Shortcuts

Source code for mmrotate.apis.inference

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose

from mmrotate.core import get_multiscale_patch, merge_results, slide_window


[docs]def inference_detector_by_patches(model, img, sizes, steps, ratios, merge_iou_thr, bs=1): """inference patches with the detector. Split huge image(s) into patches and inference them with the detector. Finally, merge patch results on one huge image by nms. Args: model (nn.Module): The loaded detector. img (str | ndarray or): Either an image file or loaded image. sizes (list): The sizes of patches. steps (list): The steps between two patches. ratios (list): Image resizing ratios for multi-scale detecting. merge_iou_thr (float): IoU threshold for merging results. bs (int): Batch size, must greater than or equal to 1. Returns: list[np.ndarray]: Detection results. """ assert bs >= 1, 'The batch size must greater than or equal to 1' cfg = model.cfg device = next(model.parameters()).device # model device cfg = cfg.copy() # set loading pipeline type cfg.data.test.pipeline[0].type = 'LoadPatchFromImage' cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) if not isinstance(img, np.ndarray): img = mmcv.imread(img) height, width = img.shape[:2] sizes, steps = get_multiscale_patch(sizes, steps, ratios) windows = slide_window(width, height, sizes, steps) # prepare patch data patch_datas = [] for window in windows: data = dict(img=img, win=window.tolist()) # build the data pipeline data = test_pipeline(data) patch_datas.append(data) results = [] start = 0 while True: data = patch_datas[start:start + bs] data = collate(data, samples_per_gpu=len(data)) # just get the actual data from DataContainer data['img_metas'] = [ img_metas.data[0] for img_metas in data['img_metas'] ] data['img'] = [img.data[0] for img in data['img']] if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: for m in model.modules(): assert not isinstance( m, RoIPool ), 'CPU inference with RoIPool is not supported currently.' # forward the model with torch.no_grad(): results.extend(model(return_loss=False, rescale=True, **data)) if start + bs >= len(patch_datas): break start += bs results = merge_results( results, windows[:, :2], iou_thr=merge_iou_thr, device=device) return results
Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
v0.1.1
v0.1.0
main
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.