team-10/venv/Lib/site-packages/torchvision/models/segmentation/_utils.py
2025-08-02 02:00:33 +02:00

37 lines
1.2 KiB
Python

from collections import OrderedDict
from typing import Dict, Optional
from torch import nn, Tensor
from torch.nn import functional as F
from ...utils import _log_api_usage_once
class _SimpleSegmentationModel(nn.Module):
__constants__ = ["aux_classifier"]
def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
super().__init__()
_log_api_usage_once(self)
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x: Tensor) -> Dict[str, Tensor]:
input_shape = x.shape[-2:]
# contract: features is a dict of tensors
features = self.backbone(x)
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
result["aux"] = x
return result