team-10/env/Lib/site-packages/transformers/generation/continuous_batching.py
2025-08-02 07:34:44 +02:00

1514 lines
62 KiB
Python

# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import queue
import statistics
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Optional, Union
import torch
import torch.nn as nn
from tokenizers import Tokenizer
from tokenizers.decoders import DecodeStream
from torch.profiler import profile, schedule, tensorboard_trace_handler
from tqdm import tqdm
from ..configuration_utils import PretrainedConfig
from ..generation.configuration_utils import GenerationConfig
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
class RequestStatus(Enum):
"""Status of a generation request through its lifecycle."""
PENDING = "pending"
PREFILLING = "prefilling"
PREFILLING_SPLIT = "prefilling_split"
SPLIT_PENDING_REMAINDER = "split_pending_remainder"
DECODING = "decoding"
FINISHED = "finished"
FAILED = "failed"
# Setup your logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
@dataclass
class GenerationOutput:
"""Tracks the output of a generation request.
Attributes:
request_id (str): The ID of the generation request.
prompt_ids (list[int]): The IDs of the prompt tokens.
generated_tokens (list[int]): The generated tokens.
logprobs (list[float]): The log probabilities of the generated tokens.
error (Optional[str]): Any error message associated with the request. When None, the request was successful.
"""
request_id: str
prompt_ids: list[int] = field(default_factory=list)
generated_tokens: list[int] = field(default_factory=list)
logprobs: list[float] = field(default_factory=list)
error: Optional[str] = None
status: RequestStatus = RequestStatus.PENDING
created_time: float = field(default_factory=time.time)
next_token: Optional[int] = field(default_factory=int)
@dataclass
class RequestState:
"""Tracks the state of a generation request through its lifecycle.
Attributes:
status (RequestStatus): can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
"""
# Required fields
request_id: str
prompt_ids: Optional[list[int]] = None # the one being processed
full_prompt_ids: Optional[list[int]] = None # the full prompt
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests
static_outputs: list[int] = field(default_factory=list)
allocated_blocks: list[int] = field(default_factory=list)
position_offset: int = 0 # Current position in the sequence for position_ids
status: RequestStatus = RequestStatus.PENDING
max_new_tokens: int = 20
eos_token_id: int = -1
created_time: float = field(default_factory=time.time)
error: Optional[str] = None
next_token: Optional[str] = None
def current_len(self) -> int:
"""Get the current length of the sequence (prompt + generated tokens)."""
return self.position_offset
def generated_len(self) -> int:
"""Get the number of tokens generated so far."""
return len(self.static_outputs)
@traced
def update_with_token(self, token_id: int) -> bool:
"""Update the request with a newly generated token and check for completion.
Args:
token_id: The token ID to add to the output sequence
Returns:
bool: True if the request is now complete, False otherwise
"""
# Only update if we're in decoding state
if self.status != RequestStatus.DECODING:
return False
is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
is_max_len = self.generated_len() >= self.max_new_tokens
# Only add the token if we're not finishing due to max length
# (EOS tokens should still be added to the output)
if not (is_max_len and not is_eos):
self.static_outputs.extend([token_id])
if is_eos or is_max_len:
self.status = RequestStatus.FINISHED
return True
return False
def __repr__(self):
return f"RequestState(\n\trequest_id={self.request_id},\n\tstatus={self.status},\n\tout_tokens={self.generated_len()},\n\tquery_length={len(self.prompt_ids)}, \n\tremaining_tokens={len(self.remaining_prompt_ids)}, \n\tkv_length={self.position_offset}\n\tfull_prompt_lenght={len(self.full_prompt_ids)},\n\tallocated_blocks={self.allocated_blocks},\n\tgenerated_tokens={self.static_outputs}\n)"
def to_generation_output(self):
"""Convert the request state to a GenerationOutput object."""
return GenerationOutput(
request_id=self.request_id,
prompt_ids=self.full_prompt_ids,
status=self.status,
generated_tokens=self.static_outputs,
logprobs=[],
error=self.error,
next_token=self.next_token,
)
@attach_tracer()
class PagedAttentionCache:
def __init__(
self,
config: PretrainedConfig,
generation_config: GenerationConfig,
device: torch.device,
dtype: torch.dtype = torch.float16,
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
initial_prompt_shapes: Optional[list[list[int]]] = None,
tp_size: Optional[int] = None,
) -> None:
"""Initialize a paged attention cache for efficient memory usage.
Args:
config: Model configuration
generation_config: Generation configuration containing cache parameters
device: Device for the cache tensors
dtype: Data type for the cache tensors
layer_device_map: Optional mapping of layer indices to devices
initial_prompt_shapes: Optional sample prompts to help calculate optimal cache size
"""
# Extract model dimensions
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.num_hidden_layers = config.num_hidden_layers
# Calculate optimal block size and number if not provided
num_blocks = getattr(generation_config, "num_blocks", None)
block_size = getattr(generation_config, "block_size", None)
if num_blocks is None or block_size is None:
logger.info("Calculating optimal block size and number...")
num_blocks, block_size = compute_optimal_blocks(
device, config, generation_config, initial_prompt_shapes or [], dtype, median_prefill_length=200
)
logger.info(f"Using calculated num_blocks={num_blocks}, block_size={block_size}")
self.block_size = block_size
self.num_blocks = num_blocks
num_key_value_heads = self.num_key_value_heads
if tp_size is not None and tp_size > 1:
if num_key_value_heads % tp_size != 0:
raise ValueError(
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
)
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
num_key_value_heads //= tp_size
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
self.dtype = dtype
self.device = device
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
for idx in range(config.num_hidden_layers):
layer_device = layer_device_map[idx] if layer_device_map is not None else device
new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
# Block management data structures
self._free_blocks = deque(range(num_blocks))
self._block_tables: dict[str, list[int]] = {}
@traced
def allocate_blocks(self, n_blocks: int, request_id: str) -> list[int]:
"""Allocates n_blocks for a given request_id."""
if len(self._free_blocks) < n_blocks:
return False
allocated = []
for _ in range(n_blocks):
allocated.append(self._free_blocks.popleft())
if request_id not in self._block_tables:
self._block_tables[request_id] = []
self._block_tables[request_id].extend(allocated)
return allocated
@traced
def free_blocks(self, request_id: str) -> None:
"""Frees all blocks associated with a request_id."""
if request_id in self._block_tables:
blocks_to_free = self._block_tables.pop(request_id)
self._free_blocks.extend(blocks_to_free)
else:
logger.warning(f"Attempted to free blocks for non-existent request_id: {request_id}")
def get_num_free_blocks(self) -> int:
"""Returns the number of free blocks available."""
return len(self._free_blocks)
def get_block_table(self, request_id: str) -> list[int]:
"""Returns the block table for a request."""
return self._block_tables.get(request_id, [])
@traced
def _get_physical_indices(self, state: RequestState, logical_indices: list[int]) -> list[int]:
"""
Maps logical sequence indices to physical cache indices using the block table, using PyTorch.
Args:
request_id: The request ID.
logical_indices: A list of logical indices.
Returns:
A list of physical indices.
Raises:
ValueError: If no block table is found for the request ID.
IndexError: If a logical index maps to a block index that is out of bounds.
"""
request_id = state.request_id
block_table = self._block_tables.get(request_id)
if not block_table:
raise ValueError(f"No block table found for request {request_id}")
block_size = self.block_size
physical_indices = []
for idx in logical_indices:
block_idx = idx // block_size
block_offset = idx % block_size
if block_idx >= len(block_table):
raise IndexError(
f"Logical index {idx} maps to block index {block_idx} which is out of bounds "
f"for request {request_id}"
)
physical_block_num = block_table[block_idx]
physical_index = physical_block_num * block_size + block_offset
physical_indices.append(physical_index)
return physical_indices
@traced
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
read_index,
write_index,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Reshape cache for easier indexing
total_slots = self.num_blocks * self.block_size
k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
k_cache_flat[:, write_index, :] = key_states[0]
v_cache_flat[:, write_index, :] = value_states[0]
return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :]
class Scheduler(ABC):
"""
Abstract base class for scheduling requests in the continuous batch processor.
It is expected that cache allocation and scheduling logic will be implemented in subclasses.
"""
def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
self.active_requests: dict[str, RequestState] = {}
self.waiting_requests: dict[str, RequestState] = {}
self.waiting_requests_order: deque[str] = deque()
self.cache = cache
self.retain_cache_on_finish = retain_cache_on_finish
@abstractmethod
def add_waiting_request(self, state: RequestState):
"""Add a request to the waiting list."""
pass
@abstractmethod
def schedule_batch(self, token_budget: int) -> list[RequestState]:
pass
@traced
def has_pending_requests(self) -> bool:
"""Check if there are requests ready to be processed."""
return self.active_requests or self.waiting_requests
@abstractmethod
def finish_request(self, request_id: str, evict_from_cache: bool = True):
"""Finish processing a request and free its allocated blocks."""
pass
@traced
def get_active_request_static_outputs(self, request_id: str) -> list[int]:
if request_id in self.active_requests:
return self.active_requests[request_id].static_outputs
return []
@attach_tracer()
class FIFOScheduler(Scheduler):
@traced
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int):
# 1. we check that the occupancy is less than the requested length
# 2. we allocate enough blocks to cover the requested length
current_len = state.current_len()
occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len
if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0):
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
if not allocated:
return False
state.allocated_blocks.extend(allocated)
return True
@traced(span_name="prepare_request")
def _prepare_request_for_processing(
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
):
"""Prepare a request for processing in the current batch."""
request_tokens = (
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
)
if len(request_tokens) < token_budget:
# Can process the entire prompt/remainder
if state.status == RequestStatus.PENDING:
self.active_requests[state.request_id] = state
state.status = RequestStatus.PREFILLING
request_ids_to_remove_from_waiting.add(state.request_id)
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
state.status = RequestStatus.PREFILLING
state.prompt_ids = state.remaining_prompt_ids
state.remaining_prompt_ids = []
else:
# Need to split the request
if state.status == RequestStatus.PENDING:
self.active_requests[state.request_id] = state
state.status = RequestStatus.PREFILLING_SPLIT
request_ids_to_remove_from_waiting.add(state.request_id)
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
state.status = RequestStatus.PREFILLING_SPLIT
state.remaining_prompt_ids = request_tokens[token_budget:]
state.prompt_ids = request_tokens[:token_budget]
@traced
def add_waiting_request(self, state: RequestState):
"""Add a request to the waiting list."""
if self.retain_cache_on_finish and state.request_id in self.active_requests:
old_state = self.active_requests.pop(state.request_id)
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :]
state.allocated_blocks = old_state.allocated_blocks
state.position_offset = old_state.position_offset
self.waiting_requests[state.request_id] = state
self.waiting_requests_order.append(state.request_id)
@traced
def schedule_batch(self, token_budget: int) -> list[RequestState]:
priority_states: list[RequestState] = []
second_priority_states: list[RequestState] = []
scheduled_requests = []
for state in self.active_requests.values():
if state.status == RequestStatus.DECODING:
priority_states.append(state)
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
second_priority_states.append(state)
# Add waiting requests to second priority
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
candidates = priority_states + second_priority_states
request_ids_to_remove_from_waiting = set()
for state in candidates:
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
break
continue
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
token_budget -= request_len
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
if token_budget == 0:
break
self.waiting_requests_order = deque(
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
)
return scheduled_requests
@traced
def finish_request(self, request_id: str, evict_from_cache: bool = True):
if evict_from_cache:
self.cache.free_blocks(request_id)
if request_id in self.active_requests:
del self.active_requests[request_id]
@attach_tracer()
class PrefillFirstScheduler(Scheduler):
@traced
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int):
# 1. we check that the occupancy is less than the requested length
# 2. we allocate enough blocks to cover the requested length
current_len = state.current_len()
occupancy = len(state.allocated_blocks) * self.cache.block_size - current_len
if occupancy < len_next_tokens or (len(state.allocated_blocks) == 0):
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
if not allocated:
return False
state.allocated_blocks.extend(allocated)
return True
@traced(span_name="prepare_request")
def _prepare_request_for_processing(
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
):
"""Prepare a request for processing in the current batch."""
request_tokens = (
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
)
if len(request_tokens) < token_budget:
# Can process the entire prompt/remainder
if state.status == RequestStatus.PENDING:
self.active_requests[state.request_id] = state
state.status = RequestStatus.PREFILLING
request_ids_to_remove_from_waiting.add(state.request_id)
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
state.status = RequestStatus.PREFILLING
state.prompt_ids = state.remaining_prompt_ids
state.remaining_prompt_ids = []
else:
# Need to split the request
if state.status == RequestStatus.PENDING:
self.active_requests[state.request_id] = state
state.status = RequestStatus.PREFILLING_SPLIT
request_ids_to_remove_from_waiting.add(state.request_id)
elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
state.status = RequestStatus.PREFILLING_SPLIT
state.remaining_prompt_ids = request_tokens[token_budget:]
state.prompt_ids = request_tokens[:token_budget]
@traced
def add_waiting_request(self, state: RequestState):
"""Add a request to the waiting list."""
if self.retain_cache_on_finish and state.request_id in self.active_requests:
old_state = self.active_requests.pop(state.request_id)
state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error?
state.allocated_blocks = old_state.allocated_blocks
state.position_offset = old_state.position_offset
self.waiting_requests[state.request_id] = state
self.waiting_requests_order.append(state.request_id)
@traced
def schedule_batch(self, token_budget: int) -> list[RequestState]:
priority_states: list[RequestState] = []
second_priority_states: list[RequestState] = []
scheduled_requests = []
for state in self.active_requests.values():
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
priority_states.append(state)
elif state.status == RequestStatus.DECODING:
second_priority_states.append(state)
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
candidates = priority_states + second_priority_states
request_ids_to_remove_from_waiting = set()
for state in candidates:
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
break
continue
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
token_budget -= request_len
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
if token_budget == 0:
break
self.waiting_requests_order = deque(
[req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
)
return scheduled_requests
@traced
def finish_request(self, request_id: str, evict_from_cache: bool = True):
if evict_from_cache:
self.cache.free_blocks(request_id)
if request_id in self.active_requests:
del self.active_requests[request_id]
@traced(standalone=True)
def compute_optimal_blocks(
device: torch.device,
config: PretrainedConfig,
generation_config: GenerationConfig,
inputs: list[list[int]],
dtype: torch.dtype = torch.bfloat16,
safety_margin: float = 0.9,
median_prefill_length: Optional[int] = None,
):
"""Calculate optimal number and size of blocks for the KV cache.
Args:
device: The device where the model runs
config: The model configuration
generation_config: The generation configuration
inputs: Sample input sequences to estimate memory requirements
dtype: Data type for cache tensors
safety_margin: Fraction of available memory to use
median_prefill_length: Override for median prefill length calculation
Returns:
Tuple of (num_blocks, block_size)
"""
# Extract model dimensions
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
num_hidden_layers = getattr(config, "num_hidden_layers", 40)
# Get available device memory
if device.type == "cuda":
device_properties = torch.cuda.get_device_properties(device)
total_memory = device_properties.total_memory
allocated_memory = torch.cuda.memory_allocated(device)
reserved_memory = torch.cuda.memory_reserved(device)
available_memory = total_memory - max(allocated_memory, reserved_memory)
elif device.type == "mps":
logger.warning("MPS memory estimation is approximate. Using conservative defaults.")
return 2048, 256
else:
logger.warning(f"Unsupported device type {device.type} for optimal block calculation. Using defaults.")
return 32, 128
# Apply safety margin
available_memory = int(available_memory * safety_margin)
if available_memory <= 0:
logger.warning("Not enough available memory. Using minimum configuration.")
return 8, 128 # Minimum viable configuration
# Calculate memory per token
dtype_size = torch.tensor([], dtype=dtype).element_size()
memory_per_token = 2 * num_kv_heads * head_dim * dtype_size * num_hidden_layers # For K and V caches
# Estimate sequence length requirements
tokens_to_generate = getattr(generation_config, "max_new_tokens") or 20
if median_prefill_length is None and inputs:
non_empty_inputs = [len(seq) for seq in inputs if seq]
median_prefill_length = int(statistics.median(non_empty_inputs)) if non_empty_inputs else 64
elif median_prefill_length is None:
median_prefill_length = 64 # Reasonable default if no inputs provided
# Total sequence length including generated tokens
seq_length = median_prefill_length + tokens_to_generate
# Calculate block parameters
MIN_BLOCK_SIZE = 16
# Estimate number of concurrent sequences
per_sequence_memory = seq_length * memory_per_token
max_concurrent_sequences = max(1, int(available_memory // per_sequence_memory))
# Total tokens that can fit in memory
total_tokens = available_memory // memory_per_token
# Calculate block size (rounded to power of 2)
initial_block_size = max(MIN_BLOCK_SIZE, total_tokens // (max_concurrent_sequences * 2))
block_size = 1 << (initial_block_size - 1).bit_length() # Round to power of 2
# Calculate number of blocks
num_blocks = max(1, total_tokens // block_size)
logger.info(
f"Optimal cache: {num_blocks} blocks of size {block_size} "
f"(can handle ~{num_blocks * block_size // seq_length} sequences of length {seq_length})"
)
return int(num_blocks), int(block_size)
@dataclass
class PagedAttentionArgs:
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
cumulative_seqlens_q: torch.Tensor
cumulative_seqlens_k: torch.Tensor
max_seqlen_q: int
max_seqlen_k: int
write_index: torch.Tensor
read_index: torch.Tensor
logits_indices: torch.Tensor
block_tables: dict[str, list[int]]
cache: PagedAttentionCache
use_cache: bool = False
@traced
def create_document_mask(cumulative_seqlens_q, cumulative_seqlens_k):
# Number of documents
valid_docs_q = cumulative_seqlens_q[1:] > cumulative_seqlens_q[:-1]
valid_docs_k = cumulative_seqlens_k[1:] > cumulative_seqlens_k[:-1]
num_valid_docs = min(valid_docs_q.sum(), valid_docs_k.sum())
# Trim to valid docs
cumulative_seqlens_q = cumulative_seqlens_q[: num_valid_docs + 1]
cumulative_seqlens_k = cumulative_seqlens_k[: num_valid_docs + 1]
total_q = cumulative_seqlens_q[-1]
total_k = cumulative_seqlens_k[-1]
q_indices = torch.arange(total_q, device=cumulative_seqlens_q.device)
k_indices = torch.arange(total_k, device=cumulative_seqlens_k.device)
q_doc_ids = torch.bucketize(q_indices, cumulative_seqlens_q[1:], right=True)
k_doc_ids = torch.bucketize(k_indices, cumulative_seqlens_k[1:], right=False)
doc_mask = q_doc_ids[:, None] == k_doc_ids[None, :]
# apply causal mask where no decoding (same nb of q than k)
is_causal = ~(cumulative_seqlens_q[1:] - cumulative_seqlens_q[:-1] == 1) * cumulative_seqlens_q[1:]
apply_causal = torch.bucketize(q_indices, is_causal, right=True)[:, None] == k_doc_ids
# TODO don't apply on prefill splitting
causal_mask = torch.triu(torch.ones(total_q, total_k, device=q_doc_ids.device), diagonal=1).bool()
doc_mask.masked_fill_((apply_causal & causal_mask), False)
return doc_mask
# Continuous Batch Processor (Internal Logic)
@attach_tracer()
class ContinuousBatchProcessor:
def __init__(
self,
cache: PagedAttentionCache,
config: PretrainedConfig,
generation_config: GenerationConfig,
input_queue: queue.Queue,
output_queue: queue.Queue,
stop_event: threading.Event,
model_device: torch.device,
model_dtype: torch.dtype,
scheduler: Scheduler,
streaming: bool = False,
manual_eviction: bool = False,
):
"""Initialize the continuous batch processor.
Args:
cache: The paged attention cache to use
generation_config: The generation configuration
input_queue: Queue for incoming requests
output_queue: Queue for outgoing results
stop_event: Event to signal processing should stop
model_device: Device for model inputs/outputs
model_dtype: Data type for model inputs/outputs
streaming: Whether to stream tokens as they're generated
"""
self.cache = cache
self.config = config
self.generation_config = generation_config
self.input_queue = input_queue
self.output_queue = output_queue
self.stop_event = stop_event
self.model_device = model_device
self.model_dtype = model_dtype
self.scheduler = scheduler
self.streaming = streaming
self.manual_eviction = manual_eviction
self.requests_in_batch: list[RequestState] = []
# Get batch size parameters from generation config
self._configure_batch_parameters()
# Set up metrics collector
self.metrics = ContinuousBatchProcessorMetrics(self.max_batch_tokens)
self.setup_static_tensors()
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced(standalone=True)
def setup_static_tensors(self):
T = self.max_batch_tokens
max_token_budget = self.cache.num_blocks * self.cache.block_size
tensor_metadata = {"dtype": torch.int32, "device": self.model_device}
self.tensor_metadata = tensor_metadata
self.input_ids = torch.zeros((1, T), **tensor_metadata)
self.position_ids = torch.zeros((1, T), **tensor_metadata)
self.attention_mask = torch.zeros(
(1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device
)
self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata)
self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata)
self.write_index = torch.zeros((T,), **tensor_metadata)
self.read_index = torch.zeros((max_token_budget,), **tensor_metadata)
self.logits_indices = torch.full((T,), -1, **tensor_metadata)
self.max_seqlen_q = 0
self.max_seqlen_k = 0
self.output_ids = torch.full((1, T), -1, **tensor_metadata)
@traced
@torch.no_grad()
def reset_static_tensors(self):
"""Reset static tensors for the next batch."""
self.input_ids.zero_()
self.position_ids.zero_()
self.attention_mask.fill_(torch.finfo(self.model_dtype).min)
self.cumulative_seqlens_q.zero_()
self.cumulative_seqlens_k.zero_()
self.write_index.fill_(-1)
self.read_index.fill_(-1)
self.logits_indices.fill_(-1)
self.max_seqlen_q = 0
self.max_seqlen_k = 0
self.output_ids.zero_()
def get_model_kwargs(self) -> PagedAttentionArgs:
"""Get model keyword arguments for the current batch."""
# torch.set_printoptions(threshold=100000,linewidth=10000)
return {
"input_ids": self.input_ids,
"position_ids": self.position_ids,
"attention_mask": self.attention_mask,
"cumulative_seqlens_q": self.cumulative_seqlens_q,
"cumulative_seqlens_k": self.cumulative_seqlens_k,
"write_index": self.write_index,
"read_index": self.read_index,
"logits_indices": self.logits_indices,
"max_seqlen_q": self.max_seqlen_q,
"max_seqlen_k": self.max_seqlen_k,
"block_tables": self.cache._block_tables,
"cache": self.cache,
"use_cache": False,
}
def __repr__(self):
return (
f"ContinuousBatchProcessor(input_queue={self.input_queue}, output_queue={self.output_queue}, active_requests={self.scheduler.active_requests}, waiting_requests={self.scheduler.waiting_requests})"
+ self.get_model_kwargs().__repr__()
)
@traced(standalone=True)
def _configure_batch_parameters(self):
"""Set up batch processing parameters based on generation config."""
# Calculate total cache capacity
total_cache_tokens = self.cache.num_blocks * self.cache.block_size
# Get or calculate max tokens per batch
user_batch_tokens = getattr(self.generation_config, "max_batch_tokens", None)
if user_batch_tokens is not None:
self.max_batch_tokens = user_batch_tokens
else:
# Default to 1/8 of total cache capacity, adjusted for context
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
recommended_batch_size = min(total_cache_tokens // 8, self.max_context_len)
self.max_batch_tokens = max(64, recommended_batch_size)
# Context length and EOS token
self.max_context_len = getattr(self.generation_config, "max_position_embeddings", 2048)
@traced
def _get_new_requests(self):
"""Pull new requests from the input queue and add to waiting list."""
while not self.input_queue.empty():
try:
state = self.input_queue.get_nowait()
if state is None: # Sentinel value
continue
self.scheduler.add_waiting_request(state)
except queue.Empty:
break
except Exception as e:
logger.error(f"Error processing new request: {e}", exc_info=True)
state: RequestState = locals().get("state")
if state is not None:
self._handle_request_error(e, state)
@traced
def _handle_request_error(self, error, state: RequestState):
"""Handle general request processing error."""
state.status = RequestStatus.FAILED
state.error = str(error)
# Include any generated tokens if this is an active request
if isinstance(state.request_id, str):
state.static_outputs = self.scheduler.get_active_request_static_outputs(state.request_id)
else:
state.static_outputs = []
self.metrics.record_request_completion(state.created_time, state.request_id)
self.output_queue.put(state.to_generation_output())
@traced
def prepare_next_batch(self):
"""Prepare tensors and metadata for the next model forward pass."""
# Get new requests from the queue
self._get_new_requests()
if not self.scheduler.has_pending_requests():
return None
self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests))
self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens)
if not self.requests_in_batch:
return None
# Get the request objects for this batch
self.reset_static_tensors()
position_ids = []
input_ids = []
read_index = []
write_index = []
cumulative_seqlens_q = [0]
cumulative_seqlens_k = [0]
logits_indices = []
self.metrics.record_batch_metrics(self.requests_in_batch)
for state in self.requests_in_batch:
next_input_ids = state.prompt_ids
input_ids.extend(next_input_ids)
past_length = state.position_offset
query_length = len(next_input_ids)
key_length = query_length + past_length
cache_index = list(range(key_length))
positions_to_add = cache_index[past_length:]
read_indices = self.cache._get_physical_indices(state, cache_index)
write_indices = read_indices[-query_length:]
position_ids.extend(positions_to_add)
read_index.extend(read_indices)
write_index.extend(write_indices)
cumulative_seqlens_q.append(cumulative_seqlens_q[-1] + query_length)
cumulative_seqlens_k.append(cumulative_seqlens_k[-1] + key_length)
if len(state.remaining_prompt_ids) == 0:
logits_indices.append(cumulative_seqlens_q[-1] - 1)
self.max_seqlen_q = max(self.max_seqlen_q, query_length)
self.max_seqlen_k = max(self.max_seqlen_k, key_length)
state.position_offset += query_length
logger.info(
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}"
)
self._build_tensors(
input_ids,
position_ids,
read_index,
write_index,
cumulative_seqlens_q,
cumulative_seqlens_k,
logits_indices,
)
self.metrics.record_kv_cache_memory_metrics(self.cache)
@traced
def _build_tensors(
self,
input_ids,
position_ids,
read_index,
write_index,
cumulative_seqlens_q,
cumulative_seqlens_k,
logits_indices,
):
to_tensor = partial(torch.tensor, **self.tensor_metadata)
self.input_ids[:, : len(input_ids)] = to_tensor(input_ids)
self.position_ids[:, : len(position_ids)] = to_tensor(position_ids)
self.write_index[: len(write_index)] = to_tensor(write_index)
self.read_index[: len(read_index)] = to_tensor(read_index)
self.cumulative_seqlens_q[: len(cumulative_seqlens_q)] = to_tensor(cumulative_seqlens_q)
self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k)
self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
min_value = torch.finfo(self.model_dtype).min
if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call`
for i in range(len(cumulative_seqlens_q) - 1):
if (
cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
< cumulative_seqlens_k[i + 1] - cumulative_seqlens_k[i]
and cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i] >= 1
):
diagonal = (
cumulative_seqlens_k[i + 1] - (cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]) + 1
)
diagonal = diagonal - cumulative_seqlens_k[i]
else:
diagonal = 1
query_range = slice(cumulative_seqlens_q[i], cumulative_seqlens_q[i + 1])
key_range = slice(cumulative_seqlens_k[i], cumulative_seqlens_k[i + 1])
mask = torch.triu(
torch.full(
self.attention_mask[..., query_range, key_range].shape,
min_value,
dtype=self.model_dtype,
device=self.model_device,
),
diagonal=diagonal,
)
self.attention_mask[..., query_range, key_range] = mask
@traced
def _sync(self):
return self.output_ids.tolist()[0] # should be the only synch we do
@traced
def _maybe_send_output(self, state: RequestState, token: int):
"""Send output to the queue based on streaming mode and request state."""
if self.streaming:
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
@traced
def update_batch(self):
"""Update request states based on generated tokens."""
out_tokens = self._sync()
finished_request_ids = []
for i, state in enumerate(self.requests_in_batch):
req_id = state.request_id
if len(state.remaining_prompt_ids) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
state.prompt_ids = [token]
if state.update_with_token(token):
self.metrics.record_request_completion(state.created_time, state.request_id)
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
finished_request_ids.append(req_id)
self._maybe_send_output(state, token)
elif state.status == RequestStatus.PREFILLING_SPLIT:
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
@traced
def has_pending_requests(self) -> bool:
"""Check if there are any active or waiting requests."""
return self.scheduler.has_pending_requests()
@traced
def handle_batch_error(self, error):
"""Handle errors during batch processing."""
failed_reqs = self.requests_in_batch
for req in failed_reqs:
self._handle_request_error(error, req)
self.scheduler.finish_request(req.request_id)
@traced
def fail_all_requests(self, error):
"""Fail all active requests with the given error.
Args:
error: The error to report in the failure message
"""
for state in self.scheduler.active_requests.values():
self._handle_request_error(error, state)
self.scheduler.finish_request(state.request_id)
# Also fail any requests in the waiting queue
for req_id in list(self.scheduler.waiting_requests.keys()):
state = self.scheduler.waiting_requests.pop(req_id)
self._handle_request_error(error, state)
# Clear the ordering queue
self.scheduler.waiting_requests_order.clear()
SCHEDULER_MAPPING = {
"fifo": FIFOScheduler,
"prefill_first": PrefillFirstScheduler,
}
# Manager Class (User Interface)
@attach_tracer()
class ContinuousBatchingManager:
"""Manager for handling continuous batching of generation requests.
This class provides the user interface for submitting generation requests,
retrieving results, and managing the background generation thread.
"""
def __init__(
self,
model,
generation_config: GenerationConfig,
manual_eviction: bool = False,
max_queue_size=0,
streaming: bool = True,
):
"""Initialize the continuous batching manager.
Args:
model: The language model for generation
generation_config: Configuration for generation parameters
max_queue_size: Maximum size of the request queue (0 = unlimited)
streaming: Whether to stream tokens as they are generated
"""
self.model = model
self.generation_config = generation_config
self.input_queue = queue.Queue(maxsize=max_queue_size)
self.output_queue = queue.Queue()
self.stop_event = threading.Event()
self.streaming = streaming
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
self._generation_thread = None
self._request_counter = 0
self._request_lock = threading.Lock()
self.model.generation_config.top_p = None
self.do_sample = getattr(generation_config, "do_sample", True)
generation_config = model.generation_config if generation_config is None else generation_config
self.logit_processor = self.model._get_logits_processor(generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
self.decode_stream = DecodeStream(skip_special_tokens=True)
@traced
def start(self):
"""Start the background generation thread."""
if self._generation_thread is not None and self._generation_thread.is_alive():
logger.warning("Manager thread is already running.")
return
self._result_queue = queue.Queue()
self._generation_thread = threading.Thread(target=self._run_generation_loop)
self._generation_thread.start()
logger.info("Continuous batching manager started.")
def is_running(self):
"""Check if the background generation thread is running."""
return self._generation_thread is not None and self._generation_thread.is_alive()
def stop(self, block: bool = False, timeout: Optional[float] = None):
"""Signal the background thread to stop.
Args:
block: Whether to wait for the thread to stop
timeout: Maximum time to wait for the thread to stop
"""
if self._generation_thread is None:
logger.warning("Manager not started.")
return
if not self.stop_event.is_set():
self.stop_event.set()
logger.info("Stopping continuous batching manager...")
if block:
self.join(timeout)
def join(self, timeout: Optional[float] = None):
"""Wait for the background thread to finish.
Args:
timeout: Maximum time to wait for the thread to stop
"""
if self._generation_thread is not None:
self._generation_thread.join(timeout=timeout)
if self._generation_thread.is_alive():
logger.warning("Generation thread did not exit after join timeout.")
else:
logger.info("Continuous Batching Manager stopped.")
self._generation_thread = None
def add_request(
self, input_ids: list[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None
) -> str:
"""Add a new generation request to the queue.
Args:
input_ids: Input token IDs to use as prompt
request_id: Optional custom request ID (auto-generated if None)
**kwargs: Additional generation parameters
Returns:
str: The request ID
"""
if request_id is None:
with self._request_lock:
request_id = f"req_{self._request_counter}"
self._request_counter += 1
max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens
state = RequestState(
request_id=request_id,
prompt_ids=list(input_ids),
full_prompt_ids=list(input_ids),
max_new_tokens=max_new_tokens,
eos_token_id=self.generation_config.eos_token_id,
)
# Use block=True with timeout to handle backpressure if queue is full
self.input_queue.put(state, block=True, timeout=10) # XXX: pass timeout as fn arg?
logger.debug(f"Added request {request_id} to queue.")
return request_id
def add_requests(self, inputs: list[list[int]], **kwargs):
for i, input_ids in enumerate(inputs):
# Assign a predictable request ID for ordering results later
req_id = f"batch_req_{i}"
self.add_request(input_ids, request_id=req_id, **kwargs)
def get_result(self, timeout=None) -> Optional[GenerationOutput]:
"""Retrieve one result from the output queue.
Args:
timeout: Maximum time to wait for a result
Returns:
Optional[Dict]: The result data or None if timeout
"""
if self._generation_thread is None and self.output_queue.empty():
return None
try:
result = self.output_queue.get(block=True, timeout=timeout)
logger.debug(f"Retrieved result for request {result.request_id}")
return result
except queue.Empty:
return None
def __iter__(self):
"""Iterate over results as they become available."""
while (
self._generation_thread is not None and self._generation_thread.is_alive() or not self.output_queue.empty()
):
result = self.get_result(timeout=0.1) # allow the model to run for 10 seconds
if result is not None:
yield result
@traced
def warmup(self, batch_processor):
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
# Warmup the model with a dummy forward pass
self._generation_step(batch_processor)
torch.cuda.current_stream().wait_stream(stream)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self._generation_step(batch_processor)
@traced
# @torch.compile
def _generation_step(self, batch_processor: ContinuousBatchProcessor):
"""Perform a single generation step. This is cuda graphed"""
batch_data = batch_processor.get_model_kwargs()
with torch.no_grad():
logits = self._model_forward(batch_data)
if self.log_prob_generation:
batch_processor.output_probs.copy_(logits) # TODO
probs = self._process_logit(batch_data, logits)
self._sample(batch_processor, probs)
@traced(span_name="model_forward")
def _model_forward(self, batch_data):
return self.model(**batch_data).logits
@traced(span_name="logit_processing")
def _process_logit(self, batch_data, logits):
# Pass continuous batching context to logits processor if it supports it. TODO we should find a way to make this a little bit cleaner!
if hasattr(self.logit_processor, "set_continuous_batching_context"):
self.logit_processor.set_continuous_batching_context(
batch_data["logits_indices"], batch_data["cumulative_seqlens_q"]
)
return self.logit_processor(batch_data["input_ids"], logits)
@traced(span_name="sampling")
def _sample(self, batch_processor: ContinuousBatchProcessor, probs):
if self.do_sample: # sample
probs = nn.functional.softmax(probs, dim=-1)
next_tokens = torch.multinomial(probs[0], num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
batch_processor.output_ids.copy_(next_tokens)
def _run_generation_loop(self):
"""Main processing loop running in the background thread."""
batch_processor = None
try:
paged_attention_cache = PagedAttentionCache(
self.model.config,
self.generation_config,
self.model.device,
self.model.dtype,
tp_size=getattr(self.model, "tp_size"),
)
scheduler = None
if hasattr(self.generation_config, "scheduler"):
scheduler = SCHEDULER_MAPPING.get(self.generation_config.scheduler)
if scheduler is None:
logger.warning(f"Scheduler '{scheduler}' not found. Defaulting to FIFO.")
scheduler = FIFOScheduler
else:
# Default to fifo
scheduler = FIFOScheduler
batch_processor = ContinuousBatchProcessor(
paged_attention_cache,
self.model.config,
self.generation_config,
self.input_queue,
self.output_queue,
self.stop_event,
self.model.device,
self.model.dtype,
scheduler(paged_attention_cache, self.manual_eviction),
self.streaming,
self.manual_eviction,
)
self.batch_processor = batch_processor
is_first = True
if self.profile:
tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1)
trace_handler = tensorboard_trace_handler(
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
)
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with profile(
activities=activities,
schedule=tracing_schedule,
on_trace_ready=trace_handler,
record_shapes=False,
with_stack=True,
) as prof:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
prof.step()
else:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
except Exception as e:
logger.error(f"Error in generation loop: {e}", exc_info=True)
self._handle_critical_error(e, batch_processor)
finally:
logger.info("Generation loop finished.")
@traced(span_name="generation_loop")
def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor, is_first: bool = False):
if torch.cuda.is_available():
torch.cuda.synchronize()
batch_processor.prepare_next_batch()
if torch.cuda.is_available() and self.use_cuda_graph:
if is_first:
self.warmup(batch_processor)
elif hasattr(self, "graph"):
try:
self._graph_replay()
except Exception as e:
logger.error(f"Model forward pass failed: {e}", exc_info=True)
batch_processor.handle_batch_error(e)
return
else:
self._generation_step(batch_processor)
else:
self._generation_step(batch_processor)
if torch.cuda.is_available():
torch.cuda.synchronize()
batch_processor.update_batch()
@traced(span_name="graph_replay")
def _graph_replay(self):
self.graph.replay()
@traced
def _handle_critical_error(self, error, batch_processor: Optional[ContinuousBatchProcessor]):
"""Handle critical errors that terminate the generation loop."""
# Signal stop
self.stop_event.set()
# Fail pending requests in input queue
try:
while True:
req_data = self.input_queue.get_nowait()
if batch_processor is not None:
batch_processor._handle_request_error(error, req_data)
except queue.Empty:
pass
# Fail active requests
if batch_processor is not None:
batch_processor.fail_all_requests(error)
@traced
def evict_request_from_cache(self, request_id: str):
"""Evict a request from the cache. It is assumed that the request is already finished."""
if not self.manual_eviction:
raise RuntimeError("Manual eviction is not enabled for this manager.")
if self.batch_processor is not None:
self.batch_processor.scheduler.finish_request(request_id)
class ContinuousMixin:
"""Mixin class for models to add continuous batching capabilities."""
def init_continuous_batching(
self,
generation_config: Optional[GenerationConfig] = None,
manual_eviction: bool = False,
max_queue_size: int = 0,
streaming: bool = False,
) -> ContinuousBatchingManager:
"""Initialize a manager for continuous batching inference.
Args:
generation_config: Custom generation configuration
max_queue_size: Maximum size of the input request queue
streaming: Whether to stream tokens as they are generated
Returns:
`ContinuousBatchingManager`: The manager instance to add requests and retrieve results.
"""
if not hasattr(self, "config") or not hasattr(self, "device") or not hasattr(self, "dtype"):
raise AttributeError("Model must have 'config', 'device', and 'dtype' attributes.")
gen_config = generation_config if generation_config is not None else self.generation_config
if gen_config is None:
raise ValueError("A GenerationConfig must be provided or set in the model.")
if gen_config.eos_token_id is None:
logger.warning("`eos_token_id` not set in GenerationConfig. Setting to -1 (disabled).")
gen_config.eos_token_id = -1
# Create and return the manager
return ContinuousBatchingManager(
model=self,
generation_config=gen_config,
manual_eviction=manual_eviction,
max_queue_size=max_queue_size,
streaming=streaming,
)
@traced
@torch.inference_mode()
def generate_batch(
self,
inputs: list[list[int]],
generation_config: Optional[GenerationConfig] = None,
progress_bar: bool = True,
**kwargs,
) -> list[list[int]]:
"""Generate sequences for a batch of prompts using continuous batching.
Args:
inputs: List of input token sequences (prompts)
generation_config: Optional generation configuration
**kwargs: Additional generation parameters
Returns:
`list[list[int]]`: A list containing the generated sequences (including prompt tokens
if not handled otherwise) for each input prompt, in the same order.
Returns an empty list `[]` for requests that failed.
"""
if not inputs:
return []
# Initialize manager with the batch inputs
manager = self.init_continuous_batching(generation_config=generation_config)
manager.start()
results = {}
num_requests = len(inputs)
try:
from tqdm.contrib.logging import logging_redirect_tqdm
with logging_redirect_tqdm([logger]):
with tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
) as pbar:
manager.add_requests(inputs, **kwargs)
finished_count = 0
while finished_count < num_requests:
result = manager.get_result(timeout=1)
if result:
req_id = result.request_id
if result.status == RequestStatus.FINISHED:
results[req_id] = result
finished_count += 1
pbar.update(1)
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")
break
except Exception as e:
logger.error(f"Error during batch generation: {e}", exc_info=True)
finally:
manager.stop(block=True, timeout=5.0)
return results