# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import xml.etree.ElementTree as ET
from collections import OrderedDict

import mmcv
import numpy as np
from mmcv import print_log
from mmdet.datasets import CustomDataset
from PIL import Image

from mmrotate.core import eval_rbbox_map, obb2poly_np, poly2obb_np
from .builder import ROTATED_DATASETS

[文档]@ROTATED_DATASETS.register_module() class HRSCDataset(CustomDataset): """HRSC dataset for detection. Args: ann_file (str): Annotation file path. pipeline (list[dict]): Processing pipeline. img_subdir (str): Subdir where images are stored. Default: JPEGImages. ann_subdir (str): Subdir where annotations are. Default: Annotations. classwise (bool): Whether to use all classes or only ship. version (str, optional): Angle representations. Defaults to 'oc'. """ CLASSES = None HRSC_CLASS = ('ship', ) HRSC_CLASSES = ('ship', 'aircraft carrier', 'warcraft', 'merchant ship', 'Nimitz', 'Enterprise', 'Arleigh Burke', 'WhidbeyIsland', 'Perry', 'Sanantonio', 'Ticonderoga', 'Kitty Hawk', 'Kuznetsov', 'Abukuma', 'Austen', 'Tarawa', 'Blue Ridge', 'Container', 'OXo|--)', 'Car carrier([]==[])', 'Hovercraft', 'yacht', 'CntShip(_|.--.--|_]=', 'Cruise', 'submarine', 'lute', 'Medical', 'Car carrier(======|', 'Ford-class', 'Midway-class', 'Invincible-class') HRSC_CLASSES_ID = ('01', '02', '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '22', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33') def __init__(self, ann_file, pipeline, img_subdir='JPEGImages', ann_subdir='Annotations', classwise=False, version='oc', **kwargs): self.img_subdir = img_subdir self.ann_subdir = ann_subdir self.classwise = classwise self.version = version if self.classwise: HRSCDataset.CLASSES = self.HRSC_CLASSES self.catid2label = { ('1' + '0' * 6 + cls_id): i for i, cls_id in enumerate(self.HRSC_CLASSES_ID) } else: HRSCDataset.CLASSES = self.HRSC_CLASS # self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} super(HRSCDataset, self).__init__(ann_file, pipeline, **kwargs)
[文档] def load_annotations(self, ann_file): """Load annotation from XML style ann_file. Args: ann_file (str): Path of Imageset file. Returns: list[dict]: Annotation info from XML file. """ data_infos = [] img_ids = mmcv.list_from_file(ann_file) for img_id in img_ids: data_info = {} filename = osp.join(self.img_subdir, f'{img_id}.bmp') data_info['filename'] = filename xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml') tree = ET.parse(xml_path) root = tree.getroot() width = int(root.find('Img_SizeWidth').text) height = int(root.find('Img_SizeHeight').text) if width is None or height is None: img_path = osp.join(self.img_prefix, filename) img = width, height = img.size data_info['width'] = width data_info['height'] = height data_info['ann'] = {} gt_bboxes = [] gt_labels = [] gt_polygons = [] gt_headers = [] gt_bboxes_ignore = [] gt_labels_ignore = [] gt_polygons_ignore = [] gt_headers_ignore = [] for obj in root.findall('HRSC_Objects/HRSC_Object'): if self.classwise: class_id = obj.find('Class_ID').text label = self.catid2label.get(class_id) if label is None: continue else: label = 0 # Add an extra score to use obb2poly_np bbox = np.array([[ float(obj.find('mbox_cx').text), float(obj.find('mbox_cy').text), float(obj.find('mbox_w').text), float(obj.find('mbox_h').text), float(obj.find('mbox_ang').text), 0 ]], dtype=np.float32) polygon = obb2poly_np(bbox, 'le90')[0, :-1].astype(np.float32) if self.version != 'le90': bbox = np.array( poly2obb_np(polygon, self.version), dtype=np.float32) else: bbox = bbox[0, :-1] head = np.array([ int(obj.find('header_x').text), int(obj.find('header_y').text) ], dtype=np.int64) gt_bboxes.append(bbox) gt_labels.append(label) gt_polygons.append(polygon) gt_headers.append(head) if gt_bboxes: data_info['ann']['bboxes'] = np.array( gt_bboxes, dtype=np.float32) data_info['ann']['labels'] = np.array( gt_labels, dtype=np.int64) data_info['ann']['polygons'] = np.array( gt_polygons, dtype=np.float32) data_info['ann']['headers'] = np.array( gt_headers, dtype=np.int64) else: data_info['ann']['bboxes'] = np.zeros((0, 5), dtype=np.float32) data_info['ann']['labels'] = np.array([], dtype=np.int64) data_info['ann']['polygons'] = np.zeros((0, 8), dtype=np.float32) data_info['ann']['headers'] = np.zeros((0, 2), dtype=np.float32) if gt_polygons_ignore: data_info['ann']['bboxes_ignore'] = np.array( gt_bboxes_ignore, dtype=np.float32) data_info['ann']['labels_ignore'] = np.array( gt_labels_ignore, dtype=np.int64) data_info['ann']['polygons_ignore'] = np.array( gt_polygons_ignore, dtype=np.float32) data_info['ann']['headers_ignore'] = np.array( gt_headers_ignore, dtype=np.float32) else: data_info['ann']['bboxes_ignore'] = np.zeros((0, 5), dtype=np.float32) data_info['ann']['labels_ignore'] = np.array([], dtype=np.int64) data_info['ann']['polygons_ignore'] = np.zeros( (0, 8), dtype=np.float32) data_info['ann']['headers_ignore'] = np.zeros((0, 2), dtype=np.float32) data_infos.append(data_info) return data_infos
def _filter_imgs(self): """Filter images without ground truths.""" valid_inds = [] for i, data_info in enumerate(self.data_infos): if data_info['ann']['labels'].size > 0: valid_inds.append(i) return valid_inds
[文档] def evaluate(self, results, metric='mAP', logger=None, proposal_nums=(100, 300, 1000), iou_thr=0.5, scale_ranges=None, use_07_metric=True, nproc=4): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. proposal_nums (Sequence[int]): Proposal number used for evaluating recalls, such as recall@100, recall@1000. Default: (100, 300, 1000). iou_thr (float | list[float]): IoU threshold. It must be a float when evaluating mAP, and can be a list when evaluating recall. Default: 0.5. scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. Default: None. use_07_metric (bool): Whether to use the voc07 metric. nproc (int): Processes used for computing TP and FP. Default: 4. """ if not isinstance(metric, str): assert len(metric) == 1 metric = metric[0] allowed_metrics = ['mAP', 'recall'] if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') annotations = [self.get_ann_info(i) for i in range(len(self))] eval_results = OrderedDict() iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr if metric == 'mAP': assert isinstance(iou_thrs, list) mean_aps = [] for iou_thr in iou_thrs: print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') mean_ap, _ = eval_rbbox_map( results, annotations, scale_ranges=scale_ranges, iou_thr=iou_thr, use_07_metric=use_07_metric, dataset=self.CLASSES, logger=logger, nproc=nproc) mean_aps.append(mean_ap) eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) eval_results['mAP'] = sum(mean_aps) / len(mean_aps) eval_results.move_to_end('mAP', last=False) elif metric == 'recall': raise NotImplementedError return eval_results
