Adding all project files

This commit is contained in:
Martina Burlando 2025-08-02 02:00:33 +02:00
parent 6c9e127bdc
commit cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions

View file

@ -0,0 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.3.14"

View file

@ -0,0 +1,131 @@
from __future__ import annotations
import sys
import os
import ctypes
import functools
import pathlib
from typing import (
Any,
Callable,
List,
Union,
Optional,
TYPE_CHECKING,
TypeVar,
Generic,
)
from typing_extensions import TypeAlias
# Load the library
def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
"""Platform independent shared library loader"""
# Searching for the library in the current directory under the name "libllama" (default name
# for llamacpp) and "llama" (default name for this repo)
lib_paths: List[pathlib.Path] = []
# Determine the file extension based on the platform
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
lib_paths += [
base_path / f"lib{lib_base_name}.so",
]
elif sys.platform == "darwin":
lib_paths += [
base_path / f"lib{lib_base_name}.so",
base_path / f"lib{lib_base_name}.dylib",
]
elif sys.platform == "win32":
lib_paths += [
base_path / f"{lib_base_name}.dll",
base_path / f"lib{lib_base_name}.dll",
]
else:
raise RuntimeError("Unsupported platform")
cdll_args = dict() # type: ignore
# Add the library directory to the DLL search path on Windows (if needed)
if sys.platform == "win32":
os.add_dll_directory(str(base_path))
os.environ["PATH"] = str(base_path) + os.pathsep + os.environ["PATH"]
if sys.platform == "win32" and sys.version_info >= (3, 8):
os.add_dll_directory(str(base_path))
if "CUDA_PATH" in os.environ:
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
if "HIP_PATH" in os.environ:
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib"))
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
# Try to load the shared library, handling potential errors
for lib_path in lib_paths:
if lib_path.exists():
try:
return ctypes.CDLL(str(lib_path), **cdll_args) # type: ignore
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{lib_path}': {e}")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
# ctypes sane type hint helpers
#
# - Generic Pointer and Array types
# - PointerOrRef type with a type hinted byref function
#
# NOTE: Only use these for static type checking not for runtime checks
# no good will come of that
if TYPE_CHECKING:
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
class CtypesRef(Generic[CtypesCData]):
pass
CtypesPointerOrRef: TypeAlias = Union[
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
]
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
F = TypeVar("F", bound=Callable[..., Any])
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
"""Decorator for defining ctypes functions with type hints"""
def ctypes_function(
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
):
def decorator(f: F) -> F:
if enabled:
func = getattr(lib, name)
func.argtypes = argtypes
func.restype = restype
functools.wraps(f)(func)
return func
else:
return f
return decorator
return ctypes_function
def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
"""Type-annotated version of ctypes.byref"""
...
byref = _byref if TYPE_CHECKING else ctypes.byref

View file

@ -0,0 +1,12 @@
"""Internal module use at your own risk
This module provides a minimal interface for working with ggml tensors from llama-cpp-python
"""
import os
import pathlib
import llama_cpp._ctypes_extensions as ctypes_ext
libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)

View file

@ -0,0 +1,850 @@
from __future__ import annotations
import os
import ctypes
from typing import (
Dict,
List,
Tuple,
Optional,
Sequence,
Callable,
Union,
)
from dataclasses import dataclass, field
from contextlib import ExitStack
import numpy as np
import numpy.typing as npt
from .llama_types import *
from .llama_grammar import LlamaGrammar
from ._utils import suppress_stdout_stderr
import llama_cpp.llama_cpp as llama_cpp
# Python wrappers over llama.h structs
class LlamaModel:
"""Intermediate Python wrapper for a llama.cpp llama_model.
NOTE: For stability it's recommended you use the Llama class instead."""
def __init__(
self,
*,
path_model: str,
params: llama_cpp.llama_model_params,
verbose: bool = True,
):
self.path_model = path_model
self.params = params
self.verbose = verbose
self._exit_stack = ExitStack()
model = None
if not os.path.exists(path_model):
raise ValueError(f"Model path does not exist: {path_model}")
with suppress_stdout_stderr(disable=verbose):
model = llama_cpp.llama_model_load_from_file(
self.path_model.encode("utf-8"), self.params
)
if model is None:
raise ValueError(f"Failed to load model from file: {path_model}")
vocab = llama_cpp.llama_model_get_vocab(model)
if vocab is None:
raise ValueError(f"Failed to get vocab from model: {path_model}")
self.model = model
self.vocab = vocab
self.sampler = None # LlamaModel doesn't use samplers, but some cleanup code expects this attribute
def free_model():
if self.model is None:
return
llama_cpp.llama_model_free(self.model)
self.model = None
self._exit_stack.callback(free_model)
def close(self):
if self.sampler is not None:
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
for i, _ in reversed(self.custom_samplers):
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
self.custom_samplers.clear()
self._exit_stack.close()
def __del__(self):
self.close()
def vocab_type(self) -> int:
return llama_cpp.llama_vocab_type(self.vocab)
def n_vocab(self) -> int:
return llama_cpp.llama_vocab_n_tokens(self.vocab)
def n_ctx_train(self) -> int:
return llama_cpp.llama_model_n_ctx_train(self.model)
def n_embd(self) -> int:
return llama_cpp.llama_model_n_embd(self.model)
def rope_freq_scale_train(self) -> float:
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
def desc(self) -> str:
buf = ctypes.create_string_buffer(1024)
llama_cpp.llama_model_desc(self.model, buf, 1024)
return buf.value.decode("utf-8")
def size(self) -> int:
return llama_cpp.llama_model_size(self.model)
def n_params(self) -> int:
return llama_cpp.llama_model_n_params(self.model)
def get_tensor(self, name: str) -> ctypes.c_void_p:
raise NotImplementedError("get_tensor is not implemented in llama.cpp")
# Vocab
def token_get_text(self, token: int) -> str:
return llama_cpp.llama_vocab_get_text(self.vocab, token).decode("utf-8")
def token_get_score(self, token: int) -> float:
return llama_cpp.llama_vocab_get_score(self.vocab, token)
def token_get_attr(self, token: int) -> int:
return llama_cpp.llama_vocab_get_attr(self.vocab, token)
# Special tokens
def token_bos(self) -> int:
return llama_cpp.llama_vocab_bos(self.vocab)
def token_eos(self) -> int:
return llama_cpp.llama_vocab_eos(self.vocab)
def token_cls(self) -> int:
return llama_cpp.llama_vocab_cls(self.vocab)
def token_sep(self) -> int:
return llama_cpp.llama_vocab_sep(self.vocab)
def token_nl(self) -> int:
return llama_cpp.llama_vocab_nl(self.vocab)
def token_prefix(self) -> int:
return llama_cpp.llama_vocab_fim_pre(self.vocab)
def token_middle(self) -> int:
return llama_cpp.llama_vocab_fim_mid(self.vocab)
def token_suffix(self) -> int:
return llama_cpp.llama_vocab_fim_suf(self.vocab)
def token_eot(self) -> int:
return llama_cpp.llama_vocab_eot(self.vocab)
def add_bos_token(self) -> bool:
return llama_cpp.llama_vocab_get_add_bos(self.vocab)
def add_eos_token(self) -> bool:
return llama_cpp.llama_vocab_get_add_eos(self.vocab)
# Tokenization
def tokenize(self, text: bytes, add_bos: bool, special: bool):
n_ctx = self.n_ctx_train()
tokens = (llama_cpp.llama_token * n_ctx)()
n_tokens = llama_cpp.llama_tokenize(
self.vocab, text, len(text), tokens, n_ctx, add_bos, special
)
if n_tokens < 0:
n_tokens = abs(n_tokens)
tokens = (llama_cpp.llama_token * n_tokens)()
n_tokens = llama_cpp.llama_tokenize(
self.vocab, text, len(text), tokens, n_tokens, add_bos, special
)
if n_tokens < 0:
raise RuntimeError(
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
)
return list(tokens[:n_tokens])
def token_to_piece(self, token: int, special: bool = False) -> bytes:
buf = ctypes.create_string_buffer(32)
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
return bytes(buf)
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
output = b""
size = 32
buffer = (ctypes.c_char * size)()
for token in tokens:
n = llama_cpp.llama_token_to_piece(
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
)
assert n <= size
output += bytes(buffer[:n])
# NOTE: Llama1 models automatically added a space at the start of the prompt
# this line removes a leading space if the first token is a beginning of sentence token
return (
output[1:]
if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
else output
)
# Extra
def metadata(self) -> Dict[str, str]:
metadata: Dict[str, str] = {}
buffer_size = 1024
buffer = ctypes.create_string_buffer(buffer_size)
# zero the buffer
buffer.value = b"\0" * buffer_size
# iterate over model keys
for i in range(llama_cpp.llama_model_meta_count(self.model)):
nbytes = llama_cpp.llama_model_meta_key_by_index(
self.model, i, buffer, buffer_size
)
if nbytes > buffer_size:
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_key_by_index(
self.model, i, buffer, buffer_size
)
key = buffer.value.decode("utf-8")
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
self.model, i, buffer, buffer_size
)
if nbytes > buffer_size:
buffer_size = nbytes + 1
buffer = ctypes.create_string_buffer(buffer_size)
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
self.model, i, buffer, buffer_size
)
value = buffer.value.decode("utf-8")
metadata[key] = value
return metadata
@staticmethod
def default_params():
"""Get the default llama_model_params."""
return llama_cpp.llama_model_default_params()
class LlamaContext:
"""Intermediate Python wrapper for a llama.cpp llama_context.
NOTE: For stability it's recommended you use the Llama class instead."""
def __init__(
self,
*,
model: LlamaModel,
params: llama_cpp.llama_context_params,
verbose: bool = True,
):
self.model = model
self.params = params
self.verbose = verbose
self._exit_stack = ExitStack()
ctx = llama_cpp.llama_init_from_model(self.model.model, self.params)
if ctx is None:
raise ValueError("Failed to create llama_context")
self.ctx = ctx
self.memory = llama_cpp.llama_get_memory(self.ctx)
self.sampler = None # LlamaContext doesn't manage samplers directly, but some cleanup code expects this attribute
def free_ctx():
if self.ctx is None:
return
llama_cpp.llama_free(self.ctx)
self.ctx = None
self._exit_stack.callback(free_ctx)
def close(self):
self._exit_stack.close()
def __del__(self):
self.close()
def n_ctx(self) -> int:
return llama_cpp.llama_n_ctx(self.ctx)
def pooling_type(self) -> int:
return llama_cpp.llama_pooling_type(self.ctx)
def kv_cache_clear(self):
llama_cpp.llama_memory_clear(self.memory, True)
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1)
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
llama_cpp.llama_memory_seq_cp(self.memory, seq_id_src, seq_id_dst, p0, p1)
def kv_cache_seq_keep(self, seq_id: int):
llama_cpp.llama_memory_seq_keep(self.memory, seq_id)
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
llama_cpp.llama_memory_seq_add(self.memory, seq_id, p0, p1, shift)
def get_state_size(self) -> int:
return llama_cpp.llama_state_get_size(self.ctx)
# TODO: copy_state_data
# TODO: set_state_data
# TODO: llama_load_session_file
# TODO: llama_save_session_file
def decode(self, batch: LlamaBatch):
return_code = llama_cpp.llama_decode(
self.ctx,
batch.batch,
)
if return_code != 0:
raise RuntimeError(f"llama_decode returned {return_code}")
def encode(self, batch: LlamaBatch):
return_code = llama_cpp.llama_encode(
self.ctx,
batch.batch,
)
if return_code != 0:
raise RuntimeError(f"llama_encode returned {return_code}")
def set_n_threads(self, n_threads: int, n_threads_batch: int):
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
def get_logits(self):
return llama_cpp.llama_get_logits(self.ctx)
def get_logits_ith(self, i: int):
return llama_cpp.llama_get_logits_ith(self.ctx, i)
def get_embeddings(self):
return llama_cpp.llama_get_embeddings(self.ctx)
def get_embeddings_ith(self, i: int):
return llama_cpp.llama_get_embeddings_ith(self.ctx, i)
def get_embeddings_seq(self, seq_id: int):
return llama_cpp.llama_get_embeddings_seq(self.ctx, seq_id)
# Sampling functions - deprecated, use LlamaSampler instead
def set_rng_seed(self, seed: int):
raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead")
def sample_repetition_penalties(
self,
candidates: "_LlamaTokenDataArray",
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
penalty_last_n: int,
penalty_repeat: float,
penalty_freq: float,
penalty_present: float,
):
raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead")
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead")
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead")
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead")
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead")
def sample_typical(
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
):
raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead")
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead")
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead")
def sample_token_mirostat(
self,
candidates: "_LlamaTokenDataArray",
tau: float,
eta: float,
m: int,
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead")
def sample_token_mirostat_v2(
self,
candidates: "_LlamaTokenDataArray",
tau: float,
eta: float,
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
) -> int:
raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead")
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead")
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead")
# Grammar
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead")
def reset_timings(self):
llama_cpp.llama_perf_context_reset(self.ctx)
def print_timings(self):
llama_cpp.llama_perf_context_print(self.ctx)
# Utility functions
@staticmethod
def default_params():
"""Get the default llama_context_params."""
return llama_cpp.llama_context_default_params()
class LlamaBatch:
def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
):
self._n_tokens = n_tokens
self.embd = embd
self.n_seq_max = n_seq_max
self.verbose = verbose
self._exit_stack = ExitStack()
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
if batch is None:
raise ValueError("Failed to create llama_batch")
self.batch = batch
self.sampler = None # LlamaBatch doesn't use samplers, but some cleanup code expects this attribute
def free_batch():
if self.batch is None:
return
llama_cpp.llama_batch_free(self.batch)
self.batch = None
self._exit_stack.callback(free_batch)
def close(self):
self._exit_stack.close()
def __del__(self):
self.close()
def n_tokens(self) -> int:
return self.batch.n_tokens
def reset(self):
self.batch.n_tokens = 0
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
n_tokens = len(batch)
self.batch.n_tokens = n_tokens
for i in range(n_tokens):
self.batch.token[i] = batch[i]
self.batch.pos[i] = n_past + i
self.batch.seq_id[i][0] = 0
self.batch.n_seq_id[i] = 1
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
n_tokens = len(batch)
n_tokens0 = self.batch.n_tokens
self.batch.n_tokens += n_tokens
for i in range(n_tokens):
j = n_tokens0 + i
self.batch.token[j] = batch[i]
self.batch.pos[j] = i
self.batch.seq_id[j][0] = seq_id
self.batch.n_seq_id[j] = 1
self.batch.logits[j] = logits_all
self.batch.logits[n_tokens - 1] = True
class LlamaTokenDataArray:
def __init__(self, *, n_vocab: int):
self.n_vocab = n_vocab
self.candidates_data = np.recarray(
(self.n_vocab,),
dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
),
)
self.candidates = llama_cpp.llama_token_data_array(
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
size=self.n_vocab,
sorted=False,
)
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
self.sampler = None # LlamaTokenDataArray doesn't use samplers, but some cleanup code expects this attribute
def copy_logits(self, logits: npt.NDArray[np.single]):
self.candidates_data.id[:] = self.default_candidates_data_id
self.candidates_data.logit[:] = logits
self.candidates_data.p[:] = self.default_candidates_data_p
self.candidates.sorted = False
self.candidates.size = self.n_vocab
# Embedding functions
def normalize_embedding(embedding):
norm = float(np.linalg.norm(embedding))
if norm == 0.0:
return embedding
return [v / norm for v in embedding]
# Python wrappers over common/sampling structs
@dataclass
class LlamaSamplingParams:
n_prev: int = 64
n_probs: int = 0
top_k: int = 40
top_p: float = 0.95
min_p: float = 0.05
tfs_z: float = 1.00
typical_p: float = 1.00
temp: float = 0.80
penalty_last_n: int = 64
penalty_repeat: float = 1.0
penalty_freq: float = 0.00
penalty_present: float = 0.00
mirostat: int = 0
mirostat_tau: float = 5.00
mirostat_eta: float = 0.10
penalize_nl: bool = True
grammar: str = ""
cfg_negative_prompt: str = ""
cfg_scale: float = 1.00
logit_bias: dict[int, float] = field(default_factory=dict)
@dataclass
class LlamaSamplingContext:
params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams)
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
grammar: Optional[LlamaGrammar] = None
# NOTE: Missing parsed_grammar
prev: list[int] = field(default_factory=list)
cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
def reset(self):
self.prev = []
self.cur = []
if self.grammar is not None:
self.grammar.reset()
def cp(self):
return LlamaSamplingContext(
params=self.params,
mirostat_mu=self.mirostat_mu,
grammar=self.grammar,
prev=self.prev.copy(),
cur=self.cur.copy(),
)
def last(self) -> Optional[int]:
if len(self.prev) > 0:
return self.prev[-1]
else:
return None
def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
def sample(
self,
ctx_main: LlamaContext,
idx: int = 0,
logits_array: Optional[npt.NDArray[np.single]] = None,
):
# This method is deprecated in favor of using LlamaSampler directly
raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead")
def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
self.prev.append(id)
class CustomSampler:
def __init__(
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
):
self.apply_func = apply_func
def apply_wrapper(
sampler: llama_cpp.llama_sampler_p,
cur_p: llama_cpp.llama_token_data_array_p,
):
self.apply_func(cur_p)
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
pass
sampler_i = llama_cpp.llama_sampler_i()
sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
self._apply_wrapper_ref = apply_wrapper
sampler_i.name = llama_cpp.llama_sampler_i_name(0)
sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
sampler_i.free = llama_cpp.llama_sampler_i_free(0)
self.sampler = llama_cpp.llama_sampler()
self.sampler.iface = ctypes.pointer(sampler_i)
self.sampler.ctx = None
def get_sampler(self) -> llama_cpp.llama_sampler_p:
return ctypes.pointer(self.sampler)
class LlamaSampler:
def __init__(self):
params = llama_cpp.llama_sampler_chain_default_params()
self.sampler = llama_cpp.llama_sampler_chain_init(params)
self.custom_samplers: List[Tuple[int, CustomSampler]] = []
self._exit_stack = ExitStack()
def free_sampler():
if self.sampler is not None:
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
for i, _ in reversed(self.custom_samplers):
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
llama_cpp.llama_sampler_free(self.sampler)
self.sampler = None
self._exit_stack.callback(free_sampler)
def close(self):
self._exit_stack.close()
def __del__(self):
self.close()
def add_greedy(self):
sampler = llama_cpp.llama_sampler_init_greedy()
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_dist(self, seed: int):
sampler = llama_cpp.llama_sampler_init_dist(seed)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_softmax(self):
sampler = llama_cpp.llama_sampler_init_softmax()
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_top_k(self, k: int):
sampler = llama_cpp.llama_sampler_init_top_k(k)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_top_p(self, p: float, min_keep: int = 1):
sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_min_p(self, p: float, min_keep: int = 1):
sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_typical(self, p: float, min_keep: int = 1):
sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_temp(self, temp: float):
sampler = llama_cpp.llama_sampler_init_temp(temp)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_temp_ext(self, t: float, delta: float, exponent: float):
sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_xtc(self, p: float, t: float, min_keep: int, seed: int):
sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_top_n_sigma(self, n: float):
sampler = llama_cpp.llama_sampler_init_top_n_sigma(n)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
sampler = llama_cpp.llama_sampler_init_grammar(
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_grammar_lazy_patterns(
self,
model: LlamaModel,
grammar: LlamaGrammar,
trigger_patterns: List[str],
trigger_tokens: List[int]
):
# Convert patterns to C array
pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))()
for i, pattern in enumerate(trigger_patterns):
pattern_ptrs[i] = pattern.encode("utf-8")
# Convert tokens to C array
token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens)
sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
model.vocab,
grammar._grammar.encode("utf-8"),
grammar._root.encode("utf-8"),
pattern_ptrs,
len(trigger_patterns),
token_array,
len(trigger_tokens)
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_penalties(
self,
penalty_last_n: int,
penalty_repeat: float,
penalty_freq: float,
penalty_present: float,
):
sampler = llama_cpp.llama_sampler_init_penalties(
penalty_last_n,
penalty_repeat,
penalty_freq,
penalty_present,
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_dry(
self,
model: LlamaModel,
n_ctx_train: int,
dry_multiplier: float,
dry_base: float,
dry_allowed_length: int,
dry_penalty_last_n: int,
seq_breakers: List[str]
):
# Convert seq_breakers to C array
breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))()
for i, breaker in enumerate(seq_breakers):
breaker_ptrs[i] = breaker.encode("utf-8")
sampler = llama_cpp.llama_sampler_init_dry(
model.vocab,
n_ctx_train,
dry_multiplier,
dry_base,
dry_allowed_length,
dry_penalty_last_n,
breaker_ptrs,
len(seq_breakers)
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_logit_bias(
self,
n_vocab: int,
logit_bias: Dict[int, float]
):
# Convert logit_bias dict to C array
bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
for i, (token, bias) in enumerate(logit_bias.items()):
bias_array[i].token = token
bias_array[i].bias = bias
sampler = llama_cpp.llama_sampler_init_logit_bias(
n_vocab,
len(logit_bias),
bias_array
)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_infill(self, model: LlamaModel):
sampler = llama_cpp.llama_sampler_init_infill(model.vocab)
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
def add_custom(
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
):
custom_sampler = CustomSampler(apply_func)
sampler = custom_sampler.get_sampler()
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
self.custom_samplers.append(
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
)
def get_seed(self) -> int:
return llama_cpp.llama_sampler_get_seed(self.sampler)
def sample(self, ctx: LlamaContext, idx: int = -1) -> int:
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
def accept(self, token: int):
llama_cpp.llama_sampler_accept(self.sampler, token)
def reset(self):
llama_cpp.llama_sampler_reset(self.sampler)
def clone(self):
# NOTE: Custom samplers cannot be cloned due to Python callback limitations
if self.custom_samplers:
raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers")
cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler)
# Create a new wrapper around the cloned sampler
new_sampler = LlamaSampler.__new__(LlamaSampler)
new_sampler.sampler = cloned_sampler
new_sampler.custom_samplers = []
new_sampler._exit_stack = ExitStack()
def free_sampler():
if new_sampler.sampler is not None:
llama_cpp.llama_sampler_free(new_sampler.sampler)
new_sampler.sampler = None
new_sampler._exit_stack.callback(free_sampler)
return new_sampler

View file

@ -0,0 +1,47 @@
import sys
import ctypes
import logging
import llama_cpp
# enum ggml_log_level {
# GGML_LOG_LEVEL_NONE = 0,
# GGML_LOG_LEVEL_INFO = 1,
# GGML_LOG_LEVEL_WARN = 2,
# GGML_LOG_LEVEL_ERROR = 3,
# GGML_LOG_LEVEL_DEBUG = 4,
# GGML_LOG_LEVEL_CONT = 5, // continue previous log
# };
GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
0: logging.CRITICAL,
1: logging.INFO,
2: logging.WARNING,
3: logging.ERROR,
4: logging.DEBUG,
5: logging.DEBUG,
}
logger = logging.getLogger("llama-cpp-python")
_last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
# typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
@llama_cpp.llama_log_callback
def llama_log_callback(
level: int,
text: bytes,
user_data: ctypes.c_void_p,
):
# TODO: Correctly implement continue previous log
global _last_log_level
log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
_last_log_level = log_level
llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
def set_verbose(verbose: bool):
logger.setLevel(logging.DEBUG if verbose else logging.ERROR)

View file

@ -0,0 +1,78 @@
import os
import sys
from typing import Any, Dict
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
outnull_file = open(os.devnull, "w")
errnull_file = open(os.devnull, "w")
STDOUT_FILENO = 1
STDERR_FILENO = 2
class suppress_stdout_stderr(object):
# NOTE: these must be "saved" here to avoid exceptions when using
# this context manager inside of a __del__ method
sys = sys
os = os
def __init__(self, disable: bool = True):
self.disable = disable
# Oddly enough this works better than the contextlib version
def __enter__(self):
if self.disable:
return self
self.old_stdout_fileno_undup = STDOUT_FILENO
self.old_stderr_fileno_undup = STDERR_FILENO
self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
self.old_stdout = self.sys.stdout
self.old_stderr = self.sys.stderr
self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
self.sys.stdout = outnull_file
self.sys.stderr = errnull_file
return self
def __exit__(self, *_):
if self.disable:
return
# Check if sys.stdout and sys.stderr have fileno method
self.sys.stdout = self.old_stdout
self.sys.stderr = self.old_stderr
self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
self.os.close(self.old_stdout_fileno)
self.os.close(self.old_stderr_fileno)
class MetaSingleton(type):
"""
Metaclass for implementing the Singleton pattern.
"""
_instances: Dict[type, Any] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Singleton(object, metaclass=MetaSingleton):
"""
Base class for implementing the Singleton pattern.
"""
def __init__(self):
super(Singleton, self).__init__()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,155 @@
import sys
from abc import ABC, abstractmethod
from typing import (
Optional,
Sequence,
Tuple,
)
from collections import OrderedDict
import diskcache
import llama_cpp.llama
from .llama_types import *
class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""
def __init__(self, capacity_bytes: int = (2 << 30)):
self.capacity_bytes = capacity_bytes
@property
@abstractmethod
def cache_size(self) -> int:
raise NotImplementedError
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
pass
@abstractmethod
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
raise NotImplementedError
@abstractmethod
def __contains__(self, key: Sequence[int]) -> bool:
raise NotImplementedError
@abstractmethod
def __setitem__(
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
) -> None:
raise NotImplementedError
class LlamaRAMCache(BaseLlamaCache):
"""Cache for a llama.cpp model using RAM."""
def __init__(self, capacity_bytes: int = (2 << 30)):
super().__init__(capacity_bytes)
self.capacity_bytes = capacity_bytes
self.cache_state: OrderedDict[
Tuple[int, ...], "llama_cpp.llama.LlamaState"
] = OrderedDict()
@property
def cache_size(self):
return sum([state.llama_state_size for state in self.cache_state.values()])
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key = None
keys = (
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
for k in self.cache_state.keys()
)
for k, prefix_len in keys:
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
key = tuple(key)
if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
self.cache_state.popitem(last=False)
# Alias for backwards compatibility
LlamaCache = LlamaRAMCache
class LlamaDiskCache(BaseLlamaCache):
"""Cache for a llama.cpp model using disk."""
def __init__(
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
):
super().__init__(capacity_bytes)
self.cache = diskcache.Cache(cache_dir)
@property
def cache_size(self):
return int(self.cache.volume()) # type: ignore
def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
min_len = 0
min_key: Optional[Tuple[int, ...]] = None
for k in self.cache.iterkeys(): # type: ignore
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
if prefix_len > min_len:
min_len = prefix_len
min_key = k # type: ignore
return min_key
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
# NOTE: This puts an integer as key in cache, which breaks,
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
# self.cache.push(_key, side="front") # type: ignore
return value
def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
key = tuple(key)
if key in self.cache:
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
del self.cache[key]
self.cache[key] = value
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
key_to_remove = next(iter(self.cache))
del self.cache[key_to_remove]
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,953 @@
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
# flake8: noqa
from pathlib import Path
from itertools import groupby
from typing import (
Any,
Set,
List,
Optional,
Tuple,
Union,
)
LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
class LlamaGrammar:
def __init__(self, *args, _grammar: str, **kwargs):
self._grammar = _grammar
self._root = LLAMA_GRAMMAR_DEFAULT_ROOT
@classmethod
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
return cls(_grammar=grammar)
@classmethod
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
try:
with open(file) as f:
grammar = f.read()
except Exception as err:
raise Exception(
f"{cls.from_file.__name__}: error reading grammar file: {err}"
)
if grammar:
return cls.from_string(grammar, verbose=verbose)
raise ValueError(
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
)
@classmethod
def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
ARITHMETIC_GBNF = r"""
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
"""
C_GBNF = r"""
root ::= (declaration)*
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
dataType ::= "int" ws | "float" ws | "char" ws
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
parameter ::= dataType identifier
statement ::=
( dataType identifier ws "=" ws expression ";" ) |
( identifier ws "=" ws expression ";" ) |
( identifier ws "(" argList? ")" ";" ) |
( "return" ws expression ";" ) |
( "while" "(" condition ")" "{" statement* "}" ) |
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
( singleLineComment ) |
( multiLineComment )
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
forUpdate ::= identifier ws "=" ws expression
condition ::= expression relationOperator expression
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
expression ::= term (("+" | "-") term)*
term ::= factor(("*" | "/") factor)*
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
unaryTerm ::= "-" factor
funcCall ::= identifier "(" argList? ")"
parenExpression ::= "(" ws expression ws ")"
argList ::= expression ("," ws expression)*
number ::= [0-9]+
singleLineComment ::= "//" [^\n]* "\n"
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
ws ::= ([ \t\n]+)
"""
CHESS_GBNF = r"""
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""
JAPANESE_GBNF = r"""
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""
JSON_ARR_GBNF = r"""
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
# Useful for generating JSON arrays
root ::= arr
value ::= object | array | string | number | ("true" | "false" | "null") ws
arr ::=
"[\n" ws (
value
(",\n" ws value)*
)? "]"
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""
JSON_GBNF = r"""
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= | " " | "\n" [ \t]{0,20}
"""
LIST_GBNF = r"""
root ::= item+
# Excludes various line break characters
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
"""
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
import json
import re
from typing import List, Optional
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
def _build_repetition(
item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False
):
if not separator_rule:
if min_items == 0 and max_items == 1:
return f"{item_rule}?"
elif min_items == 1 and max_items is None:
return f"{item_rule}+"
result = ""
if min_items > 0:
if item_rule_is_literal and separator_rule is None:
result = '"' + (item_rule[1:-1] * min_items) + '"'
else:
result = (f" {separator_rule} " if separator_rule else " ").join(
[item_rule] * min_items
)
def opt_repetitions(up_to_n, prefix_with_sep=False):
"""
- n=4, no sep: '(a (a (a (a)?)?)?)?'
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
"""
content = (
f"{separator_rule} {item_rule}"
if prefix_with_sep and separator_rule
else item_rule
)
if up_to_n == 0:
return ""
elif up_to_n == 1:
return f"({content})?"
elif separator_rule and not prefix_with_sep:
return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?"
else:
return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n)
if min_items > 0 and max_items != min_items:
result += " "
if max_items is not None:
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
else:
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
if min_items == 0 and separator_rule:
result = f"({item_rule} {item_operator}*)?"
else:
result += f"{item_operator}*"
return result
class BuiltinRule:
def __init__(self, content: str, deps: list = None):
self.content = content
self.deps = deps or []
_up_to_15_digits = _build_repetition("[0-9]", 0, 15)
PRIMITIVE_RULES = {
"boolean": BuiltinRule('("true" | "false") space', []),
"decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []),
"integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []),
"number": BuiltinRule(
'("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
["integral-part", "decimal-part"],
),
"integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]),
"value": BuiltinRule(
"object | array | string | number | boolean | null",
["object", "array", "string", "number", "boolean", "null"],
),
"object": BuiltinRule(
'"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
["string", "value"],
),
"array": BuiltinRule(
'"[" space ( value ("," space value)* )? "]" space', ["value"]
),
"uuid": BuiltinRule(
r'"\"" '
+ ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12])
+ r' "\"" space',
[],
),
"char": BuiltinRule(
r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])',
[],
),
"string": BuiltinRule(r'"\"" char* "\"" space', ["char"]),
"null": BuiltinRule('"null" space', []),
}
# TODO: support "uri", "email" string formats
STRING_FORMAT_RULES = {
"date": BuiltinRule(
'[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
[],
),
"time": BuiltinRule(
'([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
[],
),
"date-time": BuiltinRule('date "T" time', ["date", "time"]),
"date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]),
"time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]),
"date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]),
}
DOTALL = "[\\U00000000-\\U0010FFFF]"
DOT = "[^\\x0A\\x0D]"
RESERVED_NAMES = set(
["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]
)
NON_LITERAL_SET = set("|.()[]{}*+?")
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?")
class SchemaConverter:
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {
"space": SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
)
return f'"{escaped}"'
def not_literal(
self, literal: str, dotall: bool = True, maybe_escaped_underscores=False
) -> str:
"""
not_literal('a') -> '[^a]'
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
"""
assert len(literal) > 0, "Empty literal not supported"
def recurse(i: int):
c = literal[i]
if maybe_escaped_underscores and c == "_":
yield f"[^{c}\\\\]"
yield " | "
yield f'"\\\\"? "{c}"'
else:
yield f"[^{c}]"
if i < len(literal) - 1:
yield " | "
yield self._format_literal(c)
yield " ("
yield from recurse(i + 1)
yield ")?"
return "".join(("(", *recurse(0), ")"))
def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
key = esc_name
else:
i = 0
while (
f"{esc_name}{i}" in self._rules
and self._rules[f"{esc_name}{i}"] != rule
):
i += 1
key = f"{esc_name}{i}"
self._rules[key] = rule
return key
def resolve_refs(self, schema: dict, url: str):
"""
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
"""
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get("$ref")
if ref is not None and ref not in self._refs:
if ref.startswith("https://"):
assert (
self._allow_fetch
), "Fetching remote schemas is not allowed (use --allow-fetch for force)"
import requests
frag_split = ref.split("#")
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(
requests.get(ref).json(), base_url
)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == "":
return target
elif ref.startswith("#/"):
target = schema
ref = f"{url}{ref}"
n["$ref"] = ref
else:
raise ValueError(f"Unsupported ref {ref}")
for sel in ref.split("#")[-1].split("/")[1:]:
assert (
target is not None and sel in target
), f"Error resolving ref {ref}: {sel} not in {target}"
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return " | ".join(
(
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
for i, alt_schema in enumerate(alt_schemas)
)
)
def _visit_pattern(self, pattern, name):
"""
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
"""
assert pattern.startswith("^") and pattern.endswith(
"$"
), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
i = 0
length = len(pattern)
def to_rule(s: Tuple[str, bool]) -> str:
(txt, is_literal) = s
return '"' + txt + '"' if is_literal else txt
def transform() -> Tuple[str, bool]:
"""
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
"""
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[Tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = DOTALL
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = DOT
return self._add_rule(f"dot", rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in groupby(seq, lambda x: x[1]):
if is_literal:
ret.append(("".join(x[0] for x in g), True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (" ".join(to_rule(x) for x in seq), False)
while i < length:
c = pattern[i]
if c == ".":
seq.append((get_dot(), False))
i += 1
elif c == "(":
i += 1
if i < length:
assert (
pattern[i] != "?"
), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f"({to_rule(transform())})", False))
elif c == ")":
i += 1
assert (
start > 0 and pattern[start - 1] == "("
), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}"
return join_seq()
elif c == "[":
square_brackets = c
i += 1
while i < length and pattern[i] != "]":
if pattern[i] == "\\":
square_brackets += pattern[i : i + 2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert (
i < length
), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}"
square_brackets += "]"
i += 1
seq.append((square_brackets, False))
elif c == "|":
seq.append(("|", False))
i += 1
elif c in ("*", "+", "?"):
seq[-1] = (to_rule(seq[-1]) + c, False)
i += 1
elif c == "{":
curly_brackets = c
i += 1
while i < length and pattern[i] != "}":
curly_brackets += pattern[i]
i += 1
assert (
i < length
), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}"
curly_brackets += "}"
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(",")]
min_times = 0
max_times = None
try:
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
except ValueError:
raise ValueError(
f"Invalid quantifier {curly_brackets} in /{pattern}/"
)
(sub, sub_is_literal) = seq[-1]
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub)
sub_rule_ids[sub] = id
sub = id
seq[-1] = (
_build_repetition(
f'"{sub}"' if sub_is_literal else sub,
min_times,
max_times,
item_rule_is_literal=sub_is_literal,
),
False,
)
else:
literal = ""
while i < length:
if pattern[i] == "\\" and i < length - 1:
next = pattern[i + 1]
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
i += 1
literal += pattern[i]
i += 1
else:
literal += pattern[i : i + 2]
i += 2
elif pattern[i] == '"' and not self._raw_pattern:
literal += '\\"'
i += 1
elif pattern[i] not in NON_LITERAL_SET and (
i == length - 1
or literal == ""
or pattern[i + 1] == "."
or pattern[i + 1] not in NON_LITERAL_SET
):
literal += pattern[i]
i += 1
else:
break
if literal:
seq.append((literal, True))
return join_seq()
return self._add_rule(
name,
(
to_rule(transform())
if self._raw_pattern
else '"\\"" ' + to_rule(transform()) + ' "\\"" space'
),
)
def _resolve_ref(self, ref):
ref_name = ref.split("/")[-1]
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))
def visit(self, schema, name):
schema_type = schema.get("type")
schema_format = schema.get("format")
rule_name = name + "-" if name in RESERVED_NAMES else name or "root"
if (ref := schema.get("$ref")) is not None:
return self._add_rule(rule_name, self._resolve_ref(ref))
elif "oneOf" in schema or "anyOf" in schema:
return self._add_rule(
rule_name,
self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]),
)
elif isinstance(schema_type, list):
return self._add_rule(
rule_name,
self._generate_union_rule(name, [{"type": t} for t in schema_type]),
)
elif "const" in schema:
return self._add_rule(
rule_name, self._generate_constant_rule(schema["const"])
)
elif "enum" in schema:
rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"]))
return self._add_rule(rule_name, rule)
elif schema_type in (None, "object") and (
"properties" in schema
or (
"additionalProperties" in schema
and schema["additionalProperties"] is not True
)
):
required = set(schema.get("required", []))
properties = list(schema.get("properties", {}).items())
return self._add_rule(
rule_name,
self._build_object_rule(
properties, required, name, schema.get("additionalProperties")
),
)
elif schema_type in (None, "object") and "allOf" in schema:
required = set()
properties = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get("$ref")) is not None:
comp_schema = self._refs[ref]
if "properties" in comp_schema:
for prop_name, prop_schema in comp_schema["properties"].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
for t in schema["allOf"]:
if "anyOf" in t:
for tt in t["anyOf"]:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
return self._add_rule(
rule_name,
self._build_object_rule(
properties, required, hybrid_name, additional_properties=[]
),
)
elif schema_type in (None, "array") and (
"items" in schema or "prefixItems" in schema
):
items = schema.get("items") or schema["prefixItems"]
if isinstance(items, list):
return self._add_rule(
rule_name,
'"[" space '
+ ' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)
)
+ ' "]" space',
)
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
return self._add_rule(
rule_name,
'"[" space '
+ _build_repetition(
item_rule_name, min_items, max_items, separator_rule='"," space'
)
+ ' "]" space',
)
elif schema_type in (None, "string") and "pattern" in schema:
return self._visit_pattern(schema["pattern"], rule_name)
elif schema_type in (None, "string") and re.match(
r"^uuid[1-5]?$", schema_format or ""
):
return self._add_primitive(
"root" if rule_name == "root" else schema_format,
PRIMITIVE_RULES["uuid"],
)
elif (
schema_type in (None, "string")
and f"{schema_format}-string" in STRING_FORMAT_RULES
):
prim_name = f"{schema_format}-string"
return self._add_rule(
rule_name,
self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]),
)
elif schema_type == "string" and (
"minLength" in schema or "maxLength" in schema
):
char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"])
min_len = schema.get("minLength", 0)
max_len = schema.get("maxLength")
return self._add_rule(
rule_name,
r'"\"" '
+ _build_repetition(char_rule, min_len, max_len)
+ r' "\"" space',
)
elif (schema_type == "object") or (len(schema) == 0):
return self._add_rule(
rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"])
)
else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_primitive(
"root" if rule_name == "root" else schema_type,
PRIMITIVE_RULES[schema_type],
)
def _add_primitive(self, name: str, rule: BuiltinRule):
n = self._add_rule(name, rule.content)
for dep in rule.deps:
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
assert dep_rule, f"Rule {dep} not known"
if dep not in self._rules:
self._add_primitive(dep, dep_rule)
return n
def _build_object_rule(
self,
properties: List[Tuple[str, Any]],
required: Set[str],
name: str,
additional_properties: Union[bool, Any],
):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [
kv[0]
for _, kv in sorted(
enumerate(properties),
key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]),
)
]
prop_kv_rule_names = {}
for prop_name, prop_schema in properties:
prop_rule_name = self.visit(
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
)
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}',
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
if additional_properties == True or isinstance(additional_properties, dict):
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit(
{} if additional_properties == True else additional_properties,
f"{sub_name}-value",
)
prop_kv_rule_names["*"] = self._add_rule(
f"{sub_name}-kv",
self._add_primitive("string", PRIMITIVE_RULES["string"])
+ f' ":" space {value_rule}',
)
optional_props.append("*")
rule = '"{" space '
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
if optional_props:
rule += " ("
if required_props:
rule += ' "," space ( '
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
if k == "*":
res = self._add_rule(
f'{name}{"-" if name else ""}additional-kvs',
f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*",
)
elif first_is_optional:
res = f'( "," space {kv_rule_name} )?'
else:
res = kv_rule_name
if len(rest) > 0:
res += " " + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
get_recursive_refs(rest, first_is_optional=True),
)
return res
rule += " | ".join(
get_recursive_refs(optional_props[i:], first_is_optional=False)
for i in range(len(optional_props))
)
if required_props:
rule += " )"
rule += " )?"
rule += ' "}" space'
return rule
def format_grammar(self):
return "\n".join(
f"{name} ::= {rule}"
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
)
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
prop_order = prop_order or []
schema = json.loads(schema)
prop_order = {name: idx for idx, name in enumerate(prop_order)}
converter = SchemaConverter(
prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False
)
schema = converter.resolve_refs(schema, "stdin")
converter.visit(schema, "")
return converter.format_grammar()

View file

@ -0,0 +1,64 @@
import abc
from typing import Any
import numpy as np
import numpy.typing as npt
class LlamaDraftModel(abc.ABC):
@abc.abstractmethod
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
raise NotImplementedError()
class LlamaPromptLookupDecoding(LlamaDraftModel):
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
self.max_ngram_size = max_ngram_size
self.num_pred_tokens = num_pred_tokens
@staticmethod
def find_candidate_pred_tokens(
input_ids: npt.NDArray[np.intc],
max_ngram_size: int,
num_pred_tokens: int,
):
input_length = input_ids.shape[0]
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
# Convert ngram to an array for comparison
ngram_array = input_ids[-ngram_size:]
# Find where the windows match the ngram
matches = np.all(windows == ngram_array, axis=1)
# Get the indices of matches
match_indices = np.nonzero(matches)[0]
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + num_pred_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
return input_ids[start_idx:end_idx]
# If no match is found, return an empty array
return np.array([], dtype=np.intc)
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
return self.find_candidate_pred_tokens(
input_ids=input_ids,
max_ngram_size=self.max_ngram_size,
num_pred_tokens=self.num_pred_tokens,
)

View file

@ -0,0 +1,120 @@
from __future__ import annotations
import abc
from typing import (
List,
Optional,
Any,
)
import llama_cpp
from llama_cpp.llama_types import List
class BaseLlamaTokenizer(abc.ABC):
@abc.abstractmethod
def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = True
) -> List[int]:
"""Tokenize the text into tokens.
Args:
text: The utf-8 encoded string to tokenize.
add_bos: Whether to add a beginning of sequence token.
special: Whether to tokenize special tokens.
"""
raise NotImplementedError
@abc.abstractmethod
def detokenize(
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
"""Detokenize the tokens into text.
Args:
tokens: The list of tokens to detokenize.
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
special: Whether to detokenize special tokens.
"""
raise NotImplementedError
class LlamaTokenizer(BaseLlamaTokenizer):
def __init__(self, llama: llama_cpp.Llama):
self._model = llama._model # type: ignore
def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = True
) -> List[int]:
return self._model.tokenize(text, add_bos=add_bos, special=special)
def detokenize(
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
return self._model.detokenize(tokens, special=special)
def encode(
self, text: str, add_bos: bool = True, special: bool = True
) -> List[int]:
return self.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
)
def decode(self, tokens: List[int]) -> str:
return self.detokenize(tokens).decode("utf-8", errors="ignore")
@classmethod
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
class LlamaHFTokenizer(BaseLlamaTokenizer):
def __init__(self, hf_tokenizer: Any):
self.hf_tokenizer = hf_tokenizer
def tokenize(
self, text: bytes, add_bos: bool = True, special: bool = True
) -> List[int]:
return self.hf_tokenizer.encode(
text.decode("utf-8", errors="ignore"), add_special_tokens=special
)
def detokenize(
self,
tokens: List[int],
prev_tokens: Optional[List[int]] = None,
special: bool = False,
) -> bytes:
skip_special_tokens = not special
if prev_tokens is not None:
text = self.hf_tokenizer.decode(
prev_tokens + tokens, skip_special_tokens=skip_special_tokens
).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(
prev_tokens, skip_special_tokens=skip_special_tokens
).encode("utf-8", errors="ignore")
return text[len(prev_text) :]
else:
return self.hf_tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens
).encode("utf-8", errors="ignore")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library is required to use the `HFTokenizer`."
"You can install it with `pip install transformers`."
)
hf_tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path
)
return cls(hf_tokenizer)

View file

@ -0,0 +1,316 @@
"""Types and request signatures for OpenAI compatibility
NOTE: These types may change to match the OpenAI OpenAPI specification.
Based on the OpenAI OpenAPI specification:
https://github.com/openai/openai-openapi/blob/master/openapi.yaml
"""
from typing import Any, List, Optional, Dict, Union
from typing_extensions import TypedDict, NotRequired, Literal
# NOTE: Defining this correctly using annotations seems to break pydantic validation.
# This is a workaround until we can figure out how to do this correctly
# JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]]
class EmbeddingUsage(TypedDict):
prompt_tokens: int
total_tokens: int
class Embedding(TypedDict):
index: int
object: str
embedding: Union[List[float], List[List[float]]]
class CreateEmbeddingResponse(TypedDict):
object: Literal["list"]
model: str
data: List[Embedding]
usage: EmbeddingUsage
class CompletionLogprobs(TypedDict):
text_offset: List[int]
token_logprobs: List[Optional[float]]
tokens: List[str]
top_logprobs: List[Optional[Dict[str, float]]]
class CompletionChoice(TypedDict):
text: str
index: int
logprobs: Optional[CompletionLogprobs]
finish_reason: Optional[Literal["stop", "length"]]
class CompletionUsage(TypedDict):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CreateCompletionResponse(TypedDict):
id: str
object: Literal["text_completion"]
created: int
model: str
choices: List[CompletionChoice]
usage: NotRequired[CompletionUsage]
class ChatCompletionResponseFunctionCall(TypedDict):
name: str
arguments: str
class ChatCompletionResponseMessage(TypedDict):
content: Optional[str]
tool_calls: NotRequired["ChatCompletionMessageToolCalls"]
role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here
function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED
class ChatCompletionFunction(TypedDict):
name: str
description: NotRequired[str]
parameters: Dict[str, JsonType] # TODO: make this more specific
class ChatCompletionTopLogprobToken(TypedDict):
token: str
logprob: float
bytes: Optional[List[int]]
class ChatCompletionLogprobToken(ChatCompletionTopLogprobToken):
token: str
logprob: float
bytes: Optional[List[int]]
top_logprobs: List[ChatCompletionTopLogprobToken]
class ChatCompletionLogprobs(TypedDict):
content: Optional[List[ChatCompletionLogprobToken]]
refusal: Optional[List[ChatCompletionLogprobToken]]
class ChatCompletionResponseChoice(TypedDict):
index: int
message: "ChatCompletionResponseMessage"
logprobs: Optional[ChatCompletionLogprobs]
finish_reason: Optional[str]
class CreateChatCompletionResponse(TypedDict):
id: str
object: Literal["chat.completion"]
created: int
model: str
choices: List["ChatCompletionResponseChoice"]
usage: CompletionUsage
class ChatCompletionMessageToolCallChunkFunction(TypedDict):
name: Optional[str]
arguments: str
class ChatCompletionMessageToolCallChunk(TypedDict):
index: int
id: NotRequired[str]
type: Literal["function"]
function: ChatCompletionMessageToolCallChunkFunction
class ChatCompletionStreamResponseDeltaEmpty(TypedDict):
pass
class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict):
name: str
arguments: str
class ChatCompletionStreamResponseDelta(TypedDict):
content: NotRequired[Optional[str]]
function_call: NotRequired[
Optional[ChatCompletionStreamResponseDeltaFunctionCall]
] # DEPRECATED
tool_calls: NotRequired[Optional[List[ChatCompletionMessageToolCallChunk]]]
role: NotRequired[Optional[Literal["system", "user", "assistant", "tool"]]]
class ChatCompletionStreamResponseChoice(TypedDict):
index: int
delta: Union[
ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
]
finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
class CreateChatCompletionStreamResponse(TypedDict):
id: str
model: str
object: Literal["chat.completion.chunk"]
created: int
choices: List[ChatCompletionStreamResponseChoice]
class ChatCompletionFunctions(TypedDict):
name: str
description: NotRequired[str]
parameters: Dict[str, JsonType] # TODO: make this more specific
class ChatCompletionFunctionCallOption(TypedDict):
name: str
class ChatCompletionRequestResponseFormat(TypedDict):
type: Literal["text", "json_object"]
schema: NotRequired[
JsonType
] # https://docs.endpoints.anyscale.com/guides/json_mode/
class ChatCompletionRequestMessageContentPartText(TypedDict):
type: Literal["text"]
text: str
class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict):
url: str
detail: NotRequired[Literal["auto", "low", "high"]]
class ChatCompletionRequestMessageContentPartImage(TypedDict):
type: Literal["image_url"]
image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl]
ChatCompletionRequestMessageContentPart = Union[
ChatCompletionRequestMessageContentPartText,
ChatCompletionRequestMessageContentPartImage,
]
class ChatCompletionRequestSystemMessage(TypedDict):
role: Literal["system"]
content: Optional[str]
class ChatCompletionRequestUserMessage(TypedDict):
role: Literal["user"]
content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]]
class ChatCompletionMessageToolCallFunction(TypedDict):
name: str
arguments: str
class ChatCompletionMessageToolCall(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionMessageToolCallFunction
ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall]
class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict):
name: str
arguments: str
class ChatCompletionRequestAssistantMessage(TypedDict):
role: Literal["assistant"]
content: NotRequired[str]
tool_calls: NotRequired[ChatCompletionMessageToolCalls]
function_call: NotRequired[
ChatCompletionRequestAssistantMessageFunctionCall
] # DEPRECATED
class ChatCompletionRequestToolMessage(TypedDict):
role: Literal["tool"]
content: Optional[str]
tool_call_id: str
class ChatCompletionRequestFunctionMessage(TypedDict):
role: Literal["function"]
content: Optional[str]
name: str
ChatCompletionRequestMessage = Union[
ChatCompletionRequestSystemMessage,
ChatCompletionRequestUserMessage,
ChatCompletionRequestAssistantMessage,
ChatCompletionRequestUserMessage,
ChatCompletionRequestToolMessage,
ChatCompletionRequestFunctionMessage,
]
class ChatCompletionRequestFunctionCallOption(TypedDict):
name: str
ChatCompletionRequestFunctionCall = Union[
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
]
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
class ChatCompletionToolFunction(TypedDict):
name: str
description: NotRequired[str]
parameters: ChatCompletionFunctionParameters
class ChatCompletionTool(TypedDict):
type: Literal["function"]
function: ChatCompletionToolFunction
class ChatCompletionNamedToolChoiceFunction(TypedDict):
name: str
class ChatCompletionNamedToolChoice(TypedDict):
type: Literal["function"]
function: ChatCompletionNamedToolChoiceFunction
ChatCompletionToolChoiceOption = Union[
Literal["none", "auto", "required"], ChatCompletionNamedToolChoice
]
# NOTE: The following type names are not part of the OpenAI OpenAPI specification
# and will be removed in a future major release.
EmbeddingData = Embedding
CompletionChunk = CreateCompletionResponse
Completion = CreateCompletionResponse
CreateCompletionStreamResponse = CreateCompletionResponse
ChatCompletionMessage = ChatCompletionResponseMessage
ChatCompletionChoice = ChatCompletionResponseChoice
ChatCompletion = CreateChatCompletionResponse
ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
ChatCompletionChunk = CreateChatCompletionStreamResponse
ChatCompletionStreamResponse = CreateChatCompletionStreamResponse
ChatCompletionResponseFunction = ChatCompletionFunction
ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall

View file

@ -0,0 +1,158 @@
from __future__ import annotations
import os
from ctypes import (
c_bool,
c_char_p,
c_int,
c_uint8,
c_float,
c_void_p,
POINTER,
_Pointer, # type: ignore
Structure,
)
import pathlib
from typing import (
Union,
NewType,
Optional,
TYPE_CHECKING,
)
import llama_cpp.llama_cpp as llama_cpp
from llama_cpp._ctypes_extensions import (
load_shared_library,
ctypes_function_for_shared_library,
)
if TYPE_CHECKING:
from llama_cpp._ctypes_extensions import (
CtypesArray,
)
# Specify the base name of the shared library to load
_libllava_base_name = "llava"
_libllava_override_path = os.environ.get("LLAVA_CPP_LIB")
_libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path()
# Load the library
_libllava = load_shared_library(_libllava_base_name, _libllava_base_path)
ctypes_function = ctypes_function_for_shared_library(_libllava)
################################################
# llava.h
################################################
# struct clip_ctx;
clip_ctx_p = NewType("clip_ctx_p", int)
clip_ctx_p_ctypes = c_void_p
# struct llava_image_embed {
# float * embed;
# int n_image_pos;
# };
class llava_image_embed(Structure):
_fields_ = [
("embed", POINTER(c_float)),
("n_image_pos", c_int),
]
# /** sanity check for clip <-> llava embed size match */
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
@ctypes_function(
"llava_validate_embed_size",
[llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
c_bool,
)
def llava_validate_embed_size(
ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
) -> bool:
...
# /** build an image embed from image file bytes */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
@ctypes_function(
"llava_image_embed_make_with_bytes",
[clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
POINTER(llava_image_embed),
)
def llava_image_embed_make_with_bytes(
ctx_clip: clip_ctx_p,
n_threads: Union[c_int, int],
image_bytes: CtypesArray[c_uint8],
image_bytes_length: Union[c_int, int],
/,
) -> "_Pointer[llava_image_embed]":
...
# /** build an image embed from a path to an image filename */
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
@ctypes_function(
"llava_image_embed_make_with_filename",
[clip_ctx_p_ctypes, c_int, c_char_p],
POINTER(llava_image_embed),
)
def llava_image_embed_make_with_filename(
ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
) -> "_Pointer[llava_image_embed]":
...
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
# /** free an embedding made with llava_image_embed_make_* */
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
...
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
@ctypes_function(
"llava_eval_image_embed",
[
llama_cpp.llama_context_p_ctypes,
POINTER(llava_image_embed),
c_int,
POINTER(c_int),
],
c_bool,
)
def llava_eval_image_embed(
ctx_llama: llama_cpp.llama_context_p,
embed: "_Pointer[llava_image_embed]",
n_batch: Union[c_int, int],
n_past: "_Pointer[c_int]",
/,
) -> bool:
...
################################################
# clip.h
################################################
# /** load mmproj model */
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
def clip_model_load(
fname: bytes, verbosity: Union[c_int, int], /
) -> Optional[clip_ctx_p]:
...
# /** free mmproj model */
# CLIP_API void clip_free(struct clip_ctx * ctx);
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
def clip_free(ctx: clip_ctx_p, /):
...

View file

@ -0,0 +1,280 @@
from __future__ import annotations
import os
from ctypes import (
c_bool,
c_char_p,
c_int,
c_uint8,
c_uint32,
c_float,
c_void_p,
c_size_t,
POINTER,
_Pointer, # type: ignore
Structure,
byref,
)
import pathlib
from typing import (
Union,
NewType,
Optional,
TYPE_CHECKING,
)
import llama_cpp.llama_cpp as llama_cpp
from llama_cpp._ctypes_extensions import (
load_shared_library,
ctypes_function_for_shared_library,
)
if TYPE_CHECKING:
from llama_cpp._ctypes_extensions import (
CtypesArray,
)
# Specify the base name of the shared library to load
_libmtmd_base_name = "mtmd"
_libmtmd_override_path = os.environ.get("MTMD_CPP_LIB")
_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path()
# Load the library
_libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path)
ctypes_function = ctypes_function_for_shared_library(_libmtmd)
################################################
# mtmd.h types
################################################
# Opaque types
mtmd_context_p = NewType("mtmd_context_p", int)
mtmd_context_p_ctypes = c_void_p
mtmd_bitmap_p = NewType("mtmd_bitmap_p", int)
mtmd_bitmap_p_ctypes = c_void_p
mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int)
mtmd_image_tokens_p_ctypes = c_void_p
mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int)
mtmd_input_chunk_p_ctypes = c_void_p
mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int)
mtmd_input_chunks_p_ctypes = c_void_p
# Enums
MTMD_INPUT_CHUNK_TYPE_TEXT = 0
MTMD_INPUT_CHUNK_TYPE_IMAGE = 1
MTMD_INPUT_CHUNK_TYPE_AUDIO = 2
# Structures
class mtmd_context_params(Structure):
_fields_ = [
("use_gpu", c_bool),
("print_timings", c_bool),
("n_threads", c_int),
("verbosity", c_int), # ggml_log_level
("image_marker", c_char_p),
("media_marker", c_char_p),
]
class mtmd_input_text(Structure):
_fields_ = [
("text", c_char_p),
("add_special", c_bool),
("parse_special", c_bool),
]
################################################
# mtmd.h functions
################################################
# MTMD_API const char * mtmd_default_marker(void);
@ctypes_function("mtmd_default_marker", [], c_char_p)
def mtmd_default_marker() -> bytes:
...
# MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
@ctypes_function("mtmd_context_params_default", [], mtmd_context_params)
def mtmd_context_params_default() -> mtmd_context_params:
...
# MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
# const struct llama_model * text_model,
# const struct mtmd_context_params ctx_params);
@ctypes_function(
"mtmd_init_from_file",
[c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params],
mtmd_context_p_ctypes
)
def mtmd_init_from_file(
mmproj_fname: bytes,
text_model: llama_cpp.llama_model_p,
ctx_params: mtmd_context_params,
/,
) -> Optional[mtmd_context_p]:
...
# MTMD_API void mtmd_free(mtmd_context * ctx);
@ctypes_function("mtmd_free", [mtmd_context_p_ctypes], None)
def mtmd_free(ctx: mtmd_context_p, /):
...
# MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
@ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool)
def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool:
...
# MTMD_API mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, uint32_t ny, const unsigned char * data);
@ctypes_function(
"mtmd_bitmap_init",
[c_uint32, c_uint32, POINTER(c_uint8)],
mtmd_bitmap_p_ctypes
)
def mtmd_bitmap_init(
nx: Union[c_uint32, int],
ny: Union[c_uint32, int],
data: CtypesArray[c_uint8],
/,
) -> Optional[mtmd_bitmap_p]:
...
# MTMD_API void mtmd_bitmap_free(mtmd_bitmap * bitmap);
@ctypes_function("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None)
def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /):
...
# MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void);
@ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes)
def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]:
...
# MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
@ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None)
def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /):
...
# MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks);
@ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t)
def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int:
...
# MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx);
@ctypes_function(
"mtmd_input_chunks_get",
[mtmd_input_chunks_p_ctypes, c_size_t],
mtmd_input_chunk_p_ctypes
)
def mtmd_input_chunks_get(
chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], /
) -> Optional[mtmd_input_chunk_p]:
...
# MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
# mtmd_input_chunks * output,
# const mtmd_input_text * text,
# const mtmd_bitmap ** bitmaps,
# size_t n_bitmaps);
@ctypes_function(
"mtmd_tokenize",
[
mtmd_context_p_ctypes,
mtmd_input_chunks_p_ctypes,
POINTER(mtmd_input_text),
POINTER(mtmd_bitmap_p_ctypes),
c_size_t,
],
c_int,
)
def mtmd_tokenize(
ctx: mtmd_context_p,
output: mtmd_input_chunks_p,
text: "_Pointer[mtmd_input_text]",
bitmaps: CtypesArray[mtmd_bitmap_p_ctypes],
n_bitmaps: Union[c_size_t, int],
/,
) -> int:
...
# MTMD_API size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk);
@ctypes_function("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t)
def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int:
...
# MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk);
@ctypes_function("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int)
def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int:
...
# MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output);
@ctypes_function(
"mtmd_input_chunk_get_tokens_text",
[mtmd_input_chunk_p_ctypes, POINTER(c_size_t)],
POINTER(llama_cpp.llama_token)
)
def mtmd_input_chunk_get_tokens_text(
chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", /
) -> Optional["_Pointer[llama_cpp.llama_token]"]:
...
################################################
# mtmd-helper.h functions
################################################
# MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len);
@ctypes_function(
"mtmd_helper_bitmap_init_from_buf",
[mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t],
mtmd_bitmap_p_ctypes
)
def mtmd_helper_bitmap_init_from_buf(
ctx: mtmd_context_p,
buf: CtypesArray[c_uint8],
length: Union[c_size_t, int],
/,
) -> Optional[mtmd_bitmap_p]:
...
# MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
@ctypes_function("mtmd_helper_get_n_tokens", [mtmd_input_chunks_p_ctypes], c_size_t)
def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int:
...
# MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
# struct llama_context * lctx,
# const mtmd_input_chunk * chunk,
# llama_pos n_past,
# llama_seq_id seq_id,
# int32_t n_batch,
# bool logits_last,
# llama_pos * new_n_past);
@ctypes_function(
"mtmd_helper_eval_chunk_single",
[
mtmd_context_p_ctypes,
llama_cpp.llama_context_p_ctypes,
mtmd_input_chunk_p_ctypes,
llama_cpp.llama_pos,
llama_cpp.llama_seq_id,
c_int,
c_bool,
POINTER(llama_cpp.llama_pos),
],
c_int,
)
def mtmd_helper_eval_chunk_single(
ctx: mtmd_context_p,
lctx: llama_cpp.llama_context_p,
chunk: mtmd_input_chunk_p,
n_past: llama_cpp.llama_pos,
seq_id: llama_cpp.llama_seq_id,
n_batch: Union[c_int, int],
logits_last: Union[c_bool, bool],
new_n_past: "_Pointer[llama_cpp.llama_pos]",
/,
) -> int:
...

View file

@ -0,0 +1,100 @@
"""Example FastAPI server for llama.cpp.
To run this example:
```bash
pip install fastapi uvicorn sse-starlette pydantic-settings
export MODEL=../models/7B/...
```
Then run:
```
uvicorn llama_cpp.server.app:create_app --reload
```
or
```
python3 -m llama_cpp.server
```
Then visit http://localhost:8000/docs to see the interactive API docs.
"""
from __future__ import annotations
import os
import sys
import argparse
import uvicorn
from llama_cpp.server.app import create_app
from llama_cpp.server.settings import (
Settings,
ServerSettings,
ModelSettings,
ConfigFileSettings,
)
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
def main():
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
parser = argparse.ArgumentParser(description=description)
add_args_from_model(parser, Settings)
parser.add_argument(
"--config_file",
type=str,
help="Path to a config file to load.",
)
server_settings: ServerSettings | None = None
model_settings: list[ModelSettings] = []
args = parser.parse_args()
try:
# Load server settings from config_file if provided
config_file = os.environ.get("CONFIG_FILE", args.config_file)
if config_file:
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
# Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
import json
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(
f.read()
)
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
else:
server_settings = parse_model_from_args(ServerSettings, args)
model_settings = [parse_model_from_args(ModelSettings, args)]
except Exception as e:
print(e, file=sys.stderr)
parser.print_help()
sys.exit(1)
assert server_settings is not None
assert model_settings is not None
app = create_app(
server_settings=server_settings,
model_settings=model_settings,
)
uvicorn.run(
app,
host=os.getenv("HOST", server_settings.host),
port=int(os.getenv("PORT", server_settings.port)),
ssl_keyfile=server_settings.ssl_keyfile,
ssl_certfile=server_settings.ssl_certfile,
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,597 @@
from __future__ import annotations
import os
import json
import typing
import contextlib
from anyio import Lock
from functools import partial
from typing import List, Optional, Union, Dict
import llama_cpp
import anyio
from anyio.streams.memory import MemoryObjectSendStream
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer
from sse_starlette.sse import EventSourceResponse
from starlette_context.plugins import RequestIdPlugin # type: ignore
from starlette_context.middleware import RawContextMiddleware
from llama_cpp.server.model import (
LlamaProxy,
)
from llama_cpp.server.settings import (
ConfigFileSettings,
Settings,
ModelSettings,
ServerSettings,
)
from llama_cpp.server.types import (
CreateCompletionRequest,
CreateEmbeddingRequest,
CreateChatCompletionRequest,
ModelList,
TokenizeInputRequest,
TokenizeInputResponse,
TokenizeInputCountResponse,
DetokenizeInputRequest,
DetokenizeInputResponse,
)
from llama_cpp.server.errors import RouteErrorHandler
router = APIRouter(route_class=RouteErrorHandler)
_server_settings: Optional[ServerSettings] = None
def set_server_settings(server_settings: ServerSettings):
global _server_settings
_server_settings = server_settings
def get_server_settings():
yield _server_settings
_llama_proxy: Optional[LlamaProxy] = None
llama_outer_lock = Lock()
llama_inner_lock = Lock()
def set_llama_proxy(model_settings: List[ModelSettings]):
global _llama_proxy
_llama_proxy = LlamaProxy(models=model_settings)
async def get_llama_proxy():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
await llama_outer_lock.acquire()
release_outer_lock = True
try:
await llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
yield _llama_proxy
finally:
llama_inner_lock.release()
finally:
if release_outer_lock:
llama_outer_lock.release()
_ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None
def set_ping_message_factory(factory: typing.Callable[[], bytes]):
global _ping_message_factory
_ping_message_factory = factory
def create_app(
settings: Settings | None = None,
server_settings: ServerSettings | None = None,
model_settings: List[ModelSettings] | None = None,
):
config_file = os.environ.get("CONFIG_FILE", None)
if config_file is not None:
if not os.path.exists(config_file):
raise ValueError(f"Config file {config_file} not found!")
with open(config_file, "rb") as f:
# Check if yaml file
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
import yaml
config_file_settings = ConfigFileSettings.model_validate_json(
json.dumps(yaml.safe_load(f))
)
else:
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
server_settings = ServerSettings.model_validate(config_file_settings)
model_settings = config_file_settings.models
if server_settings is None and model_settings is None:
if settings is None:
settings = Settings()
server_settings = ServerSettings.model_validate(settings)
model_settings = [ModelSettings.model_validate(settings)]
assert (
server_settings is not None and model_settings is not None
), "server_settings and model_settings must be provided together"
set_server_settings(server_settings)
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
app = FastAPI(
middleware=middleware,
title="🦙 llama.cpp Python API",
version=llama_cpp.__version__,
root_path=server_settings.root_path,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(router)
assert model_settings is not None
set_llama_proxy(model_settings=model_settings)
if server_settings.disable_ping_events:
set_ping_message_factory(lambda: bytes())
return app
def prepare_request_resources(
body: CreateCompletionRequest | CreateChatCompletionRequest,
llama_proxy: LlamaProxy,
body_model: str | None,
kwargs,
) -> llama_cpp.Llama:
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
llama = llama_proxy(body_model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
return llama
async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream[typing.Any],
body: CreateCompletionRequest | CreateChatCompletionRequest,
body_model: str | None,
llama_call,
kwargs,
):
server_settings = next(get_server_settings())
interrupt_requests = (
server_settings.interrupt_requests if server_settings else False
)
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
async with inner_send_chan:
try:
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
async for chunk in iterate_in_threadpool(iterator):
await inner_send_chan.send(dict(data=json.dumps(chunk)))
if await request.is_disconnected():
raise anyio.get_cancelled_exc_class()()
if interrupt_requests and llama_outer_lock.locked():
await inner_send_chan.send(dict(data="[DONE]"))
raise anyio.get_cancelled_exc_class()()
await inner_send_chan.send(dict(data="[DONE]"))
except anyio.get_cancelled_exc_class() as e:
print("disconnected")
with anyio.move_on_after(1, shield=True):
print(
f"Disconnected from client (via refresh/close) {request.client}"
)
raise e
def _logit_bias_tokens_to_input_ids(
llama: llama_cpp.Llama,
logit_bias: Dict[str, float],
) -> Dict[str, float]:
to_bias: Dict[str, float] = {}
for token, score in logit_bias.items():
token = token.encode("utf-8")
for input_id in llama.tokenize(token, add_bos=False, special=True):
to_bias[str(input_id)] = score
return to_bias
# Setup Bearer authentication scheme
bearer_scheme = HTTPBearer(auto_error=False)
async def authenticate(
settings: Settings = Depends(get_server_settings),
authorization: Optional[str] = Depends(bearer_scheme),
):
# Skip API key check if it's not set in settings
if settings.api_key is None:
return True
# check bearer credentials against the api_key
if authorization and authorization.credentials == settings.api_key:
# api key is valid
return authorization.credentials
# raise http error 401
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
openai_v1_tag = "OpenAI V1"
@router.post(
"/v1/completions",
summary="Completion",
dependencies=[Depends(authenticate)],
response_model=Union[
llama_cpp.CreateCompletionResponse,
str,
],
responses={
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"anyOf": [
{"$ref": "#/components/schemas/CreateCompletionResponse"}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True. "
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
},
},
}
},
tags=[openai_v1_tag],
)
@router.post(
"/v1/engines/copilot-codex/completions",
include_in_schema=False,
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def create_completion(
request: Request,
body: CreateCompletionRequest,
) -> llama_cpp.Completion:
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
body_model = (
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
)
exclude = {
"n",
"best_of",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
# handle streaming request
if kwargs.get("stream", False):
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
body=body,
body_model=body_model,
llama_call=llama_cpp.Llama.__call__,
kwargs=kwargs,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
# handle regular request
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
if await request.is_disconnected():
print(
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Client closed request",
)
return await run_in_threadpool(llama, **kwargs)
@router.post(
"/v1/embeddings",
summary="Embedding",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def create_embedding(
request: CreateEmbeddingRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
):
return await run_in_threadpool(
llama_proxy(request.model).create_embedding,
**request.model_dump(exclude={"user"}),
)
@router.post(
"/v1/chat/completions",
summary="Chat",
dependencies=[Depends(authenticate)],
response_model=Union[llama_cpp.ChatCompletion, str],
responses={
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"anyOf": [
{
"$ref": "#/components/schemas/CreateChatCompletionResponse"
}
],
"title": "Completion response, when stream=False",
}
},
"text/event-stream": {
"schema": {
"type": "string",
"title": "Server Side Streaming response, when stream=True"
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
}
},
},
}
},
tags=[openai_v1_tag],
)
async def create_chat_completion(
request: Request,
body: CreateChatCompletionRequest = Body(
openapi_examples={
"normal": {
"summary": "Chat Completion",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
},
},
"json_mode": {
"summary": "JSON Mode",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020"},
],
"response_format": {"type": "json_object"},
},
},
"tool_calling": {
"summary": "Tool Calling",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Extract Jason is 30 years old."},
],
"tools": [
{
"type": "function",
"function": {
"name": "User",
"description": "User record",
"parameters": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "number"},
},
"required": ["name", "age"],
},
},
}
],
"tool_choice": {
"type": "function",
"function": {
"name": "User",
},
},
},
},
"logprobs": {
"summary": "Logprobs",
"value": {
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
"logprobs": True,
"top_logprobs": 10,
},
},
}
),
) -> llama_cpp.ChatCompletion:
# This is a workaround for an issue in FastAPI dependencies
# where the dependency is cleaned up before a StreamingResponse
# is complete.
# https://github.com/tiangolo/fastapi/issues/11143
body_model = body.model
exclude = {
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
# handle streaming request
if kwargs.get("stream", False):
send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
body=body,
body_model=body_model,
llama_call=llama_cpp.Llama.create_chat_completion,
kwargs=kwargs,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
# handle regular request
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
if await request.is_disconnected():
print(
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Client closed request",
)
return await run_in_threadpool(llama.create_chat_completion, **kwargs)
@router.get(
"/v1/models",
summary="Models",
dependencies=[Depends(authenticate)],
tags=[openai_v1_tag],
)
async def get_models(
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> ModelList:
return {
"object": "list",
"data": [
{
"id": model_alias,
"object": "model",
"owned_by": "me",
"permissions": [],
}
for model_alias in llama_proxy
],
}
extras_tag = "Extras"
@router.post(
"/extras/tokenize",
summary="Tokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def tokenize(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
return TokenizeInputResponse(tokens=tokens)
@router.post(
"/extras/tokenize/count",
summary="Tokenize Count",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def count_query_tokens(
body: TokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> TokenizeInputCountResponse:
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
return TokenizeInputCountResponse(count=len(tokens))
@router.post(
"/extras/detokenize",
summary="Detokenize",
dependencies=[Depends(authenticate)],
tags=[extras_tag],
)
async def detokenize(
body: DetokenizeInputRequest,
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
) -> DetokenizeInputResponse:
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
return DetokenizeInputResponse(text=text)

View file

@ -0,0 +1,97 @@
from __future__ import annotations
import argparse
from typing import List, Literal, Union, Any, Type, TypeVar
from pydantic import BaseModel
def _get_base_type(annotation: Type[Any]) -> Type[Any]:
if getattr(annotation, "__origin__", None) is Literal:
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
return type(annotation.__args__[0]) # type: ignore
elif getattr(annotation, "__origin__", None) is Union:
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
non_optional_args: List[Type[Any]] = [
arg for arg in annotation.__args__ if arg is not type(None) # type: ignore
]
if non_optional_args:
return _get_base_type(non_optional_args[0])
elif (
getattr(annotation, "__origin__", None) is list
or getattr(annotation, "__origin__", None) is List
):
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
return _get_base_type(annotation.__args__[0]) # type: ignore
return annotation
def _contains_list_type(annotation: Type[Any] | None) -> bool:
origin = getattr(annotation, "__origin__", None)
if origin is list or origin is List:
return True
elif origin in (Literal, Union):
return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore
else:
return False
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
if isinstance(arg, bytes):
arg = arg.decode("utf-8")
true_values = {"1", "on", "t", "true", "y", "yes"}
false_values = {"0", "off", "f", "false", "n", "no"}
arg_str = str(arg).lower().strip()
if arg_str in true_values:
return True
elif arg_str in false_values:
return False
else:
raise ValueError(f"Invalid boolean argument: {arg}")
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
"""Add arguments from a pydantic model to an argparse parser."""
for name, field in model.model_fields.items():
description = field.description
if field.default and description and not field.is_required():
description += f" (default: {field.default})"
base_type = (
_get_base_type(field.annotation) if field.annotation is not None else str
)
list_type = _contains_list_type(field.annotation)
if base_type is not bool:
parser.add_argument(
f"--{name}",
dest=name,
nargs="*" if list_type else None,
type=base_type,
help=description,
)
if base_type is bool:
parser.add_argument(
f"--{name}",
dest=name,
type=_parse_bool_arg,
help=f"{description}",
)
T = TypeVar("T", bound=Type[BaseModel])
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
"""Parse a pydantic model from an argparse namespace."""
return model(
**{
k: v
for k, v in vars(args).items()
if v is not None and k in model.model_fields
}
)

View file

@ -0,0 +1,212 @@
from __future__ import annotations
import sys
import traceback
import time
from re import compile, Match, Pattern
from typing import Callable, Coroutine, Optional, Tuple, Union, Dict
from typing_extensions import TypedDict
from fastapi import (
Request,
Response,
HTTPException,
)
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from llama_cpp.server.types import (
CreateCompletionRequest,
CreateEmbeddingRequest,
CreateChatCompletionRequest,
)
class ErrorResponse(TypedDict):
"""OpenAI style error response"""
message: str
type: str
param: Optional[str]
code: Optional[str]
class ErrorResponseFormatters:
"""Collection of formatters for error responses.
Args:
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
Request body
match (Match[str]): Match object from regex pattern
Returns:
Tuple[int, ErrorResponse]: Status code and error response
"""
@staticmethod
def context_length_exceeded(
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for context length exceeded error"""
context_window = int(match.group(2))
prompt_tokens = int(match.group(1))
completion_tokens = request.max_tokens
if hasattr(request, "messages"):
# Chat completion
message = (
"This model's maximum context length is {} tokens. "
"However, you requested {} tokens "
"({} in the messages, {} in the completion). "
"Please reduce the length of the messages or completion."
)
else:
# Text completion
message = (
"This model's maximum context length is {} tokens, "
"however you requested {} tokens "
"({} in your prompt; {} for the completion). "
"Please reduce your prompt; or completion length."
)
return 400, ErrorResponse(
message=message.format(
context_window,
(completion_tokens or 0) + prompt_tokens,
prompt_tokens,
completion_tokens,
), # type: ignore
type="invalid_request_error",
param="messages",
code="context_length_exceeded",
)
@staticmethod
def model_not_found(
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
match, # type: Match[str] # type: ignore
) -> Tuple[int, ErrorResponse]:
"""Formatter for model_not_found error"""
model_path = str(match.group(1))
message = f"The model `{model_path}` does not exist"
return 400, ErrorResponse(
message=message,
type="invalid_request_error",
param=None,
code="model_not_found",
)
class RouteErrorHandler(APIRoute):
"""Custom APIRoute that handles application errors and exceptions"""
# key: regex pattern for original error message from llama_cpp
# value: formatter function
pattern_and_formatters: Dict[
"Pattern[str]",
Callable[
[
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
"Match[str]",
],
Tuple[int, ErrorResponse],
],
] = {
compile(
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
): ErrorResponseFormatters.context_length_exceeded,
compile(
r"Model path does not exist: (.+)"
): ErrorResponseFormatters.model_not_found,
}
def error_message_wrapper(
self,
error: Exception,
body: Optional[
Union[
"CreateChatCompletionRequest",
"CreateCompletionRequest",
"CreateEmbeddingRequest",
]
] = None,
) -> Tuple[int, ErrorResponse]:
"""Wraps error message in OpenAI style error response"""
if body is not None and isinstance(
body,
(
CreateCompletionRequest,
CreateChatCompletionRequest,
),
):
# When text completion or chat completion
for pattern, callback in self.pattern_and_formatters.items():
match = pattern.search(str(error))
if match is not None:
return callback(body, match)
# Only print the trace on unexpected exceptions
print(f"Exception: {str(error)}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
# Wrap other errors as internal server error
return 500, ErrorResponse(
message=str(error),
type="internal_server_error",
param=None,
code=None,
)
def get_route_handler(
self,
) -> Callable[[Request], Coroutine[None, None, Response]]:
"""Defines custom route handler that catches exceptions and formats
in OpenAI style error response"""
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
try:
start_sec = time.perf_counter()
response = await original_route_handler(request)
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
return response
except HTTPException as unauthorized:
# api key check failed
raise unauthorized
except Exception as exc:
json_body = await request.json()
try:
if "messages" in json_body:
# Chat completion
body: Optional[
Union[
CreateChatCompletionRequest,
CreateCompletionRequest,
CreateEmbeddingRequest,
]
] = CreateChatCompletionRequest(**json_body)
elif "prompt" in json_body:
# Text completion
body = CreateCompletionRequest(**json_body)
else:
# Embedding
body = CreateEmbeddingRequest(**json_body)
except Exception:
# Invalid request body
body = None
# Get proper error message from the exception
(
status_code,
error_message,
) = self.error_message_wrapper(error=exc, body=body)
return JSONResponse(
{"error": error_message},
status_code=status_code,
)
return custom_route_handler

View file

@ -0,0 +1,312 @@
from __future__ import annotations
import json
from typing import Dict, Optional, Union, List
import llama_cpp
import llama_cpp.llama_speculative as llama_speculative
import llama_cpp.llama_tokenizer as llama_tokenizer
from llama_cpp.server.settings import ModelSettings
class LlamaProxy:
def __init__(self, models: List[ModelSettings]) -> None:
assert len(models) > 0, "No models provided!"
self._model_settings_dict: dict[str, ModelSettings] = {}
for model in models:
if not model.model_alias:
model.model_alias = model.model
self._model_settings_dict[model.model_alias] = model
self._current_model: Optional[llama_cpp.Llama] = None
self._current_model_alias: Optional[str] = None
self._default_model_settings: ModelSettings = models[0]
self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore
# Load default model
self._current_model = self.load_llama_from_model_settings(
self._default_model_settings
)
self._current_model_alias = self._default_model_alias
def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
if model is None:
model = self._default_model_alias
if model not in self._model_settings_dict:
model = self._default_model_alias
if model == self._current_model_alias:
if self._current_model is not None:
return self._current_model
if self._current_model:
self._current_model.close()
self._current_model = None
settings = self._model_settings_dict[model]
self._current_model = self.load_llama_from_model_settings(settings)
self._current_model_alias = model
return self._current_model
def __getitem__(self, model: str):
return self._model_settings_dict[model].model_dump()
def __setitem__(self, model: str, settings: Union[ModelSettings, str, bytes]):
if isinstance(settings, (bytes, str)):
settings = ModelSettings.model_validate_json(settings)
self._model_settings_dict[model] = settings
def __iter__(self):
for model in self._model_settings_dict:
yield model
def free(self):
if self._current_model:
self._current_model.close()
del self._current_model
@staticmethod
def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
chat_handler = None
if settings.chat_format == "llava-1-5":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llava15ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "obsidian":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.ObsidianChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.ObsidianChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "llava-1-6":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llava16ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llava16ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "moondream":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.MoondreamChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.MoondreamChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "nanollava":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.NanoLlavaChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "llama-3-vision-alpha":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Llama3VisionAlpha.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Llama3VisionAlpha(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "minicpm-v-2.6":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.MiniCPMv26ChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.MiniCPMv26ChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "qwen2.5-vl":
assert settings.clip_model_path is not None, "clip model not found"
if settings.hf_model_repo_id is not None:
chat_handler = (
llama_cpp.llama_chat_format.Qwen25VLChatHandler.from_pretrained(
repo_id=settings.hf_model_repo_id,
filename=settings.clip_model_path,
verbose=settings.verbose,
)
)
else:
chat_handler = llama_cpp.llama_chat_format.Qwen25VLChatHandler(
clip_model_path=settings.clip_model_path, verbose=settings.verbose
)
elif settings.chat_format == "hf-autotokenizer":
assert (
settings.hf_pretrained_model_name_or_path is not None
), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
chat_handler = (
llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_completion_handler(
settings.hf_pretrained_model_name_or_path
)
)
elif settings.chat_format == "hf-tokenizer-config":
assert (
settings.hf_tokenizer_config_path is not None
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
json.load(open(settings.hf_tokenizer_config_path))
)
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
if settings.hf_pretrained_model_name_or_path is not None:
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
settings.hf_pretrained_model_name_or_path
)
draft_model = None
if settings.draft_model is not None:
draft_model = llama_speculative.LlamaPromptLookupDecoding(
num_pred_tokens=settings.draft_model_num_pred_tokens
)
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
kv_overrides = {}
for kv in settings.kv_overrides:
key, value = kv.split("=")
if ":" in value:
value_type, value = value.split(":")
if value_type == "bool":
kv_overrides[key] = value.lower() in ["true", "1"]
elif value_type == "int":
kv_overrides[key] = int(value)
elif value_type == "float":
kv_overrides[key] = float(value)
elif value_type == "str":
kv_overrides[key] = value
else:
raise ValueError(f"Unknown value type {value_type}")
import functools
kwargs = {}
if settings.hf_model_repo_id is not None:
create_fn = functools.partial(
llama_cpp.Llama.from_pretrained,
repo_id=settings.hf_model_repo_id,
filename=settings.model,
)
else:
create_fn = llama_cpp.Llama
kwargs["model_path"] = settings.model
_model = create_fn(
**kwargs,
# Model Params
n_gpu_layers=settings.n_gpu_layers,
split_mode=settings.split_mode,
main_gpu=settings.main_gpu,
tensor_split=settings.tensor_split,
vocab_only=settings.vocab_only,
use_mmap=settings.use_mmap,
use_mlock=settings.use_mlock,
kv_overrides=kv_overrides,
rpc_servers=settings.rpc_servers,
# Context Params
seed=settings.seed,
n_ctx=settings.n_ctx,
n_batch=settings.n_batch,
n_ubatch=settings.n_ubatch,
n_threads=settings.n_threads,
n_threads_batch=settings.n_threads_batch,
rope_scaling_type=settings.rope_scaling_type,
rope_freq_base=settings.rope_freq_base,
rope_freq_scale=settings.rope_freq_scale,
yarn_ext_factor=settings.yarn_ext_factor,
yarn_attn_factor=settings.yarn_attn_factor,
yarn_beta_fast=settings.yarn_beta_fast,
yarn_beta_slow=settings.yarn_beta_slow,
yarn_orig_ctx=settings.yarn_orig_ctx,
mul_mat_q=settings.mul_mat_q,
logits_all=settings.logits_all,
embedding=settings.embedding,
offload_kqv=settings.offload_kqv,
flash_attn=settings.flash_attn,
# Sampling Params
last_n_tokens_size=settings.last_n_tokens_size,
# LoRA Params
lora_base=settings.lora_base,
lora_path=settings.lora_path,
# Backend Params
numa=settings.numa,
# Chat Format Params
chat_format=settings.chat_format,
chat_handler=chat_handler,
# Speculative Decoding
draft_model=draft_model,
# KV Cache Quantization
type_k=settings.type_k,
type_v=settings.type_v,
# Tokenizer
tokenizer=tokenizer,
# Misc
verbose=settings.verbose,
)
if settings.cache:
if settings.cache_type == "disk":
if settings.verbose:
print(f"Using disk cache with size {settings.cache_size}")
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
else:
if settings.verbose:
print(f"Using ram cache with size {settings.cache_size}")
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
_model.set_cache(cache)
return _model

View file

@ -0,0 +1,240 @@
from __future__ import annotations
import multiprocessing
from typing import Optional, List, Literal, Union, Dict, cast
from typing_extensions import Self
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings
import llama_cpp
# Disable warning for model and model_alias settings
BaseSettings.model_config["protected_namespaces"] = ()
class ModelSettings(BaseSettings):
"""Model settings used to load a Llama model."""
model: str = Field(
description="The path to the model to use for generating completions."
)
model_alias: Optional[str] = Field(
default=None,
description="The alias of the model to use for generating completions.",
)
# Model Params
n_gpu_layers: int = Field(
default=0,
ge=-1,
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
)
split_mode: int = Field(
default=llama_cpp.LLAMA_SPLIT_MODE_LAYER,
description="The split mode to use.",
)
main_gpu: int = Field(
default=0,
ge=0,
description="Main GPU to use.",
)
tensor_split: Optional[List[float]] = Field(
default=None,
description="Split layers across multiple GPUs in proportion.",
)
vocab_only: bool = Field(
default=False, description="Whether to only return the vocabulary."
)
use_mmap: bool = Field(
default=llama_cpp.llama_supports_mmap(),
description="Use mmap.",
)
use_mlock: bool = Field(
default=llama_cpp.llama_supports_mlock(),
description="Use mlock.",
)
kv_overrides: Optional[List[str]] = Field(
default=None,
description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.",
)
rpc_servers: Optional[str] = Field(
default=None,
description="comma seperated list of rpc servers for offloading",
)
# Context Params
seed: int = Field(
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
)
n_ctx: int = Field(default=2048, ge=0, description="The context size.")
n_batch: int = Field(
default=512, ge=1, description="The batch size to use per eval."
)
n_ubatch: int = Field(
default=512, ge=1, description="The physical batch size used by llama.cpp"
)
n_threads: int = Field(
default=max(multiprocessing.cpu_count() // 2, 1),
ge=1,
description="The number of threads to use. Use -1 for max cpu threads",
)
n_threads_batch: int = Field(
default=max(multiprocessing.cpu_count(), 1),
ge=0,
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
)
rope_scaling_type: int = Field(
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
rope_freq_scale: float = Field(
default=0.0, description="RoPE frequency scaling factor"
)
yarn_ext_factor: float = Field(default=-1.0)
yarn_attn_factor: float = Field(default=1.0)
yarn_beta_fast: float = Field(default=32.0)
yarn_beta_slow: float = Field(default=1.0)
yarn_orig_ctx: int = Field(default=0)
mul_mat_q: bool = Field(
default=True, description="if true, use experimental mul_mat_q kernels"
)
logits_all: bool = Field(default=True, description="Whether to return logits.")
embedding: bool = Field(default=False, description="Whether to use embeddings.")
offload_kqv: bool = Field(
default=True, description="Whether to offload kqv to the GPU."
)
flash_attn: bool = Field(
default=False, description="Whether to use flash attention."
)
# Sampling Params
last_n_tokens_size: int = Field(
default=64,
ge=0,
description="Last n tokens to keep for repeat penalty calculation.",
)
# LoRA Params
lora_base: Optional[str] = Field(
default=None,
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
)
lora_path: Optional[str] = Field(
default=None,
description="Path to a LoRA file to apply to the model.",
)
# Backend Params
numa: Union[bool, int] = Field(
default=False,
description="Enable NUMA support.",
)
# Chat Format Params
chat_format: Optional[str] = Field(
default=None,
description="Chat format to use.",
)
clip_model_path: Optional[str] = Field(
default=None,
description="Path to a CLIP model to use for multi-modal chat completion.",
)
# Cache Params
cache: bool = Field(
default=False,
description="Use a cache to reduce processing times for evaluated prompts.",
)
cache_type: Literal["ram", "disk"] = Field(
default="ram",
description="The type of cache to use. Only used if cache is True.",
)
cache_size: int = Field(
default=2 << 30,
description="The size of the cache in bytes. Only used if cache is True.",
)
# Tokenizer Options
hf_tokenizer_config_path: Optional[str] = Field(
default=None,
description="The path to a HuggingFace tokenizer_config.json file.",
)
hf_pretrained_model_name_or_path: Optional[str] = Field(
default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
)
# Loading from HuggingFace Model Hub
hf_model_repo_id: Optional[str] = Field(
default=None,
description="The model repo id to use for the HuggingFace tokenizer model.",
)
# Speculative Decoding
draft_model: Optional[str] = Field(
default=None,
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
)
draft_model_num_pred_tokens: int = Field(
default=10,
description="Number of tokens to predict using the draft model.",
)
# KV Cache Quantization
type_k: Optional[int] = Field(
default=None,
description="Type of the key cache quantization.",
)
type_v: Optional[int] = Field(
default=None,
description="Type of the value cache quantization.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."
)
@model_validator(
mode="before"
) # pre=True to ensure this runs before any other validation
def set_dynamic_defaults(self) -> Self:
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
cpu_count = multiprocessing.cpu_count()
values = cast(Dict[str, int], self)
if values.get("n_threads", 0) == -1:
values["n_threads"] = cpu_count
if values.get("n_threads_batch", 0) == -1:
values["n_threads_batch"] = cpu_count
return self
class ServerSettings(BaseSettings):
"""Server settings used to configure the FastAPI and Uvicorn server."""
# Uvicorn Settings
host: str = Field(default="localhost", description="Listen address")
port: int = Field(default=8000, description="Listen port")
ssl_keyfile: Optional[str] = Field(
default=None, description="SSL key file for HTTPS"
)
ssl_certfile: Optional[str] = Field(
default=None, description="SSL certificate file for HTTPS"
)
# FastAPI Settings
api_key: Optional[str] = Field(
default=None,
description="API key for authentication. If set all requests need to be authenticated.",
)
interrupt_requests: bool = Field(
default=True,
description="Whether to interrupt requests when a new request is received.",
)
disable_ping_events: bool = Field(
default=False,
description="Disable EventSource pings (may be needed for some clients).",
)
root_path: str = Field(
default="",
description="The root path for the server. Useful when running behind a reverse proxy.",
)
class Settings(ServerSettings, ModelSettings):
pass
class ConfigFileSettings(ServerSettings):
"""Configuration file format settings."""
models: List[ModelSettings] = Field(default=[], description="Model configs")

View file

@ -0,0 +1,316 @@
from __future__ import annotations
from typing import List, Optional, Union, Dict
from typing_extensions import TypedDict, Literal
from pydantic import BaseModel, Field
import llama_cpp
model_field = Field(
description="The model to use for generating completions.", default=None
)
max_tokens_field = Field(
default=16, ge=1, description="The maximum number of tokens to generate."
)
min_tokens_field = Field(
default=0,
ge=0,
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
)
temperature_field = Field(
default=0.8,
description="Adjust the randomness of the generated text.\n\n"
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
)
top_p_field = Field(
default=0.95,
ge=0.0,
le=1.0,
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
)
min_p_field = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Sets a minimum base probability threshold for token selection.\n\n"
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
)
stop_field = Field(
default=None,
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
)
stream_field = Field(
default=False,
description="Whether to stream the results as they are generated. Useful for chatbots.",
)
top_k_field = Field(
default=40,
ge=0,
description="Limit the next token selection to the K most probable tokens.\n\n"
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
)
repeat_penalty_field = Field(
default=1.1,
ge=0.0,
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
)
presence_penalty_field = Field(
default=0.0,
ge=-2.0,
le=2.0,
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
)
frequency_penalty_field = Field(
default=0.0,
ge=-2.0,
le=2.0,
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
)
mirostat_mode_field = Field(
default=0,
ge=0,
le=2,
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
)
mirostat_tau_field = Field(
default=5.0,
ge=0.0,
le=10.0,
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
)
mirostat_eta_field = Field(
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
)
grammar = Field(
default=None,
description="A CBNF grammar (as string) to be used for formatting the model's output.",
)
class CreateCompletionRequest(BaseModel):
prompt: Union[str, List[str]] = Field(
default="", description="The prompt to generate completions for."
)
suffix: Optional[str] = Field(
default=None,
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
)
max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate."
)
min_tokens: int = min_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
echo: bool = Field(
default=False,
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
)
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
logprobs: Optional[int] = Field(
default=None,
ge=0,
description="The number of logprobs to generate. If None, no logprobs are generated.",
)
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
best_of: Optional[int] = 1
user: Optional[str] = Field(default=None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
model_config = {
"json_schema_extra": {
"examples": [
{
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
"stop": ["\n", "###"],
}
]
}
}
class CreateEmbeddingRequest(BaseModel):
model: Optional[str] = model_field
input: Union[str, List[str]] = Field(description="The input to embed.")
user: Optional[str] = Field(default=None)
model_config = {
"json_schema_extra": {
"examples": [
{
"input": "The food was delicious and the waiter...",
}
]
}
}
class ChatCompletionRequestMessage(BaseModel):
role: Literal["system", "user", "assistant", "function"] = Field(
default="user", description="The role of the message."
)
content: Optional[str] = Field(
default="", description="The content of the message."
)
class CreateChatCompletionRequest(BaseModel):
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
default=[], description="A list of messages to generate completions for."
)
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
default=None,
description="A list of functions to apply to the generated completions.",
)
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
default=None,
description="A function to apply to the generated completions.",
)
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
default=None,
description="A list of tools to apply to the generated completions.",
)
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
default=None,
description="A tool to apply to the generated completions.",
) # TODO: verify
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
)
min_tokens: int = min_tokens_field
logprobs: Optional[bool] = Field(
default=False,
description="Whether to output the logprobs or not. Default is True",
)
top_logprobs: Optional[int] = Field(
default=None,
ge=0,
description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.",
)
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None)
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
default=None,
)
# ignored or currently unsupported
model: Optional[str] = model_field
n: Optional[int] = 1
user: Optional[str] = Field(None)
# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
grammar: Optional[str] = None
model_config = {
"json_schema_extra": {
"examples": [
{
"messages": [
ChatCompletionRequestMessage(
role="system", content="You are a helpful assistant."
).model_dump(),
ChatCompletionRequestMessage(
role="user", content="What is the capital of France?"
).model_dump(),
]
}
]
}
}
class ModelData(TypedDict):
id: str
object: Literal["model"]
owned_by: str
permissions: List[str]
class ModelList(TypedDict):
object: Literal["list"]
data: List[ModelData]
class TokenizeInputRequest(BaseModel):
model: Optional[str] = model_field
input: str = Field(description="The input to tokenize.")
model_config = {
"json_schema_extra": {"examples": [{"input": "How many tokens in this query?"}]}
}
class TokenizeInputResponse(BaseModel):
tokens: List[int] = Field(description="A list of tokens.")
model_config = {"json_schema_extra": {"example": {"tokens": [123, 321, 222]}}}
class TokenizeInputCountResponse(BaseModel):
count: int = Field(description="The number of tokens in the input.")
model_config = {"json_schema_extra": {"example": {"count": 5}}}
class DetokenizeInputRequest(BaseModel):
model: Optional[str] = model_field
tokens: List[int] = Field(description="A list of toekns to detokenize.")
model_config = {"json_schema_extra": {"example": [{"tokens": [123, 321, 222]}]}}
class DetokenizeInputResponse(BaseModel):
text: str = Field(description="The detokenized text.")
model_config = {
"json_schema_extra": {"example": {"text": "How many tokens in this query?"}}
}