Shortcuts

Source code for mmrotate.datasets.dota

# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import re
import tempfile
import time
import zipfile
from collections import defaultdict
from functools import partial
from multiprocessing import Pool

import mmcv
import numpy as np
import torch
from mmcv.ops import box_iou_rotated, nms_rotated
from mmcv.utils import print_log
from mmdet.core.evaluation import average_precision
from mmdet.datasets.custom import CustomDataset
from terminaltables import AsciiTable

from mmrotate.core import obb2poly_np, poly2obb_np
from .builder import ROTATED_DATASETS


[docs]@ROTATED_DATASETS.register_module() class DOTADataset(CustomDataset): """DOTA dataset for detection. Args: ann_file (str): Annotation file path. pipeline (list[dict]): Processing pipeline. version (str, optional): Angle representations. Defaults to 'oc'. difficulty (bool, optional): The difficulty threshold of GT. """ CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter') def __init__(self, ann_file, pipeline, version='oc', difficulty=100, **kwargs): self.version = version self.difficulty = difficulty super(DOTADataset, self).__init__(ann_file, pipeline, **kwargs) def __len__(self): """Total number of samples of data.""" return len(self.data_infos)
[docs] def load_annotations(self, ann_folder): """ Params: ann_folder: folder that contains DOTA v1 annotations txt files """ cls_map = {c: i for i, c in enumerate(self.CLASSES) } # in mmdet v2.0 label is 0-based ann_files = glob.glob(ann_folder + '/*.txt') data_infos = [] if not ann_files: # test phase ann_files = glob.glob(ann_folder + '/*.png') for ann_file in ann_files: data_info = {} img_id = osp.split(ann_file)[1][:-4] img_name = img_id + '.png' data_info['filename'] = img_name data_info['ann'] = {} data_info['ann']['bboxes'] = [] data_info['ann']['labels'] = [] data_infos.append(data_info) else: for ann_file in ann_files: data_info = {} img_id = osp.split(ann_file)[1][:-4] img_name = img_id + '.png' data_info['filename'] = img_name data_info['ann'] = {} gt_bboxes = [] gt_labels = [] gt_polygons = [] gt_bboxes_ignore = [] gt_labels_ignore = [] gt_polygons_ignore = [] if os.path.getsize(ann_file) == 0: continue with open(ann_file) as f: s = f.readlines() for si in s: bbox_info = si.split() poly = np.array(bbox_info[:8], dtype=np.float32) try: x, y, w, h, a = poly2obb_np(poly, self.version) except: # noqa: E722 continue cls_name = bbox_info[8] difficulty = int(bbox_info[9]) label = cls_map[cls_name] if difficulty > self.difficulty: pass else: gt_bboxes.append([x, y, w, h, a]) gt_labels.append(label) gt_polygons.append(poly) if gt_bboxes: data_info['ann']['bboxes'] = np.array( gt_bboxes, dtype=np.float32) data_info['ann']['labels'] = np.array( gt_labels, dtype=np.int64) data_info['ann']['polygons'] = np.array( gt_polygons, dtype=np.float32) else: data_info['ann']['bboxes'] = np.zeros((0, 5), dtype=np.float32) data_info['ann']['labels'] = np.array([], dtype=np.int64) data_info['ann']['polygons'] = np.zeros((0, 8), dtype=np.float32) if gt_polygons_ignore: data_info['ann']['bboxes_ignore'] = np.array( gt_bboxes_ignore, dtype=np.float32) data_info['ann']['labels_ignore'] = np.array( gt_labels_ignore, dtype=np.int64) data_info['ann']['polygons_ignore'] = np.array( gt_polygons_ignore, dtype=np.float32) else: data_info['ann']['bboxes_ignore'] = np.zeros( (0, 5), dtype=np.float32) data_info['ann']['labels_ignore'] = np.array( [], dtype=np.int64) data_info['ann']['polygons_ignore'] = np.zeros( (0, 8), dtype=np.float32) data_infos.append(data_info) self.img_ids = [*map(lambda x: x['filename'][:-4], data_infos)] return data_infos
def _filter_imgs(self): """Filter images without ground truths.""" valid_inds = [] for i, data_info in enumerate(self.data_infos): if data_info['ann']['labels'].size > 0: valid_inds.append(i) return valid_inds def _set_group_flag(self): """Set flag according to image aspect ratio. All set to 0. """ self.flag = np.zeros(len(self), dtype=np.uint8)
[docs] def evaluate(self, results, metric='mAP', logger=None, proposal_nums=(100, 300, 1000), iou_thr=0.5, scale_ranges=None): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. proposal_nums (Sequence[int]): Proposal number used for evaluating recalls, such as recall@100, recall@1000. Default: (100, 300, 1000). iou_thr (float | list[float]): IoU threshold. It must be a float when evaluating mAP, and can be a list when evaluating recall. Default: 0.5. scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. Default: None. """ if not isinstance(metric, str): assert len(metric) == 1 metric = metric[0] allowed_metrics = ['mAP'] if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') annotations = [self.get_ann_info(i) for i in range(len(self))] eval_results = {} if metric == 'mAP': assert isinstance(iou_thr, float) mean_ap, _ = eval_map( results, annotations, scale_ranges=scale_ranges, iou_thr=iou_thr, dataset=self.CLASSES, version=self.version, logger=logger) eval_results['mAP'] = mean_ap else: raise NotImplementedError return eval_results
[docs] def merge_det(self, results): """Merging patch bboxes into full image. Params: results (list): Testing results of the dataset. """ collector = defaultdict(list) for idx in range(len(self)): result = results[idx] img_id = self.img_ids[idx] splitname = img_id.split('__') oriname = splitname[0] pattern1 = re.compile(r'__\d+___\d+') x_y = re.findall(pattern1, img_id) x_y_2 = re.findall(r'\d+', x_y[0]) x, y = int(x_y_2[0]), int(x_y_2[1]) new_result = [] for i, dets in enumerate(result): bboxes, scores = dets[:, :-1], dets[:, [-1]] ori_bboxes = bboxes.copy() ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array( [x, y], dtype=np.float32) labels = np.zeros((bboxes.shape[0], 1)) + i new_result.append( np.concatenate([labels, ori_bboxes, scores], axis=1)) new_result = np.concatenate(new_result, axis=0) collector[oriname].append(new_result) merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1) merged_results = mmcv.track_parallel_progress(merge_func, list(collector.items()), min(4, os.cpu_count())) return zip(*merged_results)
def _results2submission(self, results, out_folder=None): """Generate the submission of full images. Params: results (list): Testing results of the dataset. out_folder (str, optional): Folder of submission. """ print('\nMerging patch bboxes into full image!!!') start_time = time.time() id_list, dets_list = self.merge_det(results) stop_time = time.time() print(f'Used time: {(stop_time - start_time):.1f} s') if osp.exists(out_folder): raise ValueError(f'The out_folder should be a non-exist path, ' f'but {out_folder} is existing') os.makedirs(out_folder) files = [ osp.join(out_folder, 'Task1_' + cls + '.txt') for cls in self.CLASSES ] file_objs = [open(f, 'w') for f in files] for img_id, dets_per_cls in zip(id_list, dets_list): for f, dets in zip(file_objs, dets_per_cls): if dets.size == 0: continue bboxes = obb2poly_np(dets, self.version) for bbox in bboxes: txt_element = [img_id, str(bbox[-1]) ] + [f'{p:.2f}' for p in bbox[:-1]] f.writelines(' '.join(txt_element) + '\n') for f in file_objs: f.close() target_name = osp.split(out_folder)[-1] with zipfile.ZipFile( osp.join(out_folder, target_name + '.zip'), 'w', zipfile.ZIP_DEFLATED) as t: for f in files: t.write(f, osp.split(f)[-1]) return files
[docs] def format_results(self, results, submission_dir=None, **kwargs): """Format the results to submission text (standard format for DOTA evaluation). Args: results (list): Testing results of the dataset. submission_dir (str, optional): The folder that contains submission files. If not specified, a temp folder will be created. Default: None. Returns: tuple: (result_files, tmp_dir), result_files is a dict containing the json filepaths, tmp_dir is the temporal directory created for saving json files when submission_dir is not specified. """ assert isinstance(results, list), 'results must be a list' assert len(results) == len(self), ( f'The length of results is not equal to ' f'the dataset len: {len(results)} != {len(self)}') if submission_dir is None: submission_dir = tempfile.TemporaryDirectory() else: tmp_dir = None result_files = self._results2submission(results, submission_dir) return result_files, tmp_dir
def eval_map(det_results, annotations, scale_ranges=None, iou_thr=0.5, dataset=None, version='oc', logger=None, nproc=4): """Evaluate mAP of a dataset. Args: det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. The outer list indicates images, and the inner list indicates per-class detected bboxes. annotations (list[dict]): Ground truth annotations where each item of the list indicates an image. Keys of annotations are: - `bboxes`: numpy array of shape (n, 4) - `labels`: numpy array of shape (n, ) - `bboxes_ignore` (optional): numpy array of shape (k, 4) - `labels_ignore` (optional): numpy array of shape (k, ) scale_ranges (list[tuple] | None): Range of scales to be evaluated, in the format [(min1, max1), (min2, max2), ...]. A range of (32, 64) means the area range between (32**2, 64**2). Default: None. iou_thr (float): IoU threshold to be considered as matched. Default: 0.5. dataset (list[str] | str | None): Dataset name or dataset classes, there are minor differences in metrics for different datasets, e.g. "voc07", "imagenet_det", etc. Default: None. version (str, optional): Angle representations. Defaults to 'oc'. logger (logging.Logger | str | None): The way to print the mAP summary. See `mmcv.utils.print_log()` for details. Default: None. tpfp_fn (callable | None): The function used to determine true/ false positives. If None, :func:`tpfp_default` is used as default unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this case). If it is given as a function, then this function is used to evaluate tp & fp. Default None. nproc (int): Processes used for computing TP and FP. Default: 4. Returns: tuple: (mAP, [dict, dict, ...]) """ assert len(det_results) == len(annotations) num_imgs = len(det_results) num_scales = len(scale_ranges) if scale_ranges is not None else 1 num_classes = len(det_results[0]) # positive class num area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] if scale_ranges is not None else None) pool = Pool(nproc) eval_results = [] for i in range(num_classes): # get gt and det bboxes of this class cls_dets, cls_gts, cls_gts_ignore = get_cls_results( det_results, annotations, i) # compute tp and fp for each image with multiple processes tpfp = pool.starmap( tpfp_default, zip(cls_dets, cls_gts, cls_gts_ignore, [iou_thr for _ in range(num_imgs)], [area_ranges for _ in range(num_imgs)])) tp, fp = tuple(zip(*tpfp)) # calculate gt number of each scale # ignored gts or gts beyond the specific scale are not counted num_gts = np.zeros(num_scales, dtype=int) for _, bbox in enumerate(cls_gts): if area_ranges is None: num_gts[0] += bbox.shape[0] else: gt_areas = (bbox[:, 2] - bbox[:, 0]) * ( bbox[:, 3] - bbox[:, 1]) for k, (min_area, max_area) in enumerate(area_ranges): num_gts[k] += np.sum((gt_areas >= min_area) & (gt_areas < max_area)) # sort all det bboxes by score, also sort tp and fp cls_dets = np.vstack(cls_dets) num_dets = cls_dets.shape[0] sort_inds = np.argsort(-cls_dets[:, -1]) tp = np.hstack(tp)[:, sort_inds] fp = np.hstack(fp)[:, sort_inds] # calculate recall and precision with tp and fp tp = np.cumsum(tp, axis=1) fp = np.cumsum(fp, axis=1) eps = np.finfo(np.float32).eps recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) precisions = tp / np.maximum((tp + fp), eps) # calculate AP if scale_ranges is None: recalls = recalls[0, :] precisions = precisions[0, :] num_gts = num_gts.item() mode = 'area' if dataset != 'voc07' else '11points' ap = average_precision(recalls, precisions, mode) eval_results.append({ 'num_gts': num_gts, 'num_dets': num_dets, 'recall': recalls, 'precision': precisions, 'ap': ap }) pool.close() if scale_ranges is not None: all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) all_num_gts = np.vstack( [cls_result['num_gts'] for cls_result in eval_results]) mean_ap = [] for i in range(num_scales): if np.any(all_num_gts[:, i] > 0): mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) else: mean_ap.append(0.0) else: aps = [] for cls_result in eval_results: if cls_result['num_gts'] > 0: aps.append(cls_result['ap']) mean_ap = np.array(aps).mean().item() if aps else 0.0 print_map_summary( mean_ap, eval_results, dataset, area_ranges, logger=logger) return mean_ap, eval_results def print_map_summary(mean_ap, results, dataset=None, scale_ranges=None, logger=None): """Print mAP and results of each class. A table will be printed to show the gts/dets/recall/AP of each class and the mAP. Args: mean_ap (float): Calculated from `eval_map()`. results (list[dict]): Calculated from `eval_map()`. dataset (list[str] | str | None): Dataset name or dataset classes. scale_ranges (list[tuple] | None): Range of scales to be evaluated. logger (logging.Logger | str | None): The way to print the mAP summary. See `mmcv.utils.print_log()` for details. Default: None. """ if logger == 'silent': return if isinstance(results[0]['ap'], np.ndarray): num_scales = len(results[0]['ap']) else: num_scales = 1 if scale_ranges is not None: assert len(scale_ranges) == num_scales num_classes = len(results) recalls = np.zeros((num_scales, num_classes), dtype=np.float32) aps = np.zeros((num_scales, num_classes), dtype=np.float32) num_gts = np.zeros((num_scales, num_classes), dtype=int) for i, cls_result in enumerate(results): if cls_result['recall'].size > 0: recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] aps[:, i] = cls_result['ap'] num_gts[:, i] = cls_result['num_gts'] if dataset is None: label_names = [str(i) for i in range(num_classes)] else: label_names = dataset if not isinstance(mean_ap, list): mean_ap = [mean_ap] header = ['class', 'gts', 'dets', 'recall', 'ap'] for i in range(num_scales): if scale_ranges is not None: print_log(f'Scale range {scale_ranges[i]}', logger=logger) table_data = [header] for j in range(num_classes): row_data = [ label_names[j], num_gts[i, j], results[j]['num_dets'], f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' ] table_data.append(row_data) table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) table = AsciiTable(table_data) table.inner_footing_row_border = True print_log('\n' + table.table, logger=logger) def tpfp_default(det_bboxes, gt_bboxes, gt_bboxes_ignore=None, iou_thr=0.5, area_ranges=None): """Check if detected bboxes are true positive or false positive. Args: det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 9). gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 8). gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, of shape (k, 8). Default: None iou_thr (float): IoU threshold to be considered as matched. Default: 0.5. area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, in the format [(min1, max1), (min2, max2), ...]. Default: None. Returns: tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of each array is (num_scales, m). """ # an indicator of ignored gts det_bboxes = np.array(det_bboxes) gt_ignore_inds = np.concatenate( (np.zeros(gt_bboxes.shape[0], dtype=np.bool), np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) # stack gt_bboxes and gt_bboxes_ignore for convenience gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) num_dets = det_bboxes.shape[0] num_gts = gt_bboxes.shape[0] if area_ranges is None: area_ranges = [(None, None)] num_scales = len(area_ranges) # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of # a certain scale tp = np.zeros((num_scales, num_dets), dtype=np.float32) fp = np.zeros((num_scales, num_dets), dtype=np.float32) # if there is no gt bboxes in this image, then all det bboxes # within area range are false positives if gt_bboxes.shape[0] == 0: if area_ranges == [(None, None)]: fp[...] = 1 else: raise NotImplementedError return tp, fp ious = box_iou_rotated( torch.from_numpy(det_bboxes).float(), torch.from_numpy(gt_bboxes).float()).numpy() # for each det, the max iou with all gts ious_max = ious.max(axis=1) # for each det, which gt overlaps most with it ious_argmax = ious.argmax(axis=1) # sort all dets in descending order by scores sort_inds = np.argsort(-det_bboxes[:, -1]) for k, (min_area, max_area) in enumerate(area_ranges): gt_covered = np.zeros(num_gts, dtype=bool) # if no area range is specified, gt_area_ignore is all False if min_area is None: gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) else: raise NotImplementedError for i in sort_inds: if ious_max[i] >= iou_thr: matched_gt = ious_argmax[i] if not (gt_ignore_inds[matched_gt] or gt_area_ignore[matched_gt]): if not gt_covered[matched_gt]: gt_covered[matched_gt] = True tp[k, i] = 1 else: fp[k, i] = 1 # otherwise ignore this detected bbox, tp = 0, fp = 0 elif min_area is None: fp[k, i] = 1 else: bbox = det_bboxes[i, :5] area = bbox[2] * bbox[3] if area >= min_area and area < max_area: fp[k, i] = 1 return tp, fp def get_cls_results(det_results, annotations, class_id): """Get det results and gt information of a certain class. Args: det_results (list[list]): Same as `eval_map()`. annotations (list[dict]): Same as `eval_map()`. class_id (int): ID of a specific class. Returns: tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes """ cls_dets = [img_res[class_id] for img_res in det_results] cls_gts = [] cls_gts_ignore = [] for ann in annotations: gt_inds = ann['labels'] == class_id cls_gts.append(ann['bboxes'][gt_inds, :]) if ann.get('labels_ignore', None) is not None: ignore_inds = ann['labels_ignore'] == class_id cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) else: cls_gts_ignore.append(torch.zeros((0, 6), dtype=torch.float64)) return cls_dets, cls_gts, cls_gts_ignore def _merge_func(info, CLASSES, iou_thr): """Merging patch bboxes into full image. Params: CLASSES (list): Label category. iou_thr (float): Threshold of IoU. """ img_id, label_dets = info label_dets = np.concatenate(label_dets, axis=0) labels, dets = label_dets[:, 0], label_dets[:, 1:] big_img_results = [] for i in range(len(CLASSES)): if len(dets[labels == i]) == 0: big_img_results.append(dets[labels == i]) else: try: cls_dets = torch.from_numpy(dets[labels == i]).cuda() except: # noqa: E722 cls_dets = torch.from_numpy(dets[labels == i]) nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1], iou_thr) big_img_results.append(nms_dets.cpu().numpy()) return img_id, big_img_results
Read the Docs v: v0.1.0
Versions
latest
stable
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.