Source code for deepdataspace.model.dataset

"""
deepdataspace.model.dataset

The dataset model.
"""

import importlib
import json
import logging
import os
import time
import uuid
from typing import Dict

from pymongo.collection import Collection
from pymongo.typings import _DocumentType

from deepdataspace.constants import AnnotationType
from deepdataspace.constants import DatasetStatus
from deepdataspace.constants import FileReadMode
from deepdataspace.constants import LabelType
from deepdataspace.constants import RedisKey
from deepdataspace.globals import Redis
from deepdataspace.model._base import BaseModel
from deepdataspace.model.category import Category
from deepdataspace.model.image import Image
from deepdataspace.model.image import ImageModel
from deepdataspace.model.label import Label
from deepdataspace.utils.file import create_file_url
from deepdataspace.utils.string import get_str_md5

logger = logging.getLogger("io.model.dataset")


def current_ts():
    return int(time.time())


[docs]class DataSet(BaseModel): """ | DataSet is a collection of images. | This only saves metadata of the dataset, not the images. | Every dataset has a corresponding individual collection to save the images. Attributes: ----------- name: str The dataset name. id: str The dataset id. path: str The dataset directory path. type: str The dataset type, see :class:`deepdataspace.constants.DatasetType`. status: str The current status of the dataset, with default being `DatasetStatus.Waiting`. See :class:`deepdataspace.constants.DatasetStatus`. detail_status: dict Detailed status of every importer/processor. See :class:`deepdataspace.constants.DatasetStatus`. flag_export_link: str The dataset flag export link. object_types: list List indicating what kind of objects this dataset contains. See :class:`deepdataspace.constants.AnnotationType`. num_images: int The number of images in this dataset. files: dict Dictionary containing the relevant files of this dataset. cover_url: str The cover image URL. description: str The dataset description. description_func: callable A function used to generate the description for this dataset. group_id: str The group id associated with this dataset. group_name: str The group name associated with this dataset. """
[docs] @classmethod def get_collection(cls, *args, **kwargs) -> Collection[_DocumentType]: """ Datasets are stored in the `datasets` collection. """ return cls.db["datasets"]
# the mandatory fields name: str # the dataset name # the optional fields id: str = None # the dataset id path: str = None # the dataset directory path type: str = None # the dataset type status: str = DatasetStatus.Waiting detail_status: dict = {} # detailed status of every importer/processor flag_export_link: str = None # the dataset flag export link object_types: list = [] # what kind of objects this dataset contains num_images: int = 0 files: dict = {} # the relevant files of this dataset cover_url: str = None # the cover image url description: str = None # the dataset description description_func: str = None # a function to generate description group_id: str = None group_name: str = None _batch_queue: Dict[int, ImageModel] = {} _batch_size: int = 100
[docs] @classmethod def create_dataset(cls, name: str, id_: str = None, type: str = None, path: str = None, files: dict = None, description: str = None, description_func: str = None, ) -> "DataSet": """ Create a dataset. Multiple datasets can have the same name. If you want to create a unique dataset, please specify a unique id value. :param name: the dataset name. Multiple datasets can have the same name. :param id_: the optional dataset id. If provided, a unique dataset will be created with the id value. :param type: the optional dataset type, can be "tsv", "coco2017". :param path: the optional dataset directory path. :param files: the optional dataset relevant files. The key is the file info, the value is the file path. :param description: the optional dataset description. :param description_func: an import path of a function to generate description. The function takes the dataset instance as the only argument and returns a string. If this is provided, it proceeds the description str. :return: the dataset object. """ if id_: dataset = DataSet.find_one({"id": id_}) if dataset is not None: dataset.type = type or dataset.type dataset.path = path or dataset.path dataset.files = files or dataset.files dataset.name = name dataset.save() return dataset else: id_ = uuid.uuid4().hex files = files or {} dataset = cls(name=name, id=id_, type=type, path=path, files=files, status=DatasetStatus.Ready, description=description, description_func=description_func) dataset.post_init() dataset.save() return dataset
[docs] @classmethod def get_importing_dataset(cls, name: str, id_: str = None, type: str = None, path: str = None, files: dict = None, ) -> "DataSet": """ This is the same as create_dataset. But if the dataset is new, it's status will be set to "waiting" instead of "ready". """ if id_: dataset = DataSet.find_one({"id": id_}) if dataset is not None: dataset.type = type or dataset.type dataset.path = path or dataset.path dataset.files = files or dataset.files dataset.name = name dataset.save() return dataset else: id_ = uuid.uuid4().hex files = files or {} dataset = cls(name=name, id=id_, type=type, path=path, files=files, status=DatasetStatus.Waiting) dataset.post_init() dataset.save() return dataset
def _add_cover(self, force_update: bool = False): has_cover = bool(self.cover_url) if has_cover and not force_update: return IModel = Image(self.id) images = list(IModel.find_many({}, sort=[("idx", 1)], size=1)) if not images: return self.cover_url = images[0].url.strip() self.save()
[docs] def add_image(self, uri: str, thumb_uri: str = None, width: int = None, height: int = None, id: int = None, metadata: dict = None, flag: int = 0, flag_ts: int = 0, ) -> ImageModel: """ Add an image to the dataset. The same image will be added to the dataset multiple times if the same uri is provided without the same image id. :param uri: the image uri, can be a local file path stars with "file://" or a remote url starts with "http://". :param thumb_uri: the image thumbnail uri, also can be a local file path or a remote url. :param width: the image width of full resolution. :param height: the image height of full resolution. :param id: the image id, if not provided, the image id will be the current number of images in the dataset. :param metadata: any information data need to be stored. :param flag: the image flag, 0 for not flagged, 1 for positive, 2 for negative. :param flag_ts: the image flag timestamp. :return: the image object. """ full_uri = uri thumb_uri = full_uri if thumb_uri is None else thumb_uri if full_uri.startswith("file://"): full_uri = create_file_url(full_uri[7:], read_mode=FileReadMode.Binary) if thumb_uri.startswith("file://"): thumb_uri = create_file_url(thumb_uri[7:], read_mode=FileReadMode.Binary) metadata = json.dumps(metadata) if metadata else "{}" image = None Model = Image(self.id) if id is not None: image = Model.find_one({"id": id}) if image is not None: user_objects = list(filter(lambda obj: obj.label_type == LabelType.User, image.objects)) image.objects = user_objects if image is None: image_id = id or self.num_images image = Model( id=image_id, idx=self.num_images, type=self.type, dataset_id=self.id, url=thumb_uri, url_full_res=full_uri, width=width, height=height, flag=flag, flag_ts=flag_ts, metadata=metadata, ) else: # please don't change idx in this case image.url = thumb_uri or image.url image.url_full_res = uri or image.url_full_res image.width = width or image.width image.height = height or image.height image.flag = flag or image.flag image.flag_ts = flag_ts or image.flag_ts image.metadata = metadata or image.metadata image.post_init() image._dataset = self # this saves a db query image.save() self.num_images = Model.count_num({}) self._add_cover() # save whitelist to redis whitelist_dirs = set() self._add_local_file_url_to_whitelist(image.url, whitelist_dirs) self._add_local_file_url_to_whitelist(image.url_full_res, whitelist_dirs) if whitelist_dirs: Redis.sadd(RedisKey.DatasetImageDirs, *whitelist_dirs) return image
[docs] def batch_add_image(self, uri: str, thumb_uri: str = None, width: int = None, height: int = None, id_: int = None, metadata: dict = None, flag: int = 0, flag_ts: int = 0, ) -> ImageModel: """ This is the batch version of add_image, which optimizes database performance. But this method is not thread safe, please make sure only one thread is calling this method. And after the batch add is finished, please call finish_batch_add_image to save the changes to database. :param uri: the image uri, can be a local file path stars with "file://" or a remote url starts with "http://". :param thumb_uri: the image thumbnail uri, also can be a local file path or a remote url. :param width: the image width of full resolution. :param height: the image height of full resolution. :param id_: the image id, if not provided, the image id will be the current number of images in the dataset. :param metadata: any information data need to be stored. :param flag: the image flag, 0 for not flagged, 1 for positive, 2 for negative. :param flag_ts: the image flag timestamp. :return: the image object, the flag indicating whether the batch is saved to db. """ full_uri = uri thumb_uri = full_uri if thumb_uri is None else thumb_uri if full_uri.startswith("file://"): full_uri = create_file_url(full_uri[7:], read_mode=FileReadMode.Binary) if thumb_uri.startswith("file://"): thumb_uri = create_file_url(thumb_uri[7:], read_mode=FileReadMode.Binary) metadata = metadata or {} metadata = json.dumps(metadata) # if id is not set, # we use a negative value to indicate we are adding a new image instead of updating an existing one id_ = id_ if id_ is not None else -self.num_images idx = -1 # we decide the idx later Model = Image(self.id) image = Model(id=id_, idx=idx, type=self.type, dataset_id=self.id, url=thumb_uri, url_full_res=full_uri, width=width, height=height, flag=flag, flag_ts=flag_ts, metadata=metadata, ) self._batch_queue[id_] = image self.num_images += 1 image._dataset = self # this saves a db query return image
@staticmethod def _add_local_file_url_to_whitelist(url: str, whitelist: set): if not url or not url.startswith("/files/local_files"): return path = url.split("/") path = "/".join(path[7:]) whitelist.add(os.path.dirname(path)) def _batch_save_image_batch(self): """ The internal function to flush the batch queue to database. """ if not self._batch_queue: return waiting_labels = dict() waiting_categories = dict() object_types = set() IModel = Image(self.id) idx = IModel.count_num({}) whitelist_dirs = set() for image_id, image in self._batch_queue.items(): for obj in image.objects: # setup label label_id = waiting_labels.get(obj.label_name, None) if label_id is None: label_id = get_str_md5(f"{self.id}_{obj.label_name}") label = Label(name=obj.label_name, id=label_id, type=obj.label_type, dataset_id=self.id) label.batch_save(batch_size=self._batch_size) waiting_labels[obj.label_name] = label_id obj.label_id = label_id # setup category category_id = waiting_categories.get(obj.category_name, None) if category_id is None: category_id = get_str_md5(f"{self.id}_{obj.category_name}") category = Category(name=obj.category_name, id=category_id, dataset_id=self.id) category.batch_save(batch_size=self._batch_size) waiting_categories[obj.category_name] = category_id obj.category_id = category_id # setup object types if AnnotationType.Classification not in object_types: object_types.add(AnnotationType.Classification) if obj.bounding_box and AnnotationType.Detection not in object_types: object_types.add(AnnotationType.Detection) if obj.segmentation and AnnotationType.Segmentation not in object_types: object_types.add(AnnotationType.Segmentation) if obj.alpha and AnnotationType.Matting not in object_types: object_types.add(AnnotationType.Matting) self._add_local_file_url_to_whitelist(obj.alpha, whitelist_dirs) if obj.points and AnnotationType.KeyPoints not in object_types: object_types.add(AnnotationType.KeyPoints) # setup image image.idx = idx image.id = idx if image.id < 0 else image.id image.batch_save(batch_size=self._batch_size, set_on_insert={"idx": image.idx}) idx += 1 self._add_local_file_url_to_whitelist(image.url, whitelist_dirs) self._add_local_file_url_to_whitelist(image.url_full_res, whitelist_dirs) # finish batch saves IModel.finish_batch_save() Label.finish_batch_save() Category.finish_batch_save() # setup dataset self.object_types = list(sorted(list(object_types))) self.num_images = IModel.count_num({}) self.save() # save whitelist to redis if whitelist_dirs: Redis.sadd(RedisKey.DatasetImageDirs, *whitelist_dirs) self._batch_queue.clear() def batch_save_image(self, enforce: bool = False): batch_is_full = len(self._batch_queue) >= self._batch_size if batch_is_full or enforce: self._batch_save_image_batch() return True return False
[docs] def finish_batch_add_image(self): """ This method should be called after all batch_add_image calls are finished. This saves all images in the buffer queue to database. """ self._batch_save_image_batch() self._add_cover()
[docs] def eval_description(self): """ Evaluate the description function and return the description. """ if self.description_func is not None: try: module, func = self.description_func.rsplit(".", 1) description_module = importlib.import_module(module) description_func = getattr(description_module, func) return description_func(self) except (ImportError, AttributeError): msg = f"Cannot import description_func[{self.description_func}] for dataset[{self.id}]" logger.warning(msg) return self.description or self.path except: logger.warning(f"Failed to eval description_func[{self.description_func}] for dataset[{self.id}]") return self.description or self.path
[docs] @staticmethod def cascade_delete(dataset: "DataSet"): """ Cascade delete the dataset, along with all its images, labels, categories and objects. """ if dataset is None: return dataset_id = dataset.id print(f"dataset [{dataset_id}] is found, deleting...") print(f"dataset [{dataset_id}] is found, deleting categories...") Category.delete_many({"dataset_id": dataset_id}) print(f"dataset [{dataset_id}] is found, deleting labels...") Label.delete_many({"dataset_id": dataset_id}) print(f"dataset [{dataset_id}] is found, deleting images...") Image(dataset_id).get_collection().drop() DataSet.delete_many({"id": dataset_id}) print(f"dataset [{dataset_id}] is deleted.")