119 lines
4.6 KiB
Python
119 lines
4.6 KiB
Python
![]() |
"""
|
||
|
Implements the Generalized R-CNN framework
|
||
|
"""
|
||
|
|
||
|
import warnings
|
||
|
from collections import OrderedDict
|
||
|
from typing import Dict, List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
from torch import nn, Tensor
|
||
|
|
||
|
from ...utils import _log_api_usage_once
|
||
|
|
||
|
|
||
|
class GeneralizedRCNN(nn.Module):
|
||
|
"""
|
||
|
Main class for Generalized R-CNN.
|
||
|
|
||
|
Args:
|
||
|
backbone (nn.Module):
|
||
|
rpn (nn.Module):
|
||
|
roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
|
||
|
detections / masks from it.
|
||
|
transform (nn.Module): performs the data transformation from the inputs to feed into
|
||
|
the model
|
||
|
"""
|
||
|
|
||
|
def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
|
||
|
super().__init__()
|
||
|
_log_api_usage_once(self)
|
||
|
self.transform = transform
|
||
|
self.backbone = backbone
|
||
|
self.rpn = rpn
|
||
|
self.roi_heads = roi_heads
|
||
|
# used only on torchscript mode
|
||
|
self._has_warned = False
|
||
|
|
||
|
@torch.jit.unused
|
||
|
def eager_outputs(self, losses, detections):
|
||
|
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
||
|
if self.training:
|
||
|
return losses
|
||
|
|
||
|
return detections
|
||
|
|
||
|
def forward(self, images, targets=None):
|
||
|
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
|
||
|
"""
|
||
|
Args:
|
||
|
images (list[Tensor]): images to be processed
|
||
|
targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
|
||
|
|
||
|
Returns:
|
||
|
result (list[BoxList] or dict[Tensor]): the output from the model.
|
||
|
During training, it returns a dict[Tensor] which contains the losses.
|
||
|
During testing, it returns list[BoxList] contains additional fields
|
||
|
like `scores`, `labels` and `mask` (for Mask R-CNN models).
|
||
|
|
||
|
"""
|
||
|
if self.training:
|
||
|
if targets is None:
|
||
|
torch._assert(False, "targets should not be none when in training mode")
|
||
|
else:
|
||
|
for target in targets:
|
||
|
boxes = target["boxes"]
|
||
|
if isinstance(boxes, torch.Tensor):
|
||
|
torch._assert(
|
||
|
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
|
||
|
f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
|
||
|
)
|
||
|
else:
|
||
|
torch._assert(False, f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
|
||
|
|
||
|
original_image_sizes: List[Tuple[int, int]] = []
|
||
|
for img in images:
|
||
|
val = img.shape[-2:]
|
||
|
torch._assert(
|
||
|
len(val) == 2,
|
||
|
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
|
||
|
)
|
||
|
original_image_sizes.append((val[0], val[1]))
|
||
|
|
||
|
images, targets = self.transform(images, targets)
|
||
|
|
||
|
# Check for degenerate boxes
|
||
|
# TODO: Move this to a function
|
||
|
if targets is not None:
|
||
|
for target_idx, target in enumerate(targets):
|
||
|
boxes = target["boxes"]
|
||
|
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
|
||
|
if degenerate_boxes.any():
|
||
|
# print the first degenerate box
|
||
|
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
|
||
|
degen_bb: List[float] = boxes[bb_idx].tolist()
|
||
|
torch._assert(
|
||
|
False,
|
||
|
"All bounding boxes should have positive height and width."
|
||
|
f" Found invalid box {degen_bb} for target at index {target_idx}.",
|
||
|
)
|
||
|
|
||
|
features = self.backbone(images.tensors)
|
||
|
if isinstance(features, torch.Tensor):
|
||
|
features = OrderedDict([("0", features)])
|
||
|
proposals, proposal_losses = self.rpn(images, features, targets)
|
||
|
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
|
||
|
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # type: ignore[operator]
|
||
|
|
||
|
losses = {}
|
||
|
losses.update(detector_losses)
|
||
|
losses.update(proposal_losses)
|
||
|
|
||
|
if torch.jit.is_scripting():
|
||
|
if not self._has_warned:
|
||
|
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
|
||
|
self._has_warned = True
|
||
|
return losses, detections
|
||
|
else:
|
||
|
return self.eager_outputs(losses, detections)
|