216 lines
9.6 KiB
Python
216 lines
9.6 KiB
Python
from typing import Optional, Union
|
|
|
|
from ..generation import GenerationConfig
|
|
from ..utils import add_end_docstrings, is_torch_available, is_vision_available, logging
|
|
from .base import Pipeline, build_pipeline_init_args
|
|
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
from ..image_utils import load_image
|
|
|
|
if is_torch_available():
|
|
from ..models.auto.modeling_auto import MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
|
|
from .pt_utils import KeyDataset
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
|
|
class VisualQuestionAnsweringPipeline(Pipeline):
|
|
"""
|
|
Visual Question Answering pipeline using a `AutoModelForVisualQuestionAnswering`. This pipeline is currently only
|
|
available in PyTorch.
|
|
|
|
Unless the model you're using explicitly sets these generation parameters in its configuration files
|
|
(`generation_config.json`), the following default values will be used:
|
|
- max_new_tokens: 256
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import pipeline
|
|
|
|
>>> oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
|
|
>>> image_url = "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/lena.png"
|
|
>>> oracle(question="What is she wearing ?", image=image_url)
|
|
[{'score': 0.948, 'answer': 'hat'}, {'score': 0.009, 'answer': 'fedora'}, {'score': 0.003, 'answer': 'clothes'}, {'score': 0.003, 'answer': 'sun hat'}, {'score': 0.002, 'answer': 'nothing'}]
|
|
|
|
>>> oracle(question="What is she wearing ?", image=image_url, top_k=1)
|
|
[{'score': 0.948, 'answer': 'hat'}]
|
|
|
|
>>> oracle(question="Is this a person ?", image=image_url, top_k=1)
|
|
[{'score': 0.993, 'answer': 'yes'}]
|
|
|
|
>>> oracle(question="Is this a man ?", image=image_url, top_k=1)
|
|
[{'score': 0.996, 'answer': 'no'}]
|
|
```
|
|
|
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
|
|
|
This visual question answering pipeline can currently be loaded from [`pipeline`] using the following task
|
|
identifiers: `"visual-question-answering", "vqa"`.
|
|
|
|
The models that this pipeline can use are models that have been fine-tuned on a visual question answering task. See
|
|
the up-to-date list of available models on
|
|
[huggingface.co/models](https://huggingface.co/models?filter=visual-question-answering).
|
|
"""
|
|
|
|
_load_processor = False
|
|
_load_image_processor = True
|
|
_load_feature_extractor = False
|
|
_load_tokenizer = True
|
|
|
|
_pipeline_calls_generate = True
|
|
# Make sure the docstring is updated when the default generation config is changed
|
|
_default_generation_config = GenerationConfig(
|
|
max_new_tokens=256,
|
|
)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.check_model_type(MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES)
|
|
|
|
def _sanitize_parameters(self, top_k=None, padding=None, truncation=None, timeout=None, **kwargs):
|
|
preprocess_params, postprocess_params = {}, {}
|
|
if padding is not None:
|
|
preprocess_params["padding"] = padding
|
|
if truncation is not None:
|
|
preprocess_params["truncation"] = truncation
|
|
if timeout is not None:
|
|
preprocess_params["timeout"] = timeout
|
|
if top_k is not None:
|
|
postprocess_params["top_k"] = top_k
|
|
|
|
forward_params = {}
|
|
if getattr(self, "assistant_model", None) is not None:
|
|
forward_params["assistant_model"] = self.assistant_model
|
|
if getattr(self, "assistant_tokenizer", None) is not None:
|
|
forward_params["tokenizer"] = self.tokenizer
|
|
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
|
|
|
|
return preprocess_params, forward_params, postprocess_params
|
|
|
|
def __call__(
|
|
self,
|
|
image: Union["Image.Image", str, list["Image.Image"], list[str], "KeyDataset"],
|
|
question: Optional[Union[str, list[str]]] = None,
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Answers open-ended questions about images. The pipeline accepts several types of inputs which are detailed
|
|
below:
|
|
|
|
- `pipeline(image=image, question=question)`
|
|
- `pipeline({"image": image, "question": question})`
|
|
- `pipeline([{"image": image, "question": question}])`
|
|
- `pipeline([{"image": image, "question": question}, {"image": image, "question": question}])`
|
|
|
|
Args:
|
|
image (`str`, `list[str]`, `PIL.Image`, `list[PIL.Image]` or `KeyDataset`):
|
|
The pipeline handles three types of images:
|
|
|
|
- A string containing a http link pointing to an image
|
|
- A string containing a local path to an image
|
|
- An image loaded in PIL directly
|
|
|
|
The pipeline accepts either a single image or a batch of images. If given a single image, it can be
|
|
broadcasted to multiple questions.
|
|
For dataset: the passed in dataset must be of type `transformers.pipelines.pt_utils.KeyDataset`
|
|
Example:
|
|
```python
|
|
>>> from transformers.pipelines.pt_utils import KeyDataset
|
|
>>> from datasets import load_dataset
|
|
|
|
>>> dataset = load_dataset("detection-datasets/coco")
|
|
>>> oracle(image=KeyDataset(dataset, "image"), question="What's in this image?")
|
|
|
|
```
|
|
question (`str`, `list[str]`):
|
|
The question(s) asked. If given a single question, it can be broadcasted to multiple images.
|
|
If multiple images and questions are given, each and every question will be broadcasted to all images
|
|
(same effect as a Cartesian product)
|
|
top_k (`int`, *optional*, defaults to 5):
|
|
The number of top labels that will be returned by the pipeline. If the provided number is higher than
|
|
the number of labels available in the model configuration, it will default to the number of labels.
|
|
timeout (`float`, *optional*, defaults to None):
|
|
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
|
|
the call may block forever.
|
|
Return:
|
|
A dictionary or a list of dictionaries containing the result. The dictionaries contain the following keys:
|
|
|
|
- **label** (`str`) -- The label identified by the model.
|
|
- **score** (`int`) -- The score attributed by the model for that label.
|
|
"""
|
|
is_dataset = isinstance(image, KeyDataset)
|
|
is_image_batch = isinstance(image, list) and all(isinstance(item, (Image.Image, str)) for item in image)
|
|
is_question_batch = isinstance(question, list) and all(isinstance(item, str) for item in question)
|
|
|
|
if isinstance(image, (Image.Image, str)) and isinstance(question, str):
|
|
inputs = {"image": image, "question": question}
|
|
elif (is_image_batch or is_dataset) and isinstance(question, str):
|
|
inputs = [{"image": im, "question": question} for im in image]
|
|
elif isinstance(image, (Image.Image, str)) and is_question_batch:
|
|
inputs = [{"image": image, "question": q} for q in question]
|
|
elif (is_image_batch or is_dataset) and is_question_batch:
|
|
question_image_pairs = []
|
|
for q in question:
|
|
for im in image:
|
|
question_image_pairs.append({"image": im, "question": q})
|
|
inputs = question_image_pairs
|
|
else:
|
|
"""
|
|
Supports the following format
|
|
- {"image": image, "question": question}
|
|
- [{"image": image, "question": question}]
|
|
- Generator and datasets
|
|
"""
|
|
inputs = image
|
|
results = super().__call__(inputs, **kwargs)
|
|
return results
|
|
|
|
def preprocess(self, inputs, padding=False, truncation=False, timeout=None):
|
|
image = load_image(inputs["image"], timeout=timeout)
|
|
model_inputs = self.tokenizer(
|
|
inputs["question"],
|
|
return_tensors=self.framework,
|
|
padding=padding,
|
|
truncation=truncation,
|
|
)
|
|
image_features = self.image_processor(images=image, return_tensors=self.framework)
|
|
if self.framework == "pt":
|
|
image_features = image_features.to(self.torch_dtype)
|
|
model_inputs.update(image_features)
|
|
return model_inputs
|
|
|
|
def _forward(self, model_inputs, **generate_kwargs):
|
|
if self.model.can_generate():
|
|
# User-defined `generation_config` passed to the pipeline call take precedence
|
|
if "generation_config" not in generate_kwargs:
|
|
generate_kwargs["generation_config"] = self.generation_config
|
|
|
|
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
|
|
else:
|
|
model_outputs = self.model(**model_inputs)
|
|
return model_outputs
|
|
|
|
def postprocess(self, model_outputs, top_k=5):
|
|
if self.model.can_generate():
|
|
return [
|
|
{"answer": self.tokenizer.decode(output_ids, skip_special_tokens=True).strip()}
|
|
for output_ids in model_outputs
|
|
]
|
|
else:
|
|
if top_k > self.model.config.num_labels:
|
|
top_k = self.model.config.num_labels
|
|
|
|
if self.framework == "pt":
|
|
probs = model_outputs.logits.sigmoid()[0]
|
|
scores, ids = probs.topk(top_k)
|
|
else:
|
|
raise ValueError(f"Unsupported framework: {self.framework}")
|
|
|
|
scores = scores.tolist()
|
|
ids = ids.tolist()
|
|
return [{"score": score, "answer": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)]
|