Source code for mmrotate.models.detectors.base
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.models import BaseDetector
from mmrotate.core import imshow_det_rbboxes
from ..builder import ROTATED_DETECTORS
[docs]@ROTATED_DETECTORS.register_module()
class RotatedBaseDetector(BaseDetector):
"""Base class for rotated detectors."""
def __init__(self, init_cfg=None):
super(RotatedBaseDetector, self).__init__(init_cfg)
self.fp16_enabled = False
[docs] def show_result(self,
img,
result,
score_thr=0.3,
bbox_color=(226, 43, 138),
text_color='white',
thickness=2,
font_scale=0.25,
win_name='',
show=False,
wait_time=0,
out_file=None,
**kwargs):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
text_color (str or tuple or :obj:`Color`): Color of texts.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (torch.Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
if isinstance(result, tuple):
bbox_result, segm_result = result
if isinstance(segm_result, tuple):
segm_result = segm_result[0]
else:
bbox_result, segm_result = result, None
bboxes = np.vstack(bbox_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
# draw segmentation masks
if segm_result is not None and len(labels) > 0:
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
np.random.seed(42)
color_masks = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(max(labels) + 1)
]
for i in inds:
i = int(i)
color_mask = color_masks[labels[i]]
mask = segms[i]
img[mask] = img[mask] * 0.5 + color_mask * 0.5
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
imshow_det_rbboxes(
img,
bboxes,
labels,
class_names=self.CLASSES,
score_thr=score_thr,
bbox_color=bbox_color,
text_color=text_color,
thickness=thickness,
font_scale=font_scale,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img