team-10/env/Lib/site-packages/transformers/models/mistral3/modular_mistral3.py
2025-08-02 07:34:44 +02:00

353 lines
15 KiB
Python

# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import is_torchdynamo_compiling, logging
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
LlavaPreTrainedModel,
TransformersKwargs,
)
from ..mistral.modeling_mistral import MistralRMSNorm
from .configuration_mistral3 import Mistral3Config
logger = logging.get_logger(__name__)
class Mistral3RMSNorm(MistralRMSNorm):
pass
class Mistral3PatchMerger(nn.Module):
"""
Learned merging of spatial_merge_size ** 2 patches
"""
def __init__(self, config: Mistral3Config):
super().__init__()
self.config = config
hidden_size = config.vision_config.hidden_size
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = self.config.vision_config.patch_size
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
image_sizes = [
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
]
tokens_per_image = [h * w for h, w in image_sizes]
d = image_features.shape[-1]
permuted_tensor = []
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
)
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
permuted_tensor.append(grid)
image_features = torch.cat(permuted_tensor, dim=0)
image_features = self.merging_layer(image_features)
return image_features
class Mistral3MultiModalProjector(nn.Module):
def __init__(self, config: Mistral3Config):
super().__init__()
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
self.patch_merger = Mistral3PatchMerger(config)
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size * num_feature_layers,
config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
image_features = self.norm(image_features)
image_features = self.patch_merger(image_features, image_sizes)
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast):
pass
class Mistral3PreTrainedModel(LlavaPreTrainedModel):
pass
class Mistral3Model(LlavaModel):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, list[int]]`, *optional*):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_sizes (`torch.Tensor`, *optional*):
Tensor containing the image sizes as returned by the processor.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
image_features = torch.split(image_features.squeeze(0), split_sizes)
return image_features
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, Mistral3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
image_features = torch.cat(image_features, dim=0)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
n_image_tokens = (special_image_mask).sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
return Mistral3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
**kwargs,
):
return self.model.get_image_features(
pixel_values=pixel_values,
image_sizes=image_sizes,
vision_feature_layer=vision_feature_layer,
**kwargs,
)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
image_sizes=image_sizes,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return Mistral3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
__all__ = [
"Mistral3Model",
"Mistral3PreTrainedModel", # noqa
"Mistral3ForConditionalGeneration",
]