Source code for deepdataspace.process.calculate_fnfp

"""
deepdataspace.process.calculate_fnfp

This module defines the processor for calculating false negative and false positive analytics for dataset.
"""

import logging
from typing import Dict
from typing import List

from deepdataspace.algos import calculate_fnfp
from deepdataspace.constants import AnnotationType
from deepdataspace.constants import LabelCompareResult
from deepdataspace.constants import LabelType
from deepdataspace.model.image import Image
from deepdataspace.model.label import Label
from deepdataspace.process.processor import BaseProcessor

logger = logging.getLogger("process.calculate_fnfp")


[docs]class FNFPCalculator(BaseProcessor): """ This processor calculates false negative and false positive analytics for dataset. """
[docs] @classmethod def dependencies(cls) -> List[str]: """ This processor depends on nothing. """ return []
[docs] @classmethod def should_auto_run(cls) -> bool: """ This processor should run automatically on program start-up. """ return True
[docs] def can_process(self): """ This processor can process any dataset. """ return True
[docs] @staticmethod def calculate_detection_thresh(dataset_id: str): """ For each label set, """ logger.info(f"calculate_detection_thresh for dataset[{dataset_id}] starts") # prepare data for calculation all_gt = [] all_det = {} all_cat = {} images = Image(dataset_id).find_many({}) for image_idx, image in enumerate(images): image_width = image.width image_height = image.height for obj in image.objects: if not obj.bounding_box: continue if obj.is_group is True: continue bbox = obj.bounding_box x1, y1, x2, y2 = bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"] x1, x2 = x1 * image_width, x2 * image_width y1, y2 = y1 * image_height, y2 * image_height bbox = [x1, y1, x2, y2] cat_id = obj.category_id cat_id = all_cat.setdefault(cat_id, len(all_cat)) label_id = obj.label_id if obj.label_type == LabelType.GroundTruth: gt = [image_idx, cat_id, bbox] all_gt.append(gt) elif obj.label_type == LabelType.Prediction: conf = obj.conf det = [image_idx, cat_id, bbox, conf, ] label_all_det = all_det.setdefault(label_id, []) label_all_det.append(det) logger.info(f"calculate_detection_thresh prepare data done") # calculate thresholds for every label set all_thresholds = {} for label_id, label_all_det in all_det.items(): thresholds = calculate_fnfp.calculate_thresholds(all_gt, label_all_det) thresholds.sort(key=lambda x: x["precision_thresh"], reverse=True) thresholds = [ { "precision": t["precision_thresh"], "threshold": t["conf_thresh"], "recall" : t["recall"] } for t in thresholds if t["precision_thresh"] != 0.0 and t["precision_thresh"] != 1.0 ] all_thresholds[label_id] = thresholds set_data = {"compare_precisions": thresholds} Label.update_one({"id": label_id}, set_data) logger.info(f"updated compare_precisions of label[{label_id}] to {thresholds}") return all_thresholds
[docs] @staticmethod def calculate_detection_result(dataset_id: str, label_id: str, thresholds: List[Dict]): """ For given label set, calculate fn/fp analytics for each image with given precision thresholds. """ logger.info(f"calculate_detection_result starts, label_id={label_id}") category_id_map = {} images = Image(dataset_id).find_many({}) for image in images: image_width = image.width image_height = image.height # prepare all ground truth and detection for this image all_gt = [] all_det = [] gt_io_map = [] det_io_map = [] for idx, obj in enumerate(image.objects): if obj.bounding_box is None: # skip object without bounding box continue if obj.is_group is True: # skip group object continue bbox = obj.bounding_box x1, y1, x2, y2 = bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"] x1, x2 = x1 * image_width, x2 * image_width y1, y2 = y1 * image_height, y2 * image_height bbox = [x1, y1, x2, y2] cat_id = obj.category_id cat_id = category_id_map.setdefault(cat_id, len(category_id_map)) if obj.label_type == LabelType.GroundTruth: gt = [cat_id, *bbox] all_gt.append(gt) gt_io_map.append(idx) elif obj.label_type == LabelType.Prediction: if obj.label_id != label_id: continue det = [cat_id, *bbox, obj.conf, ] all_det.append(det) det_io_map.append(idx) # calculate fn/fp gt_results, det_results = calculate_fnfp.calculate_fnfp(all_gt, all_det) # save the result in database for further query for threshold in thresholds: conf_thresh = threshold["threshold"] precision_thresh = str(int(threshold["precision"] * 100)) # key 不能有点(.) gt_valids = [] det_valids = [] det_idx_valids = {} for idx, det in enumerate(all_det): if det[5] >= conf_thresh: det_valids.append(det_results[idx]) det_idx_valids[idx] = True for gt_matched_det_idx in gt_results: if det_idx_valids.get(gt_matched_det_idx, False) is True: gt_valids.append(1) else: gt_valids.append(0) num_fn = len(gt_valids) - sum(gt_valids) num_fp = len(det_valids) - sum(det_valids) image.num_fn.setdefault(label_id, {})[precision_thresh] = num_fn image.num_fp.setdefault(label_id, {})[precision_thresh] = num_fp for idx, (gt_valid, gt_result) in enumerate(zip(gt_valids, gt_results)): idx = gt_io_map[idx] obj = image.objects[idx] if gt_valid == 1: compare_result = LabelCompareResult.OK else: compare_result = LabelCompareResult.FalseNegative cat_id = obj.category_id counter = image.num_fn_cat.setdefault(label_id, {}).setdefault(cat_id, {}) counter.setdefault(precision_thresh, 0) counter[precision_thresh] += 1 obj.compare_result[precision_thresh] = compare_result obj.matched_det_idx = int(gt_result) for idx, det_result in enumerate(det_results): idx = det_io_map[idx] obj = image.objects[idx] if det_result == 1: compare_result = LabelCompareResult.OK else: compare_result = LabelCompareResult.FalsePositive cat_id = obj.category_id counter = image.num_fp_cat.setdefault(label_id, {}).setdefault(cat_id, {}) counter.setdefault(precision_thresh, 0) counter[precision_thresh] += 1 obj.compare_result[precision_thresh] = compare_result obj.matched_det_idx = None image.save() logger.debug(f"updated num_fp of image[{image.id}] of label[{label_id}] to {image.num_fp}") logger.debug(f"updated num_fn of image[{image.id}] of label[{label_id}] to {image.num_fn}")
[docs] def process_dataset(self): """ The major steps of calculate fnfp for the dataset. """ dataset = self.dataset logger.info(f"process_dataset starts, dataset_id={dataset.id}, dataset_name={dataset.name}") # skip task if no detection is found obj_types = dataset.object_types if AnnotationType.Detection not in obj_types: msg = f"dataset[{dataset.name}].object_types={obj_types}, no {AnnotationType.Detection}, skip task..." logger.info(msg) return # calculate detection thresholds for every label set all_thresholds = self.calculate_detection_thresh(dataset.id) # for each label set, for label_id, label_thresholds in all_thresholds.items(): self.calculate_detection_result(dataset.id, label_id, label_thresholds)