team-10/venv/Lib/site-packages/transformers/models/fuyu/modeling_fuyu.py
2025-08-02 02:00:33 +02:00

387 lines
17 KiB
Python

# coding=utf-8
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Fuyu model."""
from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...cache_utils import Cache
from ...generation import GenerationMixin
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...models.auto.modeling_auto import AutoModel
from ...utils import auto_docstring, can_return_tuple, logging
from .configuration_fuyu import FuyuConfig
logger = logging.get_logger(__name__)
@auto_docstring
class FuyuPreTrainedModel(PreTrainedModel):
config: FuyuConfig
base_model_prefix = "fuyu"
supports_gradient_checkpointing = True
_supports_attention_backend = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@auto_docstring(
custom_intro="""
The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.
"""
)
class FuyuModel(FuyuPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModel.from_config(config.text_config)
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
continuous_embeddings: list[torch.Tensor],
image_patch_input_indices: torch.Tensor,
) -> torch.Tensor:
"""This function places the continuous_embeddings into the word_embeddings at the locations
indicated by image_patch_input_indices. Different batch elements can have different numbers of continuous
embeddings.
Args:
word_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Tensor of word embeddings.
continuous_embeddings (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
Tensor of continuous embeddings. The length of the list is the batch size. Each entry is shape
[num_image_embeddings, hidden], and num_image_embeddings needs to match the number of non-negative
indices in image_patch_input_indices for that batch element.
image_patch_input_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Tensor of indices of the image patches in the input_ids tensor.
"""
if not (word_embeddings.shape[0] == len(continuous_embeddings)):
raise ValueError(
f"Batch sizes must match! Got {len(continuous_embeddings)=} and {word_embeddings.shape[0]=}"
)
output_embeddings = word_embeddings.clone()
for batch_idx in range(word_embeddings.shape[0]):
# First, find the positions of all the non-negative values in image_patch_input_indices, those are the
# positions in word_embeddings that we want to replace with content from continuous_embeddings.
dst_indices = torch.nonzero(image_patch_input_indices[batch_idx] >= 0, as_tuple=True)[0]
# Next look up those indices in image_patch_input_indices to find the indices in continuous_embeddings that we
# want to use to replace the values in word_embeddings.
src_indices = image_patch_input_indices[batch_idx][dst_indices]
# Check if we have more indices than embeddings. Note that we could have fewer indices if images got truncated.
if src_indices.shape[0] > continuous_embeddings[batch_idx].shape[0]:
raise ValueError(
f"Number of continuous embeddings {continuous_embeddings[batch_idx].shape=} does not match "
f"number of continuous token ids {src_indices.shape=} in batch element {batch_idx}."
)
output_embeddings[batch_idx, dst_indices] = continuous_embeddings[batch_idx][src_indices].to(
output_embeddings.device
)
return output_embeddings
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
"""
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
for patch in pixel_values
]
return patch_embeddings
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
hidden size of the model.
image_patches_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Tensor of indices of the image patches in the input_ids tensor.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_is or inputs_embeds")
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None:
patch_embeddings = self.get_image_features(image_patches)
patch_embeddings = torch.cat(patch_embeddings, dim=0)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_dict=return_dict,
**kwargs,
)
return outputs
@auto_docstring(
custom_intro="""
Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.
"""
)
class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_embed_tokens": "model.vision_embed_tokens",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.model = FuyuModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Optional[int] = 0,
**kwargs,
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
image_patches (`torch.FloatTensor` of shape `(batch_size, num_total_patches, patch_size_ x patch_size x num_channels)`, *optional*):
Image patches to be used as continuous embeddings. The patches are flattened and then projected to the
hidden size of the model.
image_patches_indices (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Tensor of indices of the image patches in the input_ids tensor.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Examples:
```python
>>> from transformers import FuyuProcessor, FuyuForCausalLM
>>> from PIL import Image
>>> import requests
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
>>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
>>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
>>> print(generation_text[0])
A blue bus parked on the side of a road.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
image_patches=image_patches,
image_patches_indices=image_patches_indices,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_dict=True,
# don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
image_patches=None,
image_patches_indices=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
image_patches=image_patches,
image_patches_indices=image_patches_indices,
**kwargs,
)
if past_key_values is not None:
model_inputs["image_patches_indices"] = None
model_inputs["image_patches"] = None
return model_inputs
__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"]