Source code for deepdataspace.plugins.tsv.process

"""
deepdataspace.plugins.tsv.process

Implement all processors for tsv dataset.
"""

import os
import json
import logging
from typing import List

import numpy as np

from deepdataspace.model.image import Image
from deepdataspace.constants import DatasetType
from deepdataspace.algos.refine_by_seed import refine
from deepdataspace.process.processor import BaseProcessor

logger = logging.getLogger("plugins.tsv.process")


[docs]class RankByFlags(BaseProcessor):
[docs] @classmethod def dependencies(cls) -> List[str]: return []
[docs] @classmethod def should_auto_run(cls) -> bool: return False
[docs] def can_process(self): return self.dataset.type == DatasetType.TSV
[docs] def process_dataset(self): logger.info(f"process_dataset starts, dataset_id={self.dataset_id}, subset_name={self.dataset.name}") embd_file = self.dataset.files.get("Embedding", None) if embd_file is None or not os.path.exists(embd_file): logger.warning(f"dataset[{self.dataset.name}] dose not have an embedding file, skip task") return # load seeds dataset = self.dataset pos_ids = [img.id for img in Image(dataset.id).find_many({"flag": 1})] neg_ids = [img.id for img in Image(dataset.id).find_many({"flag": 2})] # load embeddings with open(embd_file, "r") as fp: embeddings = [json.loads(line) for line in fp] embeddings = np.array(embeddings) # rerank by seeds and embeddings sorted_ids = refine(pos_ids, neg_ids, embeddings) # update idx field to rerank images for new_idx, _id in enumerate(sorted_ids): Image(dataset.id).batch_update({"id": int(_id)}, {"idx": int(new_idx)}) Image(dataset.id).finish_batch_update()