Source code for deepdataspace.process.calculate_fnfp


This module defines the processor for calculating false negative and false positive analytics for dataset.

import logging
from typing import Dict
from typing import List

from deepdataspace.algos import calculate_fnfp
from deepdataspace.constants import AnnotationType
from deepdataspace.constants import LabelCompareResult
from deepdataspace.constants import LabelType
from deepdataspace.model.image import Image
from deepdataspace.model.label import Label
from deepdataspace.process.processor import BaseProcessor

logger = logging.getLogger("process.calculate_fnfp")

[docs]class FNFPCalculator(BaseProcessor): """ This processor calculates false negative and false positive analytics for dataset. """
[docs] @classmethod def dependencies(cls) -> List[str]: """ This processor depends on nothing. """ return []
[docs] @classmethod def should_auto_run(cls) -> bool: """ This processor should run automatically on program start-up. """ return True
[docs] def can_process(self): """ This processor can process any dataset. """ return True
[docs] @staticmethod def calculate_detection_thresh(dataset_id: str): """ For each label set, """"calculate_detection_thresh for dataset[{dataset_id}] starts") # prepare data for calculation all_gt = [] all_det = {} all_cat = {} images = Image(dataset_id).find_many({}) for image_idx, image in enumerate(images): image_width = image.width image_height = image.height for obj in image.objects: if not obj.bounding_box: continue if obj.is_group is True: continue bbox = obj.bounding_box x1, y1, x2, y2 = bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"] x1, x2 = x1 * image_width, x2 * image_width y1, y2 = y1 * image_height, y2 * image_height bbox = [x1, y1, x2, y2] cat_id = obj.category_id cat_id = all_cat.setdefault(cat_id, len(all_cat)) label_id = obj.label_id if obj.label_type == LabelType.GroundTruth: gt = [image_idx, cat_id, bbox] all_gt.append(gt) elif obj.label_type == LabelType.Prediction: conf = obj.conf det = [image_idx, cat_id, bbox, conf, ] label_all_det = all_det.setdefault(label_id, []) label_all_det.append(det)"calculate_detection_thresh prepare data done") # calculate thresholds for every label set all_thresholds = {} for label_id, label_all_det in all_det.items(): thresholds = calculate_fnfp.calculate_thresholds(all_gt, label_all_det) thresholds.sort(key=lambda x: x["precision_thresh"], reverse=True) thresholds = [ { "precision": t["precision_thresh"], "threshold": t["conf_thresh"], "recall" : t["recall"] } for t in thresholds if t["precision_thresh"] != 0.0 and t["precision_thresh"] != 1.0 ] all_thresholds[label_id] = thresholds set_data = {"compare_precisions": thresholds} Label.update_one({"id": label_id}, set_data)"updated compare_precisions of label[{label_id}] to {thresholds}") return all_thresholds
[docs] @staticmethod def calculate_detection_result(dataset_id: str, label_id: str, thresholds: List[Dict]): """ For given label set, calculate fn/fp analytics for each image with given precision thresholds. """"calculate_detection_result starts, label_id={label_id}") category_id_map = {} images = Image(dataset_id).find_many({}) for image in images: image_width = image.width image_height = image.height # prepare all ground truth and detection for this image all_gt = [] all_det = [] gt_io_map = [] det_io_map = [] for idx, obj in enumerate(image.objects): if obj.bounding_box is None: # skip object without bounding box continue if obj.is_group is True: # skip group object continue bbox = obj.bounding_box x1, y1, x2, y2 = bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"] x1, x2 = x1 * image_width, x2 * image_width y1, y2 = y1 * image_height, y2 * image_height bbox = [x1, y1, x2, y2] cat_id = obj.category_id cat_id = category_id_map.setdefault(cat_id, len(category_id_map)) if obj.label_type == LabelType.GroundTruth: gt = [cat_id, *bbox] all_gt.append(gt) gt_io_map.append(idx) elif obj.label_type == LabelType.Prediction: if obj.label_id != label_id: continue det = [cat_id, *bbox, obj.conf, ] all_det.append(det) det_io_map.append(idx) # calculate fn/fp gt_results, det_results = calculate_fnfp.calculate_fnfp(all_gt, all_det) # save the result in database for further query for threshold in thresholds: conf_thresh = threshold["threshold"] precision_thresh = str(int(threshold["precision"] * 100)) # key 不能有点(.) gt_valids = [] det_valids = [] det_idx_valids = {} for idx, det in enumerate(all_det): if det[5] >= conf_thresh: det_valids.append(det_results[idx]) det_idx_valids[idx] = True for gt_matched_det_idx in gt_results: if det_idx_valids.get(gt_matched_det_idx, False) is True: gt_valids.append(1) else: gt_valids.append(0) num_fn = len(gt_valids) - sum(gt_valids) num_fp = len(det_valids) - sum(det_valids) image.num_fn.setdefault(label_id, {})[precision_thresh] = num_fn image.num_fp.setdefault(label_id, {})[precision_thresh] = num_fp for idx, (gt_valid, gt_result) in enumerate(zip(gt_valids, gt_results)): idx = gt_io_map[idx] obj = image.objects[idx] if gt_valid == 1: compare_result = LabelCompareResult.OK else: compare_result = LabelCompareResult.FalseNegative cat_id = obj.category_id counter = image.num_fn_cat.setdefault(label_id, {}).setdefault(cat_id, {}) counter.setdefault(precision_thresh, 0) counter[precision_thresh] += 1 obj.compare_result[precision_thresh] = compare_result obj.matched_det_idx = int(gt_result) for idx, det_result in enumerate(det_results): idx = det_io_map[idx] obj = image.objects[idx] if det_result == 1: compare_result = LabelCompareResult.OK else: compare_result = LabelCompareResult.FalsePositive cat_id = obj.category_id counter = image.num_fp_cat.setdefault(label_id, {}).setdefault(cat_id, {}) counter.setdefault(precision_thresh, 0) counter[precision_thresh] += 1 obj.compare_result[precision_thresh] = compare_result obj.matched_det_idx = None logger.debug(f"updated num_fp of image[{}] of label[{label_id}] to {image.num_fp}") logger.debug(f"updated num_fn of image[{}] of label[{label_id}] to {image.num_fn}")
[docs] def process_dataset(self): """ The major steps of calculate fnfp for the dataset. """ dataset = self.dataset"process_dataset starts, dataset_id={}, dataset_name={}") # skip task if no detection is found obj_types = dataset.object_types if AnnotationType.Detection not in obj_types: msg = f"dataset[{}].object_types={obj_types}, no {AnnotationType.Detection}, skip task..." return # calculate detection thresholds for every label set all_thresholds = self.calculate_detection_thresh( # for each label set, for label_id, label_thresholds in all_thresholds.items(): self.calculate_detection_result(, label_id, label_thresholds)