353 lines
15 KiB
Python
353 lines
15 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache
|
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
from ...processing_utils import Unpack
|
|
from ...utils import is_torchdynamo_compiling, logging
|
|
from ..llava.modeling_llava import (
|
|
LlavaCausalLMOutputWithPast,
|
|
LlavaForConditionalGeneration,
|
|
LlavaModel,
|
|
LlavaModelOutputWithPast,
|
|
LlavaPreTrainedModel,
|
|
TransformersKwargs,
|
|
)
|
|
from ..mistral.modeling_mistral import MistralRMSNorm
|
|
from .configuration_mistral3 import Mistral3Config
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Mistral3RMSNorm(MistralRMSNorm):
|
|
pass
|
|
|
|
|
|
class Mistral3PatchMerger(nn.Module):
|
|
"""
|
|
Learned merging of spatial_merge_size ** 2 patches
|
|
"""
|
|
|
|
def __init__(self, config: Mistral3Config):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
hidden_size = config.vision_config.hidden_size
|
|
self.spatial_merge_size = config.spatial_merge_size
|
|
self.patch_size = self.config.vision_config.patch_size
|
|
self.merging_layer = nn.Linear(hidden_size * self.spatial_merge_size**2, hidden_size, bias=False)
|
|
|
|
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor) -> torch.Tensor:
|
|
image_sizes = [
|
|
(image_size[0] // self.patch_size, image_size[1] // self.patch_size) for image_size in image_sizes
|
|
]
|
|
|
|
tokens_per_image = [h * w for h, w in image_sizes]
|
|
d = image_features.shape[-1]
|
|
|
|
permuted_tensor = []
|
|
for image_index, image_tokens in enumerate(image_features.split(tokens_per_image)):
|
|
# Reshape image_tokens into a 2D grid
|
|
h, w = image_sizes[image_index]
|
|
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
|
|
grid = torch.nn.functional.unfold(
|
|
image_grid, kernel_size=self.spatial_merge_size, stride=self.spatial_merge_size
|
|
)
|
|
grid = grid.view(d * self.spatial_merge_size**2, -1).t()
|
|
permuted_tensor.append(grid)
|
|
|
|
image_features = torch.cat(permuted_tensor, dim=0)
|
|
image_features = self.merging_layer(image_features)
|
|
return image_features
|
|
|
|
|
|
class Mistral3MultiModalProjector(nn.Module):
|
|
def __init__(self, config: Mistral3Config):
|
|
super().__init__()
|
|
self.norm = Mistral3RMSNorm(config.vision_config.hidden_size, eps=config.text_config.rms_norm_eps)
|
|
self.patch_merger = Mistral3PatchMerger(config)
|
|
# We have hidden_size * the number of vision feature layers
|
|
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
|
|
self.linear_1 = nn.Linear(
|
|
config.vision_config.hidden_size * num_feature_layers,
|
|
config.text_config.hidden_size,
|
|
bias=config.multimodal_projector_bias,
|
|
)
|
|
self.act = ACT2FN[config.projector_hidden_act]
|
|
self.linear_2 = nn.Linear(
|
|
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
|
)
|
|
|
|
def forward(self, image_features: torch.Tensor, image_sizes: torch.Tensor):
|
|
image_features = self.norm(image_features)
|
|
image_features = self.patch_merger(image_features, image_sizes)
|
|
hidden_states = self.linear_1(image_features)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class Mistral3CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
|
pass
|
|
|
|
|
|
class Mistral3ModelOutputWithPast(LlavaModelOutputWithPast):
|
|
pass
|
|
|
|
|
|
class Mistral3PreTrainedModel(LlavaPreTrainedModel):
|
|
pass
|
|
|
|
|
|
class Mistral3Model(LlavaModel):
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
image_sizes: torch.Tensor,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
|
|
The tensors corresponding to the input images.
|
|
vision_feature_layer (`Union[int, list[int]]`, *optional*):
|
|
The index of the layer to select the vision feature. If multiple indices are provided,
|
|
the vision feature of the corresponding indices will be concatenated to form the
|
|
vision features.
|
|
image_sizes (`torch.Tensor`, *optional*):
|
|
Tensor containing the image sizes as returned by the processor.
|
|
Returns:
|
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
"""
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
|
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
|
|
image_outputs = self.vision_tower(pixel_values, image_sizes=image_sizes, output_hidden_states=True, **kwargs)
|
|
# If we have one vision feature layer, return the corresponding hidden states,
|
|
# otherwise, select the hidden states of each feature layer and concatenate them
|
|
if isinstance(vision_feature_layer, int):
|
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
|
else:
|
|
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
|
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
|
|
|
image_features = self.multi_modal_projector(selected_image_feature.squeeze(0), image_sizes)
|
|
downsample_ratio = self.vision_tower.patch_size * self.config.spatial_merge_size
|
|
split_sizes = [(height // downsample_ratio) * (width // downsample_ratio) for height, width in image_sizes]
|
|
image_features = torch.split(image_features.squeeze(0), split_sizes)
|
|
return image_features
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
image_sizes: torch.Tensor = None,
|
|
**kwargs: Unpack[FlashAttentionKwargs],
|
|
) -> Union[tuple, Mistral3ModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
if pixel_values is not None:
|
|
image_features = self.get_image_features(
|
|
pixel_values=pixel_values,
|
|
vision_feature_layer=vision_feature_layer,
|
|
image_sizes=image_sizes,
|
|
)
|
|
image_features = torch.cat(image_features, dim=0)
|
|
|
|
if input_ids is None:
|
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
)
|
|
special_image_mask = special_image_mask.all(-1)
|
|
else:
|
|
special_image_mask = input_ids == self.config.image_token_id
|
|
|
|
n_image_tokens = (special_image_mask).sum()
|
|
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
|
|
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
raise ValueError(
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
)
|
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
|
|
outputs = self.language_model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
return Mistral3ModelOutputWithPast(
|
|
last_hidden_state=outputs.last_hidden_state,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
)
|
|
|
|
|
|
class Mistral3ForConditionalGeneration(LlavaForConditionalGeneration):
|
|
def get_image_features(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
image_sizes: torch.Tensor,
|
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
**kwargs,
|
|
):
|
|
return self.model.get_image_features(
|
|
pixel_values=pixel_values,
|
|
image_sizes=image_sizes,
|
|
vision_feature_layer=vision_feature_layer,
|
|
**kwargs,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
image_sizes: Optional[torch.Tensor] = None,
|
|
**kwargs: Unpack[TransformersKwargs],
|
|
) -> Union[tuple, Mistral3CausalLMOutputWithPast]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
|
|
|
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
|
|
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"What is the image?The image depicts two cats lying on a pink blanket."
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=True,
|
|
cache_position=cache_position,
|
|
image_sizes=image_sizes,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(
|
|
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
|
|
)
|
|
|
|
return Mistral3CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=outputs.image_hidden_states,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"Mistral3Model",
|
|
"Mistral3PreTrainedModel", # noqa
|
|
"Mistral3ForConditionalGeneration",
|
|
]
|