Source code for deepdataspace.plugins.tsv.server

"""
deepdataspace.plugins.tsv.server

Add RESTful APIs for tsv dataset.
"""

import json
import logging
import os
import time
from dataclasses import asdict
from dataclasses import dataclass

from rest_framework.views import APIView

from deepdataspace.constants import ErrCode
from deepdataspace.constants import DatasetStatus
from deepdataspace.constants import TaskName
from deepdataspace.constants import TaskStatus
from deepdataspace.globals import Redis
from deepdataspace.model import DataSet
from deepdataspace.plugins.tsv.process import RankByFlags
from deepdataspace.utils.http import Argument
from deepdataspace.utils.http import format_response
from deepdataspace.utils.http import parse_arguments
from deepdataspace.utils.http import raise_exception

logger = logging.getLogger("plugins.tsv.server")


[docs]@dataclass class Task: """ A task model for rerank_by_flags task. """ id: str name: str status: str trigger_at: int start_at: int finish_at: int
[docs]class ReRankImagesByFlagsTasksView(APIView): """ Add an API to trigger a task to rerank images by dataset flags. """ post_args = [Argument("dataset_id", str, "json", required=True)]
[docs] def post(self, request): """ Trigger a task to rerank images by dataset flags. """ dataset_id, = parse_arguments(request, self.post_args) dataset = DataSet.find_one({"id": dataset_id}) if dataset is None: raise_exception(ErrCode.DatasetNotFound, f"dataset_id[{dataset_id}] not found") status = dataset.status if status in DatasetStatus.BatchProcessing_: raise_exception(ErrCode.DatasetNotReadable, f"dataset_id[{dataset_id}] is in status[{status}] now, try again later") embd_file = dataset.files.get("Embedding", None) if embd_file is None or not os.path.exists(embd_file): raise_exception(ErrCode.DatasetMissingEmbdFile, f"dataset_id[{dataset_id}] dose not have an embedding file") celery_task = RankByFlags.run_async(dataset.path, enforce=True) task_id = celery_task.id.replace("-", "") task = Task(task_id, TaskName.ReRankByFlags, TaskStatus.Waiting, int(time.time() * 1000), 0, 0) task_data = asdict(task) Redis.set(f"task:{task_id}", json.dumps(task_data)) return format_response(task_data)
[docs]class ReRankImagesByFlagsTaskView(APIView): """ Add an API to query result of rerank_by_flags task. """
[docs] def get(self, request, task_id: str): """ Query result of rerank_by_flags task. """ redis_key = f"task:{task_id}" task_data = Redis.get(redis_key) if not task_data: raise_exception(ErrCode.ReRankByFlagTaskNotFound, f"task_id[{task_id}] not found") task_data = json.loads(task_data) return format_response(task_data)