"""
deepdataspace.algos.calculate_fnfp
Compare predictions to ground truths, found the FPs and FNs.
"""
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
[docs]def calculate_iou(all_gt: np.ndarray, all_det: np.ndarray) -> np.ndarray:
"""
For every ground truth, calculate it's iou to every prediction.
:param all_gt: (np.ndarray) Shape (G, 4), 4 present [x1, y1, x2, y2]
:param all_det: (np.ndarray) Shape (D, 4), 4 present [x1, y1, x2, y2]
:return iou: (np.ndarray) Shape (G, D)
"""
all_gt = all_gt[:, np.newaxis, :]
xmin = np.maximum(all_gt[:, :, 0], all_det[:, 0])
ymin = np.maximum(all_gt[:, :, 1], all_det[:, 1])
xmax = np.minimum(all_gt[:, :, 2], all_det[:, 2])
ymax = np.minimum(all_gt[:, :, 3], all_det[:, 3])
intersection = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
union = (all_gt[:, :, 2] - all_gt[:, :, 0]) * (all_gt[:, :, 3] - all_gt[:, :, 1]) \
+ (all_det[:, 2] - all_det[:, 0]) * (all_det[:, 3] - all_det[:, 1]) \
- intersection
return intersection / union
[docs]def calculate_thresholds(all_gt: List[List],
all_det: List[List],
iou_thresh: float = 0.5) -> List[Dict[str, float]]:
"""
For given IoU thresh, calculate confidence thresh for precisions from 0.0 to 1.0 .
:param all_gt: All ground truth objects from a subset.
.. code-block:: python
[
# [image_id, category_id, [x1, y1, x2, y2]]
[1, 1, [10, 20, 30, 40]],
]
:param all_det: All prediction objects from a subset.
.. code-block:: python
[
# [image_id, category_id, [x1, y1, x2, y2], conf]
[1, 1, [10, 20, 33, 43], 0.8],
]
:param iou_thresh: float.
:return conf thresholds: list of dict.
.. code-block:: python
[
{
"conf_thresh": 10,
"recall": 10,
"precision": 10,
"precision_thresh": 10.1
}
]
"""
# sort det by conf in descending order
all_det = sorted(all_det, key=lambda x: -x[-1])
# transform gt
imgid2gt = {} # {"$img_id": {"category": [int(x), ], "bbox":[[x1, y1, x2, y2], ]}, }
for gt in all_gt:
img_id = gt[0]
img = imgid2gt.setdefault(img_id, {"category": [], "bbox": []})
img["category"].append(gt[1])
img["bbox"].append(gt[2])
for img in imgid2gt.values(): # transform list to np array
img["category"] = np.array(img["category"])
img["bbox"] = np.array(img["bbox"])
# transform det
imgid2det = {} # {"$img_id": {"category": [int(x), ], "bbox":[[x1, y1, x2, y2],], "conf":[float(x)]}, }
for det in all_det:
img_id = det[0]
img = imgid2det.setdefault(img_id, {"category": [], "bbox": [], "conf": [], })
img["category"].append(det[1])
img["bbox"].append(det[2])
img["conf"].append(det[3])
for img in imgid2det.values(): # transform list to np array
img["category"] = np.array(img["category"])
img["bbox"] = np.array(img["bbox"])
img["conf"] = np.array(img["conf"])
# calculate iou
imgid2iou = {}
for imgid in imgid2gt:
bbox_gt = imgid2gt[imgid]["bbox"]
if imgid in imgid2det:
bbox_det = imgid2det[imgid]["bbox"]
else:
bbox_det = np.zeros((0, 4), dtype=np.float32)
iou = calculate_iou(bbox_gt, bbox_det)
imgid2iou[imgid] = iou
# store current position idx of each image_id
imgid2idx = {k: 0 for k in imgid2det}
# calculate thresholds, recall, precision
correct = []
for det in all_det:
imgid = det[0]
if imgid not in imgid2iou: # detection not in ground truth, it is an FN
correct.append(0)
continue
idx = imgid2idx[imgid]
category_id = det[1]
iou = imgid2iou[imgid] # G * D
gt_idx_of_cat = np.where(imgid2gt[imgid]["category"] == category_id)[0] # N * 1
iou_of_cat = iou[gt_idx_of_cat] # N * D
if iou_of_cat.shape[0] == 0:
correct.append(0)
else:
gt_idx_of_max_iou = iou_of_cat[:, idx].argmax()
max_iou = iou_of_cat[gt_idx_of_max_iou][idx]
if max_iou >= iou_thresh:
correct.append(1)
gt_idx_of_all_cat = gt_idx_of_cat[gt_idx_of_max_iou]
imgid2iou[imgid][gt_idx_of_all_cat, :] = -1
else:
correct.append(0)
imgid2idx[imgid] = idx + 1
num_det = 0
num_correct = 0
num_gt = sum([img["bbox"].shape[0] for img in imgid2gt.values()])
recalls = []
precisions = []
for c in correct:
num_det += 1
num_correct += c
precisions.append(num_correct * 1.0 / num_det)
recalls.append(num_correct * 1.0 / num_gt)
for i in range(len(precisions) - 2, -1, -1):
precisions[i] = max(precisions[i], precisions[i + 1])
results = [
{"conf_thresh": -1, "recall": -1, "precision": -1, "precision_thresh": round(i * 0.1, 1)}
for i in range(11) # 0.0 ~ 1.0
]
all_det_conf = [det[3] for det in all_det]
for i in range(len(precisions) - 1, -1, -1):
precision = precisions[i]
recall = recalls[i]
conf = all_det_conf[i]
update_idx = int(precision / 0.1)
if results[update_idx]["conf_thresh"] == -1:
results[update_idx]["conf_thresh"] = conf
results[update_idx]["precision"] = precision
results[update_idx]["recall"] = recall
return results
[docs]def calculate_fnfp(all_gt: List[List],
all_det: List[List],
iou_thresh: float = 0.5) -> Tuple[List[int], List[int]]:
"""
For given IoU thresh, check the correctness of all predictions in an image.
:param all_gt:
| All ground truth objects from a subset
| [category_id(int), bbox(List[int])], bbox = [x1, y1, x2, y2]
:param all_det:
| All prediction objects from a subset
| [category_id(int), bbox(List[int]), conf(float)], bbox = [x1, y1, x2, y2]
:param iou_thresh: IoU threshold
:return tuple of list of int:
| (gt_results, det_results)
| gt_results, list, -1 means FN, otherwise means matched det id
| det_results, list, 1 means TP, 0 means FP
"""
gt_arr = np.array(all_gt[:], dtype=np.float32)
det_arr = np.array(all_det[:], dtype=np.float32)
gt_results = [-1] * gt_arr.shape[0]
det_results = [0] * det_arr.shape[0]
if gt_arr.shape[0] == 0 or det_arr.shape[0] == 0:
return gt_results, det_results
categories = set(det_arr[:, 0].astype(np.int32).tolist())
for category_id in categories:
gt_idx_of_cat = np.where(gt_arr[:, 0] == category_id)[0]
gt_of_cat = gt_arr[gt_idx_of_cat]
det_idx_of_cat = np.where(det_arr[:, 0] == category_id)[0]
det_of_cat = det_arr[det_idx_of_cat]
if gt_of_cat.shape[0] == 0:
continue
if det_of_cat.shape[0] == 0:
continue
iou = calculate_iou(gt_of_cat[:, 1:5], det_of_cat[:, 1:5])
det_idx_sorted = (-det_of_cat[:, 5]).argsort() # sort by confidence
for i in range(det_of_cat.shape[0]):
det_idx = det_idx_sorted[i]
gt_idx_of_max_iou = iou[:, det_idx].argmax()
if iou[gt_idx_of_max_iou, det_idx] >= iou_thresh:
det_idx_of_all_cat = det_idx_of_cat[det_idx]
det_results[det_idx_of_all_cat] = 1
gt_idx_of_all_cat = gt_idx_of_cat[gt_idx_of_max_iou]
gt_results[gt_idx_of_all_cat] = det_idx_of_all_cat
iou[gt_idx_of_max_iou, :] = -1
return gt_results, det_results