Adding all project files
This commit is contained in:
parent
6c9e127bdc
commit
cd4316ad0f
42289 changed files with 8009643 additions and 0 deletions
4
venv/Lib/site-packages/llama_cpp/__init__.py
Normal file
4
venv/Lib/site-packages/llama_cpp/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from .llama_cpp import *
|
||||
from .llama import *
|
||||
|
||||
__version__ = "0.3.14"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
131
venv/Lib/site-packages/llama_cpp/_ctypes_extensions.py
Normal file
131
venv/Lib/site-packages/llama_cpp/_ctypes_extensions.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
import functools
|
||||
import pathlib
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
List,
|
||||
Union,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Generic,
|
||||
)
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
# Load the library
|
||||
def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
|
||||
"""Platform independent shared library loader"""
|
||||
# Searching for the library in the current directory under the name "libllama" (default name
|
||||
# for llamacpp) and "llama" (default name for this repo)
|
||||
lib_paths: List[pathlib.Path] = []
|
||||
# Determine the file extension based on the platform
|
||||
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
|
||||
lib_paths += [
|
||||
base_path / f"lib{lib_base_name}.so",
|
||||
]
|
||||
elif sys.platform == "darwin":
|
||||
lib_paths += [
|
||||
base_path / f"lib{lib_base_name}.so",
|
||||
base_path / f"lib{lib_base_name}.dylib",
|
||||
]
|
||||
elif sys.platform == "win32":
|
||||
lib_paths += [
|
||||
base_path / f"{lib_base_name}.dll",
|
||||
base_path / f"lib{lib_base_name}.dll",
|
||||
]
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform")
|
||||
|
||||
cdll_args = dict() # type: ignore
|
||||
|
||||
# Add the library directory to the DLL search path on Windows (if needed)
|
||||
if sys.platform == "win32":
|
||||
os.add_dll_directory(str(base_path))
|
||||
os.environ["PATH"] = str(base_path) + os.pathsep + os.environ["PATH"]
|
||||
|
||||
if sys.platform == "win32" and sys.version_info >= (3, 8):
|
||||
os.add_dll_directory(str(base_path))
|
||||
if "CUDA_PATH" in os.environ:
|
||||
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
|
||||
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
|
||||
if "HIP_PATH" in os.environ:
|
||||
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
|
||||
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib"))
|
||||
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
|
||||
|
||||
# Try to load the shared library, handling potential errors
|
||||
for lib_path in lib_paths:
|
||||
if lib_path.exists():
|
||||
try:
|
||||
return ctypes.CDLL(str(lib_path), **cdll_args) # type: ignore
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared library '{lib_path}': {e}")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Shared library with base name '{lib_base_name}' not found"
|
||||
)
|
||||
|
||||
|
||||
# ctypes sane type hint helpers
|
||||
#
|
||||
# - Generic Pointer and Array types
|
||||
# - PointerOrRef type with a type hinted byref function
|
||||
#
|
||||
# NOTE: Only use these for static type checking not for runtime checks
|
||||
# no good will come of that
|
||||
|
||||
if TYPE_CHECKING:
|
||||
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
|
||||
|
||||
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
|
||||
|
||||
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
|
||||
|
||||
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
|
||||
|
||||
class CtypesRef(Generic[CtypesCData]):
|
||||
pass
|
||||
|
||||
CtypesPointerOrRef: TypeAlias = Union[
|
||||
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
|
||||
]
|
||||
|
||||
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
|
||||
"""Decorator for defining ctypes functions with type hints"""
|
||||
|
||||
def ctypes_function(
|
||||
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
|
||||
):
|
||||
def decorator(f: F) -> F:
|
||||
if enabled:
|
||||
func = getattr(lib, name)
|
||||
func.argtypes = argtypes
|
||||
func.restype = restype
|
||||
functools.wraps(f)(func)
|
||||
return func
|
||||
else:
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
return ctypes_function
|
||||
|
||||
|
||||
def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
|
||||
"""Type-annotated version of ctypes.byref"""
|
||||
...
|
||||
|
||||
|
||||
byref = _byref if TYPE_CHECKING else ctypes.byref
|
12
venv/Lib/site-packages/llama_cpp/_ggml.py
Normal file
12
venv/Lib/site-packages/llama_cpp/_ggml.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
"""Internal module use at your own risk
|
||||
|
||||
This module provides a minimal interface for working with ggml tensors from llama-cpp-python
|
||||
"""
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import llama_cpp._ctypes_extensions as ctypes_ext
|
||||
|
||||
libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
|
||||
libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)
|
||||
|
850
venv/Lib/site-packages/llama_cpp/_internals.py
Normal file
850
venv/Lib/site-packages/llama_cpp/_internals.py
Normal file
|
@ -0,0 +1,850 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Tuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Callable,
|
||||
Union,
|
||||
)
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import ExitStack
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from .llama_types import *
|
||||
from .llama_grammar import LlamaGrammar
|
||||
from ._utils import suppress_stdout_stderr
|
||||
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
|
||||
# Python wrappers over llama.h structs
|
||||
|
||||
|
||||
class LlamaModel:
|
||||
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
||||
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
path_model: str,
|
||||
params: llama_cpp.llama_model_params,
|
||||
verbose: bool = True,
|
||||
):
|
||||
self.path_model = path_model
|
||||
self.params = params
|
||||
self.verbose = verbose
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
model = None
|
||||
|
||||
if not os.path.exists(path_model):
|
||||
raise ValueError(f"Model path does not exist: {path_model}")
|
||||
|
||||
with suppress_stdout_stderr(disable=verbose):
|
||||
model = llama_cpp.llama_model_load_from_file(
|
||||
self.path_model.encode("utf-8"), self.params
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Failed to load model from file: {path_model}")
|
||||
|
||||
vocab = llama_cpp.llama_model_get_vocab(model)
|
||||
|
||||
if vocab is None:
|
||||
raise ValueError(f"Failed to get vocab from model: {path_model}")
|
||||
|
||||
self.model = model
|
||||
self.vocab = vocab
|
||||
self.sampler = None # LlamaModel doesn't use samplers, but some cleanup code expects this attribute
|
||||
|
||||
def free_model():
|
||||
if self.model is None:
|
||||
return
|
||||
llama_cpp.llama_model_free(self.model)
|
||||
self.model = None
|
||||
|
||||
self._exit_stack.callback(free_model)
|
||||
|
||||
def close(self):
|
||||
if self.sampler is not None:
|
||||
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
||||
for i, _ in reversed(self.custom_samplers):
|
||||
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
|
||||
self.custom_samplers.clear()
|
||||
self._exit_stack.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def vocab_type(self) -> int:
|
||||
return llama_cpp.llama_vocab_type(self.vocab)
|
||||
|
||||
def n_vocab(self) -> int:
|
||||
return llama_cpp.llama_vocab_n_tokens(self.vocab)
|
||||
|
||||
def n_ctx_train(self) -> int:
|
||||
return llama_cpp.llama_model_n_ctx_train(self.model)
|
||||
|
||||
def n_embd(self) -> int:
|
||||
return llama_cpp.llama_model_n_embd(self.model)
|
||||
|
||||
def rope_freq_scale_train(self) -> float:
|
||||
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
|
||||
|
||||
def desc(self) -> str:
|
||||
buf = ctypes.create_string_buffer(1024)
|
||||
llama_cpp.llama_model_desc(self.model, buf, 1024)
|
||||
return buf.value.decode("utf-8")
|
||||
|
||||
def size(self) -> int:
|
||||
return llama_cpp.llama_model_size(self.model)
|
||||
|
||||
def n_params(self) -> int:
|
||||
return llama_cpp.llama_model_n_params(self.model)
|
||||
|
||||
def get_tensor(self, name: str) -> ctypes.c_void_p:
|
||||
raise NotImplementedError("get_tensor is not implemented in llama.cpp")
|
||||
|
||||
# Vocab
|
||||
|
||||
def token_get_text(self, token: int) -> str:
|
||||
return llama_cpp.llama_vocab_get_text(self.vocab, token).decode("utf-8")
|
||||
|
||||
def token_get_score(self, token: int) -> float:
|
||||
return llama_cpp.llama_vocab_get_score(self.vocab, token)
|
||||
|
||||
def token_get_attr(self, token: int) -> int:
|
||||
return llama_cpp.llama_vocab_get_attr(self.vocab, token)
|
||||
|
||||
# Special tokens
|
||||
|
||||
def token_bos(self) -> int:
|
||||
return llama_cpp.llama_vocab_bos(self.vocab)
|
||||
|
||||
def token_eos(self) -> int:
|
||||
return llama_cpp.llama_vocab_eos(self.vocab)
|
||||
|
||||
def token_cls(self) -> int:
|
||||
return llama_cpp.llama_vocab_cls(self.vocab)
|
||||
|
||||
def token_sep(self) -> int:
|
||||
return llama_cpp.llama_vocab_sep(self.vocab)
|
||||
|
||||
def token_nl(self) -> int:
|
||||
return llama_cpp.llama_vocab_nl(self.vocab)
|
||||
|
||||
def token_prefix(self) -> int:
|
||||
return llama_cpp.llama_vocab_fim_pre(self.vocab)
|
||||
|
||||
def token_middle(self) -> int:
|
||||
return llama_cpp.llama_vocab_fim_mid(self.vocab)
|
||||
|
||||
def token_suffix(self) -> int:
|
||||
return llama_cpp.llama_vocab_fim_suf(self.vocab)
|
||||
|
||||
def token_eot(self) -> int:
|
||||
return llama_cpp.llama_vocab_eot(self.vocab)
|
||||
|
||||
def add_bos_token(self) -> bool:
|
||||
return llama_cpp.llama_vocab_get_add_bos(self.vocab)
|
||||
|
||||
def add_eos_token(self) -> bool:
|
||||
return llama_cpp.llama_vocab_get_add_eos(self.vocab)
|
||||
|
||||
# Tokenization
|
||||
|
||||
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
||||
n_ctx = self.n_ctx_train()
|
||||
tokens = (llama_cpp.llama_token * n_ctx)()
|
||||
n_tokens = llama_cpp.llama_tokenize(
|
||||
self.vocab, text, len(text), tokens, n_ctx, add_bos, special
|
||||
)
|
||||
if n_tokens < 0:
|
||||
n_tokens = abs(n_tokens)
|
||||
tokens = (llama_cpp.llama_token * n_tokens)()
|
||||
n_tokens = llama_cpp.llama_tokenize(
|
||||
self.vocab, text, len(text), tokens, n_tokens, add_bos, special
|
||||
)
|
||||
if n_tokens < 0:
|
||||
raise RuntimeError(
|
||||
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
||||
)
|
||||
return list(tokens[:n_tokens])
|
||||
|
||||
def token_to_piece(self, token: int, special: bool = False) -> bytes:
|
||||
buf = ctypes.create_string_buffer(32)
|
||||
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
|
||||
return bytes(buf)
|
||||
|
||||
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
|
||||
output = b""
|
||||
size = 32
|
||||
buffer = (ctypes.c_char * size)()
|
||||
for token in tokens:
|
||||
n = llama_cpp.llama_token_to_piece(
|
||||
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
|
||||
)
|
||||
assert n <= size
|
||||
output += bytes(buffer[:n])
|
||||
# NOTE: Llama1 models automatically added a space at the start of the prompt
|
||||
# this line removes a leading space if the first token is a beginning of sentence token
|
||||
return (
|
||||
output[1:]
|
||||
if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
|
||||
else output
|
||||
)
|
||||
|
||||
# Extra
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
metadata: Dict[str, str] = {}
|
||||
buffer_size = 1024
|
||||
buffer = ctypes.create_string_buffer(buffer_size)
|
||||
# zero the buffer
|
||||
buffer.value = b"\0" * buffer_size
|
||||
# iterate over model keys
|
||||
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
||||
nbytes = llama_cpp.llama_model_meta_key_by_index(
|
||||
self.model, i, buffer, buffer_size
|
||||
)
|
||||
if nbytes > buffer_size:
|
||||
buffer_size = nbytes + 1
|
||||
buffer = ctypes.create_string_buffer(buffer_size)
|
||||
nbytes = llama_cpp.llama_model_meta_key_by_index(
|
||||
self.model, i, buffer, buffer_size
|
||||
)
|
||||
key = buffer.value.decode("utf-8")
|
||||
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
|
||||
self.model, i, buffer, buffer_size
|
||||
)
|
||||
if nbytes > buffer_size:
|
||||
buffer_size = nbytes + 1
|
||||
buffer = ctypes.create_string_buffer(buffer_size)
|
||||
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
|
||||
self.model, i, buffer, buffer_size
|
||||
)
|
||||
value = buffer.value.decode("utf-8")
|
||||
metadata[key] = value
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def default_params():
|
||||
"""Get the default llama_model_params."""
|
||||
return llama_cpp.llama_model_default_params()
|
||||
|
||||
|
||||
class LlamaContext:
|
||||
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
||||
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: LlamaModel,
|
||||
params: llama_cpp.llama_context_params,
|
||||
verbose: bool = True,
|
||||
):
|
||||
self.model = model
|
||||
self.params = params
|
||||
self.verbose = verbose
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
ctx = llama_cpp.llama_init_from_model(self.model.model, self.params)
|
||||
|
||||
if ctx is None:
|
||||
raise ValueError("Failed to create llama_context")
|
||||
|
||||
self.ctx = ctx
|
||||
self.memory = llama_cpp.llama_get_memory(self.ctx)
|
||||
self.sampler = None # LlamaContext doesn't manage samplers directly, but some cleanup code expects this attribute
|
||||
|
||||
def free_ctx():
|
||||
if self.ctx is None:
|
||||
return
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
|
||||
self._exit_stack.callback(free_ctx)
|
||||
|
||||
def close(self):
|
||||
self._exit_stack.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def n_ctx(self) -> int:
|
||||
return llama_cpp.llama_n_ctx(self.ctx)
|
||||
|
||||
def pooling_type(self) -> int:
|
||||
return llama_cpp.llama_pooling_type(self.ctx)
|
||||
|
||||
def kv_cache_clear(self):
|
||||
llama_cpp.llama_memory_clear(self.memory, True)
|
||||
|
||||
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
|
||||
llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1)
|
||||
|
||||
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
|
||||
llama_cpp.llama_memory_seq_cp(self.memory, seq_id_src, seq_id_dst, p0, p1)
|
||||
|
||||
def kv_cache_seq_keep(self, seq_id: int):
|
||||
llama_cpp.llama_memory_seq_keep(self.memory, seq_id)
|
||||
|
||||
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
|
||||
llama_cpp.llama_memory_seq_add(self.memory, seq_id, p0, p1, shift)
|
||||
|
||||
def get_state_size(self) -> int:
|
||||
return llama_cpp.llama_state_get_size(self.ctx)
|
||||
|
||||
# TODO: copy_state_data
|
||||
|
||||
# TODO: set_state_data
|
||||
|
||||
# TODO: llama_load_session_file
|
||||
|
||||
# TODO: llama_save_session_file
|
||||
|
||||
def decode(self, batch: LlamaBatch):
|
||||
return_code = llama_cpp.llama_decode(
|
||||
self.ctx,
|
||||
batch.batch,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"llama_decode returned {return_code}")
|
||||
|
||||
def encode(self, batch: LlamaBatch):
|
||||
return_code = llama_cpp.llama_encode(
|
||||
self.ctx,
|
||||
batch.batch,
|
||||
)
|
||||
if return_code != 0:
|
||||
raise RuntimeError(f"llama_encode returned {return_code}")
|
||||
|
||||
def set_n_threads(self, n_threads: int, n_threads_batch: int):
|
||||
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
|
||||
|
||||
def get_logits(self):
|
||||
return llama_cpp.llama_get_logits(self.ctx)
|
||||
|
||||
def get_logits_ith(self, i: int):
|
||||
return llama_cpp.llama_get_logits_ith(self.ctx, i)
|
||||
|
||||
def get_embeddings(self):
|
||||
return llama_cpp.llama_get_embeddings(self.ctx)
|
||||
|
||||
def get_embeddings_ith(self, i: int):
|
||||
return llama_cpp.llama_get_embeddings_ith(self.ctx, i)
|
||||
|
||||
def get_embeddings_seq(self, seq_id: int):
|
||||
return llama_cpp.llama_get_embeddings_seq(self.ctx, seq_id)
|
||||
|
||||
# Sampling functions - deprecated, use LlamaSampler instead
|
||||
|
||||
def set_rng_seed(self, seed: int):
|
||||
raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_repetition_penalties(
|
||||
self,
|
||||
candidates: "_LlamaTokenDataArray",
|
||||
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
|
||||
penalty_last_n: int,
|
||||
penalty_repeat: float,
|
||||
penalty_freq: float,
|
||||
penalty_present: float,
|
||||
):
|
||||
raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
||||
raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
||||
raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||
raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
||||
raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_typical(
|
||||
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
|
||||
):
|
||||
raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
||||
raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
||||
raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_token_mirostat(
|
||||
self,
|
||||
candidates: "_LlamaTokenDataArray",
|
||||
tau: float,
|
||||
eta: float,
|
||||
m: int,
|
||||
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
||||
) -> int:
|
||||
raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_token_mirostat_v2(
|
||||
self,
|
||||
candidates: "_LlamaTokenDataArray",
|
||||
tau: float,
|
||||
eta: float,
|
||||
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
||||
) -> int:
|
||||
raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
|
||||
raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead")
|
||||
|
||||
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
||||
raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead")
|
||||
|
||||
# Grammar
|
||||
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
|
||||
raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead")
|
||||
|
||||
def reset_timings(self):
|
||||
llama_cpp.llama_perf_context_reset(self.ctx)
|
||||
|
||||
def print_timings(self):
|
||||
llama_cpp.llama_perf_context_print(self.ctx)
|
||||
|
||||
# Utility functions
|
||||
@staticmethod
|
||||
def default_params():
|
||||
"""Get the default llama_context_params."""
|
||||
return llama_cpp.llama_context_default_params()
|
||||
|
||||
|
||||
class LlamaBatch:
|
||||
def __init__(
|
||||
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
||||
):
|
||||
self._n_tokens = n_tokens
|
||||
self.embd = embd
|
||||
self.n_seq_max = n_seq_max
|
||||
self.verbose = verbose
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
|
||||
|
||||
if batch is None:
|
||||
raise ValueError("Failed to create llama_batch")
|
||||
|
||||
self.batch = batch
|
||||
self.sampler = None # LlamaBatch doesn't use samplers, but some cleanup code expects this attribute
|
||||
|
||||
def free_batch():
|
||||
if self.batch is None:
|
||||
return
|
||||
llama_cpp.llama_batch_free(self.batch)
|
||||
self.batch = None
|
||||
|
||||
self._exit_stack.callback(free_batch)
|
||||
|
||||
def close(self):
|
||||
self._exit_stack.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def n_tokens(self) -> int:
|
||||
return self.batch.n_tokens
|
||||
|
||||
def reset(self):
|
||||
self.batch.n_tokens = 0
|
||||
|
||||
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
||||
n_tokens = len(batch)
|
||||
self.batch.n_tokens = n_tokens
|
||||
for i in range(n_tokens):
|
||||
self.batch.token[i] = batch[i]
|
||||
self.batch.pos[i] = n_past + i
|
||||
self.batch.seq_id[i][0] = 0
|
||||
self.batch.n_seq_id[i] = 1
|
||||
self.batch.logits[i] = logits_all
|
||||
self.batch.logits[n_tokens - 1] = True
|
||||
|
||||
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
|
||||
n_tokens = len(batch)
|
||||
n_tokens0 = self.batch.n_tokens
|
||||
self.batch.n_tokens += n_tokens
|
||||
for i in range(n_tokens):
|
||||
j = n_tokens0 + i
|
||||
self.batch.token[j] = batch[i]
|
||||
self.batch.pos[j] = i
|
||||
self.batch.seq_id[j][0] = seq_id
|
||||
self.batch.n_seq_id[j] = 1
|
||||
self.batch.logits[j] = logits_all
|
||||
self.batch.logits[n_tokens - 1] = True
|
||||
|
||||
|
||||
class LlamaTokenDataArray:
|
||||
def __init__(self, *, n_vocab: int):
|
||||
self.n_vocab = n_vocab
|
||||
self.candidates_data = np.recarray(
|
||||
(self.n_vocab,),
|
||||
dtype=np.dtype(
|
||||
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
||||
),
|
||||
)
|
||||
self.candidates = llama_cpp.llama_token_data_array(
|
||||
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
||||
size=self.n_vocab,
|
||||
sorted=False,
|
||||
)
|
||||
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
|
||||
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
||||
self.sampler = None # LlamaTokenDataArray doesn't use samplers, but some cleanup code expects this attribute
|
||||
|
||||
def copy_logits(self, logits: npt.NDArray[np.single]):
|
||||
self.candidates_data.id[:] = self.default_candidates_data_id
|
||||
self.candidates_data.logit[:] = logits
|
||||
self.candidates_data.p[:] = self.default_candidates_data_p
|
||||
self.candidates.sorted = False
|
||||
self.candidates.size = self.n_vocab
|
||||
|
||||
|
||||
# Embedding functions
|
||||
|
||||
|
||||
def normalize_embedding(embedding):
|
||||
norm = float(np.linalg.norm(embedding))
|
||||
if norm == 0.0:
|
||||
return embedding
|
||||
return [v / norm for v in embedding]
|
||||
|
||||
|
||||
# Python wrappers over common/sampling structs
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlamaSamplingParams:
|
||||
n_prev: int = 64
|
||||
n_probs: int = 0
|
||||
top_k: int = 40
|
||||
top_p: float = 0.95
|
||||
min_p: float = 0.05
|
||||
tfs_z: float = 1.00
|
||||
typical_p: float = 1.00
|
||||
temp: float = 0.80
|
||||
penalty_last_n: int = 64
|
||||
penalty_repeat: float = 1.0
|
||||
penalty_freq: float = 0.00
|
||||
penalty_present: float = 0.00
|
||||
mirostat: int = 0
|
||||
mirostat_tau: float = 5.00
|
||||
mirostat_eta: float = 0.10
|
||||
penalize_nl: bool = True
|
||||
|
||||
grammar: str = ""
|
||||
|
||||
cfg_negative_prompt: str = ""
|
||||
cfg_scale: float = 1.00
|
||||
|
||||
logit_bias: dict[int, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlamaSamplingContext:
|
||||
params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams)
|
||||
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
|
||||
grammar: Optional[LlamaGrammar] = None
|
||||
# NOTE: Missing parsed_grammar
|
||||
prev: list[int] = field(default_factory=list)
|
||||
cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
|
||||
|
||||
def reset(self):
|
||||
self.prev = []
|
||||
self.cur = []
|
||||
if self.grammar is not None:
|
||||
self.grammar.reset()
|
||||
|
||||
def cp(self):
|
||||
return LlamaSamplingContext(
|
||||
params=self.params,
|
||||
mirostat_mu=self.mirostat_mu,
|
||||
grammar=self.grammar,
|
||||
prev=self.prev.copy(),
|
||||
cur=self.cur.copy(),
|
||||
)
|
||||
|
||||
def last(self) -> Optional[int]:
|
||||
if len(self.prev) > 0:
|
||||
return self.prev[-1]
|
||||
else:
|
||||
return None
|
||||
|
||||
def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
|
||||
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
||||
|
||||
def sample(
|
||||
self,
|
||||
ctx_main: LlamaContext,
|
||||
idx: int = 0,
|
||||
logits_array: Optional[npt.NDArray[np.single]] = None,
|
||||
):
|
||||
# This method is deprecated in favor of using LlamaSampler directly
|
||||
raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead")
|
||||
|
||||
def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
|
||||
self.prev.append(id)
|
||||
|
||||
|
||||
class CustomSampler:
|
||||
def __init__(
|
||||
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
|
||||
):
|
||||
self.apply_func = apply_func
|
||||
|
||||
def apply_wrapper(
|
||||
sampler: llama_cpp.llama_sampler_p,
|
||||
cur_p: llama_cpp.llama_token_data_array_p,
|
||||
):
|
||||
self.apply_func(cur_p)
|
||||
|
||||
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
|
||||
pass
|
||||
|
||||
sampler_i = llama_cpp.llama_sampler_i()
|
||||
sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
|
||||
self._apply_wrapper_ref = apply_wrapper
|
||||
|
||||
sampler_i.name = llama_cpp.llama_sampler_i_name(0)
|
||||
sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
|
||||
sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
|
||||
sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
|
||||
sampler_i.free = llama_cpp.llama_sampler_i_free(0)
|
||||
|
||||
self.sampler = llama_cpp.llama_sampler()
|
||||
self.sampler.iface = ctypes.pointer(sampler_i)
|
||||
self.sampler.ctx = None
|
||||
|
||||
def get_sampler(self) -> llama_cpp.llama_sampler_p:
|
||||
return ctypes.pointer(self.sampler)
|
||||
|
||||
|
||||
class LlamaSampler:
|
||||
def __init__(self):
|
||||
params = llama_cpp.llama_sampler_chain_default_params()
|
||||
self.sampler = llama_cpp.llama_sampler_chain_init(params)
|
||||
self.custom_samplers: List[Tuple[int, CustomSampler]] = []
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
def free_sampler():
|
||||
if self.sampler is not None:
|
||||
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
||||
for i, _ in reversed(self.custom_samplers):
|
||||
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
|
||||
llama_cpp.llama_sampler_free(self.sampler)
|
||||
self.sampler = None
|
||||
|
||||
self._exit_stack.callback(free_sampler)
|
||||
|
||||
def close(self):
|
||||
self._exit_stack.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def add_greedy(self):
|
||||
sampler = llama_cpp.llama_sampler_init_greedy()
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_dist(self, seed: int):
|
||||
sampler = llama_cpp.llama_sampler_init_dist(seed)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_softmax(self):
|
||||
sampler = llama_cpp.llama_sampler_init_softmax()
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_top_k(self, k: int):
|
||||
sampler = llama_cpp.llama_sampler_init_top_k(k)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_top_p(self, p: float, min_keep: int = 1):
|
||||
sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_min_p(self, p: float, min_keep: int = 1):
|
||||
sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_typical(self, p: float, min_keep: int = 1):
|
||||
sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_temp(self, temp: float):
|
||||
sampler = llama_cpp.llama_sampler_init_temp(temp)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_temp_ext(self, t: float, delta: float, exponent: float):
|
||||
sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_xtc(self, p: float, t: float, min_keep: int, seed: int):
|
||||
sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_top_n_sigma(self, n: float):
|
||||
sampler = llama_cpp.llama_sampler_init_top_n_sigma(n)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
|
||||
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
|
||||
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
|
||||
sampler = llama_cpp.llama_sampler_init_grammar(
|
||||
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
|
||||
)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_grammar_lazy_patterns(
|
||||
self,
|
||||
model: LlamaModel,
|
||||
grammar: LlamaGrammar,
|
||||
trigger_patterns: List[str],
|
||||
trigger_tokens: List[int]
|
||||
):
|
||||
# Convert patterns to C array
|
||||
pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))()
|
||||
for i, pattern in enumerate(trigger_patterns):
|
||||
pattern_ptrs[i] = pattern.encode("utf-8")
|
||||
|
||||
# Convert tokens to C array
|
||||
token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens)
|
||||
|
||||
sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
|
||||
model.vocab,
|
||||
grammar._grammar.encode("utf-8"),
|
||||
grammar._root.encode("utf-8"),
|
||||
pattern_ptrs,
|
||||
len(trigger_patterns),
|
||||
token_array,
|
||||
len(trigger_tokens)
|
||||
)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_penalties(
|
||||
self,
|
||||
penalty_last_n: int,
|
||||
penalty_repeat: float,
|
||||
penalty_freq: float,
|
||||
penalty_present: float,
|
||||
):
|
||||
sampler = llama_cpp.llama_sampler_init_penalties(
|
||||
penalty_last_n,
|
||||
penalty_repeat,
|
||||
penalty_freq,
|
||||
penalty_present,
|
||||
)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_dry(
|
||||
self,
|
||||
model: LlamaModel,
|
||||
n_ctx_train: int,
|
||||
dry_multiplier: float,
|
||||
dry_base: float,
|
||||
dry_allowed_length: int,
|
||||
dry_penalty_last_n: int,
|
||||
seq_breakers: List[str]
|
||||
):
|
||||
# Convert seq_breakers to C array
|
||||
breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))()
|
||||
for i, breaker in enumerate(seq_breakers):
|
||||
breaker_ptrs[i] = breaker.encode("utf-8")
|
||||
|
||||
sampler = llama_cpp.llama_sampler_init_dry(
|
||||
model.vocab,
|
||||
n_ctx_train,
|
||||
dry_multiplier,
|
||||
dry_base,
|
||||
dry_allowed_length,
|
||||
dry_penalty_last_n,
|
||||
breaker_ptrs,
|
||||
len(seq_breakers)
|
||||
)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_logit_bias(
|
||||
self,
|
||||
n_vocab: int,
|
||||
logit_bias: Dict[int, float]
|
||||
):
|
||||
# Convert logit_bias dict to C array
|
||||
bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
|
||||
for i, (token, bias) in enumerate(logit_bias.items()):
|
||||
bias_array[i].token = token
|
||||
bias_array[i].bias = bias
|
||||
|
||||
sampler = llama_cpp.llama_sampler_init_logit_bias(
|
||||
n_vocab,
|
||||
len(logit_bias),
|
||||
bias_array
|
||||
)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_infill(self, model: LlamaModel):
|
||||
sampler = llama_cpp.llama_sampler_init_infill(model.vocab)
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
|
||||
def add_custom(
|
||||
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
|
||||
):
|
||||
custom_sampler = CustomSampler(apply_func)
|
||||
sampler = custom_sampler.get_sampler()
|
||||
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
||||
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
||||
self.custom_samplers.append(
|
||||
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
|
||||
)
|
||||
|
||||
def get_seed(self) -> int:
|
||||
return llama_cpp.llama_sampler_get_seed(self.sampler)
|
||||
|
||||
def sample(self, ctx: LlamaContext, idx: int = -1) -> int:
|
||||
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
|
||||
|
||||
def accept(self, token: int):
|
||||
llama_cpp.llama_sampler_accept(self.sampler, token)
|
||||
|
||||
def reset(self):
|
||||
llama_cpp.llama_sampler_reset(self.sampler)
|
||||
|
||||
def clone(self):
|
||||
# NOTE: Custom samplers cannot be cloned due to Python callback limitations
|
||||
if self.custom_samplers:
|
||||
raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers")
|
||||
|
||||
cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler)
|
||||
# Create a new wrapper around the cloned sampler
|
||||
new_sampler = LlamaSampler.__new__(LlamaSampler)
|
||||
new_sampler.sampler = cloned_sampler
|
||||
new_sampler.custom_samplers = []
|
||||
new_sampler._exit_stack = ExitStack()
|
||||
|
||||
def free_sampler():
|
||||
if new_sampler.sampler is not None:
|
||||
llama_cpp.llama_sampler_free(new_sampler.sampler)
|
||||
new_sampler.sampler = None
|
||||
|
||||
new_sampler._exit_stack.callback(free_sampler)
|
||||
return new_sampler
|
47
venv/Lib/site-packages/llama_cpp/_logger.py
Normal file
47
venv/Lib/site-packages/llama_cpp/_logger.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
import sys
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
import llama_cpp
|
||||
|
||||
# enum ggml_log_level {
|
||||
# GGML_LOG_LEVEL_NONE = 0,
|
||||
# GGML_LOG_LEVEL_INFO = 1,
|
||||
# GGML_LOG_LEVEL_WARN = 2,
|
||||
# GGML_LOG_LEVEL_ERROR = 3,
|
||||
# GGML_LOG_LEVEL_DEBUG = 4,
|
||||
# GGML_LOG_LEVEL_CONT = 5, // continue previous log
|
||||
# };
|
||||
GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
|
||||
0: logging.CRITICAL,
|
||||
1: logging.INFO,
|
||||
2: logging.WARNING,
|
||||
3: logging.ERROR,
|
||||
4: logging.DEBUG,
|
||||
5: logging.DEBUG,
|
||||
}
|
||||
|
||||
logger = logging.getLogger("llama-cpp-python")
|
||||
|
||||
_last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
|
||||
|
||||
# typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
|
||||
@llama_cpp.llama_log_callback
|
||||
def llama_log_callback(
|
||||
level: int,
|
||||
text: bytes,
|
||||
user_data: ctypes.c_void_p,
|
||||
):
|
||||
# TODO: Correctly implement continue previous log
|
||||
global _last_log_level
|
||||
log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
|
||||
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
|
||||
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
|
||||
_last_log_level = log_level
|
||||
|
||||
|
||||
llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
|
||||
|
||||
|
||||
def set_verbose(verbose: bool):
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
|
78
venv/Lib/site-packages/llama_cpp/_utils.py
Normal file
78
venv/Lib/site-packages/llama_cpp/_utils.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
|
||||
outnull_file = open(os.devnull, "w")
|
||||
errnull_file = open(os.devnull, "w")
|
||||
|
||||
STDOUT_FILENO = 1
|
||||
STDERR_FILENO = 2
|
||||
|
||||
|
||||
class suppress_stdout_stderr(object):
|
||||
# NOTE: these must be "saved" here to avoid exceptions when using
|
||||
# this context manager inside of a __del__ method
|
||||
sys = sys
|
||||
os = os
|
||||
|
||||
def __init__(self, disable: bool = True):
|
||||
self.disable = disable
|
||||
|
||||
# Oddly enough this works better than the contextlib version
|
||||
def __enter__(self):
|
||||
if self.disable:
|
||||
return self
|
||||
|
||||
self.old_stdout_fileno_undup = STDOUT_FILENO
|
||||
self.old_stderr_fileno_undup = STDERR_FILENO
|
||||
|
||||
self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
|
||||
self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
|
||||
|
||||
self.old_stdout = self.sys.stdout
|
||||
self.old_stderr = self.sys.stderr
|
||||
|
||||
self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
|
||||
self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
|
||||
|
||||
self.sys.stdout = outnull_file
|
||||
self.sys.stderr = errnull_file
|
||||
return self
|
||||
|
||||
def __exit__(self, *_):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
# Check if sys.stdout and sys.stderr have fileno method
|
||||
self.sys.stdout = self.old_stdout
|
||||
self.sys.stderr = self.old_stderr
|
||||
|
||||
self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
||||
self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
||||
|
||||
self.os.close(self.old_stdout_fileno)
|
||||
self.os.close(self.old_stderr_fileno)
|
||||
|
||||
|
||||
class MetaSingleton(type):
|
||||
"""
|
||||
Metaclass for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
_instances: Dict[type, Any] = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
class Singleton(object, metaclass=MetaSingleton):
|
||||
"""
|
||||
Base class for implementing the Singleton pattern.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Singleton, self).__init__()
|
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-base.dll
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-base.dll
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-base.lib
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-base.lib
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-cpu.dll
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-cpu.dll
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-cpu.lib
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml-cpu.lib
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml.dll
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml.dll
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml.lib
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/ggml.lib
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/llama.dll
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/llama.dll
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/llama.lib
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/llama.lib
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/mtmd.dll
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/mtmd.dll
Normal file
Binary file not shown.
BIN
venv/Lib/site-packages/llama_cpp/lib/mtmd.lib
Normal file
BIN
venv/Lib/site-packages/llama_cpp/lib/mtmd.lib
Normal file
Binary file not shown.
2422
venv/Lib/site-packages/llama_cpp/llama.py
Normal file
2422
venv/Lib/site-packages/llama_cpp/llama.py
Normal file
File diff suppressed because it is too large
Load diff
155
venv/Lib/site-packages/llama_cpp/llama_cache.py
Normal file
155
venv/Lib/site-packages/llama_cpp/llama_cache.py
Normal file
|
@ -0,0 +1,155 @@
|
|||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
import diskcache
|
||||
|
||||
import llama_cpp.llama
|
||||
|
||||
from .llama_types import *
|
||||
|
||||
|
||||
class BaseLlamaCache(ABC):
|
||||
"""Base cache class for a llama.cpp model."""
|
||||
|
||||
def __init__(self, capacity_bytes: int = (2 << 30)):
|
||||
self.capacity_bytes = capacity_bytes
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def _find_longest_prefix_key(
|
||||
self,
|
||||
key: Tuple[int, ...],
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, key: Sequence[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __setitem__(
|
||||
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LlamaRAMCache(BaseLlamaCache):
|
||||
"""Cache for a llama.cpp model using RAM."""
|
||||
|
||||
def __init__(self, capacity_bytes: int = (2 << 30)):
|
||||
super().__init__(capacity_bytes)
|
||||
self.capacity_bytes = capacity_bytes
|
||||
self.cache_state: OrderedDict[
|
||||
Tuple[int, ...], "llama_cpp.llama.LlamaState"
|
||||
] = OrderedDict()
|
||||
|
||||
@property
|
||||
def cache_size(self):
|
||||
return sum([state.llama_state_size for state in self.cache_state.values()])
|
||||
|
||||
def _find_longest_prefix_key(
|
||||
self,
|
||||
key: Tuple[int, ...],
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
min_len = 0
|
||||
min_key = None
|
||||
keys = (
|
||||
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
|
||||
for k in self.cache_state.keys()
|
||||
)
|
||||
for k, prefix_len in keys:
|
||||
if prefix_len > min_len:
|
||||
min_len = prefix_len
|
||||
min_key = k
|
||||
return min_key
|
||||
|
||||
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
||||
key = tuple(key)
|
||||
_key = self._find_longest_prefix_key(key)
|
||||
if _key is None:
|
||||
raise KeyError("Key not found")
|
||||
value = self.cache_state[_key]
|
||||
self.cache_state.move_to_end(_key)
|
||||
return value
|
||||
|
||||
def __contains__(self, key: Sequence[int]) -> bool:
|
||||
return self._find_longest_prefix_key(tuple(key)) is not None
|
||||
|
||||
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
|
||||
key = tuple(key)
|
||||
if key in self.cache_state:
|
||||
del self.cache_state[key]
|
||||
self.cache_state[key] = value
|
||||
while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0:
|
||||
self.cache_state.popitem(last=False)
|
||||
|
||||
|
||||
# Alias for backwards compatibility
|
||||
LlamaCache = LlamaRAMCache
|
||||
|
||||
|
||||
class LlamaDiskCache(BaseLlamaCache):
|
||||
"""Cache for a llama.cpp model using disk."""
|
||||
|
||||
def __init__(
|
||||
self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30)
|
||||
):
|
||||
super().__init__(capacity_bytes)
|
||||
self.cache = diskcache.Cache(cache_dir)
|
||||
|
||||
@property
|
||||
def cache_size(self):
|
||||
return int(self.cache.volume()) # type: ignore
|
||||
|
||||
def _find_longest_prefix_key(
|
||||
self,
|
||||
key: Tuple[int, ...],
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
min_len = 0
|
||||
min_key: Optional[Tuple[int, ...]] = None
|
||||
for k in self.cache.iterkeys(): # type: ignore
|
||||
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
|
||||
if prefix_len > min_len:
|
||||
min_len = prefix_len
|
||||
min_key = k # type: ignore
|
||||
return min_key
|
||||
|
||||
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
|
||||
key = tuple(key)
|
||||
_key = self._find_longest_prefix_key(key)
|
||||
if _key is None:
|
||||
raise KeyError("Key not found")
|
||||
value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore
|
||||
# NOTE: This puts an integer as key in cache, which breaks,
|
||||
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
||||
# self.cache.push(_key, side="front") # type: ignore
|
||||
return value
|
||||
|
||||
def __contains__(self, key: Sequence[int]) -> bool:
|
||||
return self._find_longest_prefix_key(tuple(key)) is not None
|
||||
|
||||
def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
|
||||
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
|
||||
key = tuple(key)
|
||||
if key in self.cache:
|
||||
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
|
||||
del self.cache[key]
|
||||
self.cache[key] = value
|
||||
print("LlamaDiskCache.__setitem__: set", file=sys.stderr)
|
||||
while self.cache_size > self.capacity_bytes and len(self.cache) > 0:
|
||||
key_to_remove = next(iter(self.cache))
|
||||
del self.cache[key_to_remove]
|
||||
print("LlamaDiskCache.__setitem__: trim", file=sys.stderr)
|
3956
venv/Lib/site-packages/llama_cpp/llama_chat_format.py
Normal file
3956
venv/Lib/site-packages/llama_cpp/llama_chat_format.py
Normal file
File diff suppressed because it is too large
Load diff
4343
venv/Lib/site-packages/llama_cpp/llama_cpp.py
Normal file
4343
venv/Lib/site-packages/llama_cpp/llama_cpp.py
Normal file
File diff suppressed because it is too large
Load diff
953
venv/Lib/site-packages/llama_cpp/llama_grammar.py
Normal file
953
venv/Lib/site-packages/llama_cpp/llama_grammar.py
Normal file
|
@ -0,0 +1,953 @@
|
|||
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
|
||||
|
||||
# flake8: noqa
|
||||
from pathlib import Path
|
||||
|
||||
from itertools import groupby
|
||||
from typing import (
|
||||
Any,
|
||||
Set,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
LLAMA_GRAMMAR_DEFAULT_ROOT = "root"
|
||||
|
||||
|
||||
class LlamaGrammar:
|
||||
def __init__(self, *args, _grammar: str, **kwargs):
|
||||
self._grammar = _grammar
|
||||
self._root = LLAMA_GRAMMAR_DEFAULT_ROOT
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
||||
return cls(_grammar=grammar)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
|
||||
try:
|
||||
with open(file) as f:
|
||||
grammar = f.read()
|
||||
except Exception as err:
|
||||
raise Exception(
|
||||
f"{cls.from_file.__name__}: error reading grammar file: {err}"
|
||||
)
|
||||
|
||||
if grammar:
|
||||
return cls.from_string(grammar, verbose=verbose)
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
|
||||
return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
|
||||
|
||||
|
||||
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
|
||||
|
||||
ARITHMETIC_GBNF = r"""
|
||||
root ::= (expr "=" ws term "\n")+
|
||||
expr ::= term ([-+*/] term)*
|
||||
term ::= ident | num | "(" ws expr ")" ws
|
||||
ident ::= [a-z] [a-z0-9_]* ws
|
||||
num ::= [0-9]+ ws
|
||||
ws ::= [ \t\n]*
|
||||
"""
|
||||
|
||||
C_GBNF = r"""
|
||||
root ::= (declaration)*
|
||||
|
||||
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
|
||||
|
||||
dataType ::= "int" ws | "float" ws | "char" ws
|
||||
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||
|
||||
parameter ::= dataType identifier
|
||||
|
||||
statement ::=
|
||||
( dataType identifier ws "=" ws expression ";" ) |
|
||||
( identifier ws "=" ws expression ";" ) |
|
||||
( identifier ws "(" argList? ")" ";" ) |
|
||||
( "return" ws expression ";" ) |
|
||||
( "while" "(" condition ")" "{" statement* "}" ) |
|
||||
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
|
||||
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
|
||||
( singleLineComment ) |
|
||||
( multiLineComment )
|
||||
|
||||
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
|
||||
forUpdate ::= identifier ws "=" ws expression
|
||||
|
||||
condition ::= expression relationOperator expression
|
||||
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
|
||||
|
||||
expression ::= term (("+" | "-") term)*
|
||||
term ::= factor(("*" | "/") factor)*
|
||||
|
||||
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
|
||||
unaryTerm ::= "-" factor
|
||||
funcCall ::= identifier "(" argList? ")"
|
||||
parenExpression ::= "(" ws expression ws ")"
|
||||
|
||||
argList ::= expression ("," ws expression)*
|
||||
|
||||
number ::= [0-9]+
|
||||
|
||||
singleLineComment ::= "//" [^\n]* "\n"
|
||||
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
|
||||
|
||||
ws ::= ([ \t\n]+)
|
||||
"""
|
||||
|
||||
CHESS_GBNF = r"""
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
"""
|
||||
|
||||
JAPANESE_GBNF = r"""
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
"""
|
||||
|
||||
JSON_ARR_GBNF = r"""
|
||||
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
|
||||
# Useful for generating JSON arrays
|
||||
|
||||
root ::= arr
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
arr ::=
|
||||
"[\n" ws (
|
||||
value
|
||||
(",\n" ws value)*
|
||||
)? "]"
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
"""
|
||||
|
||||
|
||||
JSON_GBNF = r"""
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= | " " | "\n" [ \t]{0,20}
|
||||
"""
|
||||
|
||||
LIST_GBNF = r"""
|
||||
root ::= item+
|
||||
|
||||
# Excludes various line break characters
|
||||
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
||||
"""
|
||||
|
||||
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||
# whitespace. Also maybe improves generation quality?
|
||||
SPACE_RULE = '" "?'
|
||||
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
||||
|
||||
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||
# whitespace. Also maybe improves generation quality?
|
||||
SPACE_RULE = '" "?'
|
||||
|
||||
|
||||
def _build_repetition(
|
||||
item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False
|
||||
):
|
||||
if not separator_rule:
|
||||
if min_items == 0 and max_items == 1:
|
||||
return f"{item_rule}?"
|
||||
elif min_items == 1 and max_items is None:
|
||||
return f"{item_rule}+"
|
||||
|
||||
result = ""
|
||||
|
||||
if min_items > 0:
|
||||
if item_rule_is_literal and separator_rule is None:
|
||||
result = '"' + (item_rule[1:-1] * min_items) + '"'
|
||||
else:
|
||||
result = (f" {separator_rule} " if separator_rule else " ").join(
|
||||
[item_rule] * min_items
|
||||
)
|
||||
|
||||
def opt_repetitions(up_to_n, prefix_with_sep=False):
|
||||
"""
|
||||
- n=4, no sep: '(a (a (a (a)?)?)?)?'
|
||||
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
|
||||
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
|
||||
"""
|
||||
|
||||
content = (
|
||||
f"{separator_rule} {item_rule}"
|
||||
if prefix_with_sep and separator_rule
|
||||
else item_rule
|
||||
)
|
||||
if up_to_n == 0:
|
||||
return ""
|
||||
elif up_to_n == 1:
|
||||
return f"({content})?"
|
||||
elif separator_rule and not prefix_with_sep:
|
||||
return f"({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?"
|
||||
else:
|
||||
return (f"({content} " * up_to_n).rstrip() + (")?" * up_to_n)
|
||||
|
||||
if min_items > 0 and max_items != min_items:
|
||||
result += " "
|
||||
|
||||
if max_items is not None:
|
||||
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
|
||||
else:
|
||||
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
|
||||
|
||||
if min_items == 0 and separator_rule:
|
||||
result = f"({item_rule} {item_operator}*)?"
|
||||
else:
|
||||
result += f"{item_operator}*"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list = None):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
|
||||
_up_to_15_digits = _build_repetition("[0-9]", 0, 15)
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
"boolean": BuiltinRule('("true" | "false") space', []),
|
||||
"decimal-part": BuiltinRule("[0-9] " + _up_to_15_digits, []),
|
||||
"integral-part": BuiltinRule("[0-9] | [1-9] " + _up_to_15_digits, []),
|
||||
"number": BuiltinRule(
|
||||
'("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
|
||||
["integral-part", "decimal-part"],
|
||||
),
|
||||
"integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]),
|
||||
"value": BuiltinRule(
|
||||
"object | array | string | number | boolean | null",
|
||||
["object", "array", "string", "number", "boolean", "null"],
|
||||
),
|
||||
"object": BuiltinRule(
|
||||
'"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
|
||||
["string", "value"],
|
||||
),
|
||||
"array": BuiltinRule(
|
||||
'"[" space ( value ("," space value)* )? "]" space', ["value"]
|
||||
),
|
||||
"uuid": BuiltinRule(
|
||||
r'"\"" '
|
||||
+ ' "-" '.join("[0-9a-fA-F]" * n for n in [8, 4, 4, 4, 12])
|
||||
+ r' "\"" space',
|
||||
[],
|
||||
),
|
||||
"char": BuiltinRule(
|
||||
r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])',
|
||||
[],
|
||||
),
|
||||
"string": BuiltinRule(r'"\"" char* "\"" space', ["char"]),
|
||||
"null": BuiltinRule('"null" space', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
STRING_FORMAT_RULES = {
|
||||
"date": BuiltinRule(
|
||||
'[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )',
|
||||
[],
|
||||
),
|
||||
"time": BuiltinRule(
|
||||
'([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
|
||||
[],
|
||||
),
|
||||
"date-time": BuiltinRule('date "T" time', ["date", "time"]),
|
||||
"date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]),
|
||||
"time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]),
|
||||
"date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]),
|
||||
}
|
||||
|
||||
DOTALL = "[\\U00000000-\\U0010FFFF]"
|
||||
DOT = "[^\\x0A\\x0D]"
|
||||
|
||||
RESERVED_NAMES = set(
|
||||
["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]
|
||||
)
|
||||
|
||||
|
||||
NON_LITERAL_SET = set("|.()[]{}*+?")
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("[]()|{}*+?")
|
||||
|
||||
|
||||
class SchemaConverter:
|
||||
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
||||
self._prop_order = prop_order
|
||||
self._allow_fetch = allow_fetch
|
||||
self._dotall = dotall
|
||||
self._raw_pattern = raw_pattern
|
||||
self._rules = {
|
||||
"space": SPACE_RULE,
|
||||
}
|
||||
self._refs = {}
|
||||
self._refs_being_resolved = set()
|
||||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
def not_literal(
|
||||
self, literal: str, dotall: bool = True, maybe_escaped_underscores=False
|
||||
) -> str:
|
||||
"""
|
||||
not_literal('a') -> '[^a]'
|
||||
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||
"""
|
||||
assert len(literal) > 0, "Empty literal not supported"
|
||||
|
||||
def recurse(i: int):
|
||||
c = literal[i]
|
||||
if maybe_escaped_underscores and c == "_":
|
||||
yield f"[^{c}\\\\]"
|
||||
yield " | "
|
||||
yield f'"\\\\"? "{c}"'
|
||||
else:
|
||||
yield f"[^{c}]"
|
||||
if i < len(literal) - 1:
|
||||
yield " | "
|
||||
yield self._format_literal(c)
|
||||
yield " ("
|
||||
yield from recurse(i + 1)
|
||||
yield ")?"
|
||||
|
||||
return "".join(("(", *recurse(0), ")"))
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||
key = esc_name
|
||||
else:
|
||||
i = 0
|
||||
while (
|
||||
f"{esc_name}{i}" in self._rules
|
||||
and self._rules[f"{esc_name}{i}"] != rule
|
||||
):
|
||||
i += 1
|
||||
key = f"{esc_name}{i}"
|
||||
self._rules[key] = rule
|
||||
return key
|
||||
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
"""
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
"""
|
||||
|
||||
def visit(n: dict):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
ref = n.get("$ref")
|
||||
if ref is not None and ref not in self._refs:
|
||||
if ref.startswith("https://"):
|
||||
assert (
|
||||
self._allow_fetch
|
||||
), "Fetching remote schemas is not allowed (use --allow-fetch for force)"
|
||||
import requests
|
||||
|
||||
frag_split = ref.split("#")
|
||||
base_url = frag_split[0]
|
||||
|
||||
target = self._refs.get(base_url)
|
||||
if target is None:
|
||||
target = self.resolve_refs(
|
||||
requests.get(ref).json(), base_url
|
||||
)
|
||||
self._refs[base_url] = target
|
||||
|
||||
if len(frag_split) == 1 or frag_split[-1] == "":
|
||||
return target
|
||||
elif ref.startswith("#/"):
|
||||
target = schema
|
||||
ref = f"{url}{ref}"
|
||||
n["$ref"] = ref
|
||||
else:
|
||||
raise ValueError(f"Unsupported ref {ref}")
|
||||
|
||||
for sel in ref.split("#")[-1].split("/")[1:]:
|
||||
assert (
|
||||
target is not None and sel in target
|
||||
), f"Error resolving ref {ref}: {sel} not in {target}"
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
for v in n.values():
|
||||
visit(v)
|
||||
|
||||
return n
|
||||
|
||||
return visit(schema)
|
||||
|
||||
def _generate_union_rule(self, name, alt_schemas):
|
||||
return " | ".join(
|
||||
(
|
||||
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
||||
for i, alt_schema in enumerate(alt_schemas)
|
||||
)
|
||||
)
|
||||
|
||||
def _visit_pattern(self, pattern, name):
|
||||
"""
|
||||
Transforms a regular expression pattern into a GBNF rule.
|
||||
|
||||
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
||||
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
|
||||
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
||||
|
||||
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
||||
we define sub-rules to keep the output lean.
|
||||
"""
|
||||
|
||||
assert pattern.startswith("^") and pattern.endswith(
|
||||
"$"
|
||||
), 'Pattern must start with "^" and end with "$"'
|
||||
pattern = pattern[1:-1]
|
||||
sub_rule_ids = {}
|
||||
|
||||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: Tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return '"' + txt + '"' if is_literal else txt
|
||||
|
||||
def transform() -> Tuple[str, bool]:
|
||||
"""
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
"""
|
||||
nonlocal i
|
||||
nonlocal pattern
|
||||
nonlocal sub_rule_ids
|
||||
|
||||
start = i
|
||||
# For each component of this sequence, store its string representation and whether it's a literal.
|
||||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[Tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
rule = DOTALL
|
||||
else:
|
||||
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||
rule = DOT
|
||||
return self._add_rule(f"dot", rule)
|
||||
|
||||
def join_seq():
|
||||
nonlocal seq
|
||||
ret = []
|
||||
for is_literal, g in groupby(seq, lambda x: x[1]):
|
||||
if is_literal:
|
||||
ret.append(("".join(x[0] for x in g), True))
|
||||
else:
|
||||
ret.extend(g)
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
return (" ".join(to_rule(x) for x in seq), False)
|
||||
|
||||
while i < length:
|
||||
c = pattern[i]
|
||||
if c == ".":
|
||||
seq.append((get_dot(), False))
|
||||
i += 1
|
||||
elif c == "(":
|
||||
i += 1
|
||||
if i < length:
|
||||
assert (
|
||||
pattern[i] != "?"
|
||||
), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
||||
seq.append((f"({to_rule(transform())})", False))
|
||||
elif c == ")":
|
||||
i += 1
|
||||
assert (
|
||||
start > 0 and pattern[start - 1] == "("
|
||||
), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}"
|
||||
return join_seq()
|
||||
elif c == "[":
|
||||
square_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != "]":
|
||||
if pattern[i] == "\\":
|
||||
square_brackets += pattern[i : i + 2]
|
||||
i += 2
|
||||
else:
|
||||
square_brackets += pattern[i]
|
||||
i += 1
|
||||
assert (
|
||||
i < length
|
||||
), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}"
|
||||
square_brackets += "]"
|
||||
i += 1
|
||||
seq.append((square_brackets, False))
|
||||
elif c == "|":
|
||||
seq.append(("|", False))
|
||||
i += 1
|
||||
elif c in ("*", "+", "?"):
|
||||
seq[-1] = (to_rule(seq[-1]) + c, False)
|
||||
i += 1
|
||||
elif c == "{":
|
||||
curly_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != "}":
|
||||
curly_brackets += pattern[i]
|
||||
i += 1
|
||||
assert (
|
||||
i < length
|
||||
), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}"
|
||||
curly_brackets += "}"
|
||||
i += 1
|
||||
nums = [s.strip() for s in curly_brackets[1:-1].split(",")]
|
||||
min_times = 0
|
||||
max_times = None
|
||||
try:
|
||||
if len(nums) == 1:
|
||||
min_times = int(nums[0])
|
||||
max_times = min_times
|
||||
else:
|
||||
assert len(nums) == 2
|
||||
min_times = int(nums[0]) if nums[0] else 0
|
||||
max_times = int(nums[1]) if nums[1] else None
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid quantifier {curly_brackets} in /{pattern}/"
|
||||
)
|
||||
|
||||
(sub, sub_is_literal) = seq[-1]
|
||||
|
||||
if not sub_is_literal:
|
||||
id = sub_rule_ids.get(sub)
|
||||
if id is None:
|
||||
id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub)
|
||||
sub_rule_ids[sub] = id
|
||||
sub = id
|
||||
|
||||
seq[-1] = (
|
||||
_build_repetition(
|
||||
f'"{sub}"' if sub_is_literal else sub,
|
||||
min_times,
|
||||
max_times,
|
||||
item_rule_is_literal=sub_is_literal,
|
||||
),
|
||||
False,
|
||||
)
|
||||
else:
|
||||
literal = ""
|
||||
while i < length:
|
||||
if pattern[i] == "\\" and i < length - 1:
|
||||
next = pattern[i + 1]
|
||||
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
||||
i += 1
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
literal += pattern[i : i + 2]
|
||||
i += 2
|
||||
elif pattern[i] == '"' and not self._raw_pattern:
|
||||
literal += '\\"'
|
||||
i += 1
|
||||
elif pattern[i] not in NON_LITERAL_SET and (
|
||||
i == length - 1
|
||||
or literal == ""
|
||||
or pattern[i + 1] == "."
|
||||
or pattern[i + 1] not in NON_LITERAL_SET
|
||||
):
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
if literal:
|
||||
seq.append((literal, True))
|
||||
|
||||
return join_seq()
|
||||
|
||||
return self._add_rule(
|
||||
name,
|
||||
(
|
||||
to_rule(transform())
|
||||
if self._raw_pattern
|
||||
else '"\\"" ' + to_rule(transform()) + ' "\\"" space'
|
||||
),
|
||||
)
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_name = ref.split("/")[-1]
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
ref_name = self.visit(resolved, ref_name)
|
||||
self._refs_being_resolved.remove(ref)
|
||||
return ref_name
|
||||
|
||||
def _generate_constant_rule(self, value):
|
||||
return self._format_literal(json.dumps(value))
|
||||
|
||||
def visit(self, schema, name):
|
||||
schema_type = schema.get("type")
|
||||
schema_format = schema.get("format")
|
||||
rule_name = name + "-" if name in RESERVED_NAMES else name or "root"
|
||||
|
||||
if (ref := schema.get("$ref")) is not None:
|
||||
return self._add_rule(rule_name, self._resolve_ref(ref))
|
||||
|
||||
elif "oneOf" in schema or "anyOf" in schema:
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]),
|
||||
)
|
||||
|
||||
elif isinstance(schema_type, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
self._generate_union_rule(name, [{"type": t} for t in schema_type]),
|
||||
)
|
||||
|
||||
elif "const" in schema:
|
||||
return self._add_rule(
|
||||
rule_name, self._generate_constant_rule(schema["const"])
|
||||
)
|
||||
|
||||
elif "enum" in schema:
|
||||
rule = " | ".join((self._generate_constant_rule(v) for v in schema["enum"]))
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, "object") and (
|
||||
"properties" in schema
|
||||
or (
|
||||
"additionalProperties" in schema
|
||||
and schema["additionalProperties"] is not True
|
||||
)
|
||||
):
|
||||
required = set(schema.get("required", []))
|
||||
properties = list(schema.get("properties", {}).items())
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
self._build_object_rule(
|
||||
properties, required, name, schema.get("additionalProperties")
|
||||
),
|
||||
)
|
||||
|
||||
elif schema_type in (None, "object") and "allOf" in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
hybrid_name = name
|
||||
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get("$ref")) is not None:
|
||||
comp_schema = self._refs[ref]
|
||||
|
||||
if "properties" in comp_schema:
|
||||
for prop_name, prop_schema in comp_schema["properties"].items():
|
||||
properties.append((prop_name, prop_schema))
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
for t in schema["allOf"]:
|
||||
if "anyOf" in t:
|
||||
for tt in t["anyOf"]:
|
||||
add_component(tt, is_required=False)
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
self._build_object_rule(
|
||||
properties, required, hybrid_name, additional_properties=[]
|
||||
),
|
||||
)
|
||||
|
||||
elif schema_type in (None, "array") and (
|
||||
"items" in schema or "prefixItems" in schema
|
||||
):
|
||||
items = schema.get("items") or schema["prefixItems"]
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
'"[" space '
|
||||
+ ' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)
|
||||
)
|
||||
+ ' "]" space',
|
||||
)
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
'"[" space '
|
||||
+ _build_repetition(
|
||||
item_rule_name, min_items, max_items, separator_rule='"," space'
|
||||
)
|
||||
+ ' "]" space',
|
||||
)
|
||||
|
||||
elif schema_type in (None, "string") and "pattern" in schema:
|
||||
return self._visit_pattern(schema["pattern"], rule_name)
|
||||
|
||||
elif schema_type in (None, "string") and re.match(
|
||||
r"^uuid[1-5]?$", schema_format or ""
|
||||
):
|
||||
return self._add_primitive(
|
||||
"root" if rule_name == "root" else schema_format,
|
||||
PRIMITIVE_RULES["uuid"],
|
||||
)
|
||||
|
||||
elif (
|
||||
schema_type in (None, "string")
|
||||
and f"{schema_format}-string" in STRING_FORMAT_RULES
|
||||
):
|
||||
prim_name = f"{schema_format}-string"
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]),
|
||||
)
|
||||
|
||||
elif schema_type == "string" and (
|
||||
"minLength" in schema or "maxLength" in schema
|
||||
):
|
||||
char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"])
|
||||
min_len = schema.get("minLength", 0)
|
||||
max_len = schema.get("maxLength")
|
||||
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
r'"\"" '
|
||||
+ _build_repetition(char_rule, min_len, max_len)
|
||||
+ r' "\"" space',
|
||||
)
|
||||
|
||||
elif (schema_type == "object") or (len(schema) == 0):
|
||||
return self._add_rule(
|
||||
rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"])
|
||||
)
|
||||
|
||||
else:
|
||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
||||
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return self._add_primitive(
|
||||
"root" if rule_name == "root" else schema_type,
|
||||
PRIMITIVE_RULES[schema_type],
|
||||
)
|
||||
|
||||
def _add_primitive(self, name: str, rule: BuiltinRule):
|
||||
n = self._add_rule(name, rule.content)
|
||||
|
||||
for dep in rule.deps:
|
||||
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
|
||||
assert dep_rule, f"Rule {dep} not known"
|
||||
if dep not in self._rules:
|
||||
self._add_primitive(dep, dep_rule)
|
||||
return n
|
||||
|
||||
def _build_object_rule(
|
||||
self,
|
||||
properties: List[Tuple[str, Any]],
|
||||
required: Set[str],
|
||||
name: str,
|
||||
additional_properties: Union[bool, Any],
|
||||
):
|
||||
prop_order = self._prop_order
|
||||
# sort by position in prop_order (if specified) then by original order
|
||||
sorted_props = [
|
||||
kv[0]
|
||||
for _, kv in sorted(
|
||||
enumerate(properties),
|
||||
key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]),
|
||||
)
|
||||
]
|
||||
|
||||
prop_kv_rule_names = {}
|
||||
for prop_name, prop_schema in properties:
|
||||
prop_rule_name = self.visit(
|
||||
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
||||
)
|
||||
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||
rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}',
|
||||
)
|
||||
required_props = [k for k in sorted_props if k in required]
|
||||
optional_props = [k for k in sorted_props if k not in required]
|
||||
|
||||
if additional_properties == True or isinstance(additional_properties, dict):
|
||||
sub_name = f'{name}{"-" if name else ""}additional'
|
||||
value_rule = self.visit(
|
||||
{} if additional_properties == True else additional_properties,
|
||||
f"{sub_name}-value",
|
||||
)
|
||||
prop_kv_rule_names["*"] = self._add_rule(
|
||||
f"{sub_name}-kv",
|
||||
self._add_primitive("string", PRIMITIVE_RULES["string"])
|
||||
+ f' ":" space {value_rule}',
|
||||
)
|
||||
optional_props.append("*")
|
||||
|
||||
rule = '"{" space '
|
||||
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
||||
|
||||
if optional_props:
|
||||
rule += " ("
|
||||
if required_props:
|
||||
rule += ' "," space ( '
|
||||
|
||||
def get_recursive_refs(ks, first_is_optional):
|
||||
[k, *rest] = ks
|
||||
kv_rule_name = prop_kv_rule_names[k]
|
||||
if k == "*":
|
||||
res = self._add_rule(
|
||||
f'{name}{"-" if name else ""}additional-kvs',
|
||||
f'{kv_rule_name} ( "," space ' + kv_rule_name + " )*",
|
||||
)
|
||||
elif first_is_optional:
|
||||
res = f'( "," space {kv_rule_name} )?'
|
||||
else:
|
||||
res = kv_rule_name
|
||||
if len(rest) > 0:
|
||||
res += " " + self._add_rule(
|
||||
f'{name}{"-" if name else ""}{k}-rest',
|
||||
get_recursive_refs(rest, first_is_optional=True),
|
||||
)
|
||||
return res
|
||||
|
||||
rule += " | ".join(
|
||||
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
||||
for i in range(len(optional_props))
|
||||
)
|
||||
if required_props:
|
||||
rule += " )"
|
||||
rule += " )?"
|
||||
|
||||
rule += ' "}" space'
|
||||
|
||||
return rule
|
||||
|
||||
def format_grammar(self):
|
||||
return "\n".join(
|
||||
f"{name} ::= {rule}"
|
||||
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
||||
)
|
||||
|
||||
|
||||
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
||||
prop_order = prop_order or []
|
||||
schema = json.loads(schema)
|
||||
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
||||
converter = SchemaConverter(
|
||||
prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False
|
||||
)
|
||||
schema = converter.resolve_refs(schema, "stdin")
|
||||
converter.visit(schema, "")
|
||||
return converter.format_grammar()
|
64
venv/Lib/site-packages/llama_cpp/llama_speculative.py
Normal file
64
venv/Lib/site-packages/llama_cpp/llama_speculative.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import abc
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class LlamaDraftModel(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
||||
) -> npt.NDArray[np.intc]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LlamaPromptLookupDecoding(LlamaDraftModel):
|
||||
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
|
||||
|
||||
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
|
||||
self.max_ngram_size = max_ngram_size
|
||||
self.num_pred_tokens = num_pred_tokens
|
||||
|
||||
@staticmethod
|
||||
def find_candidate_pred_tokens(
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
max_ngram_size: int,
|
||||
num_pred_tokens: int,
|
||||
):
|
||||
input_length = input_ids.shape[0]
|
||||
|
||||
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
|
||||
# Create sliding windows of size ngram_size
|
||||
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
|
||||
|
||||
# Convert ngram to an array for comparison
|
||||
ngram_array = input_ids[-ngram_size:]
|
||||
|
||||
# Find where the windows match the ngram
|
||||
matches = np.all(windows == ngram_array, axis=1)
|
||||
|
||||
# Get the indices of matches
|
||||
match_indices = np.nonzero(matches)[0]
|
||||
|
||||
# Iterate through match indices to find a valid continuation
|
||||
for idx in match_indices:
|
||||
start_idx = idx + ngram_size
|
||||
end_idx = start_idx + num_pred_tokens
|
||||
end_idx = min(end_idx, input_length)
|
||||
|
||||
if start_idx < end_idx:
|
||||
return input_ids[start_idx:end_idx]
|
||||
|
||||
# If no match is found, return an empty array
|
||||
return np.array([], dtype=np.intc)
|
||||
|
||||
def __call__(
|
||||
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
||||
) -> npt.NDArray[np.intc]:
|
||||
return self.find_candidate_pred_tokens(
|
||||
input_ids=input_ids,
|
||||
max_ngram_size=self.max_ngram_size,
|
||||
num_pred_tokens=self.num_pred_tokens,
|
||||
)
|
120
venv/Lib/site-packages/llama_cpp/llama_tokenizer.py
Normal file
120
venv/Lib/site-packages/llama_cpp/llama_tokenizer.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Any,
|
||||
)
|
||||
|
||||
import llama_cpp
|
||||
from llama_cpp.llama_types import List
|
||||
|
||||
|
||||
class BaseLlamaTokenizer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def tokenize(
|
||||
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||
) -> List[int]:
|
||||
"""Tokenize the text into tokens.
|
||||
|
||||
Args:
|
||||
text: The utf-8 encoded string to tokenize.
|
||||
add_bos: Whether to add a beginning of sequence token.
|
||||
special: Whether to tokenize special tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def detokenize(
|
||||
self,
|
||||
tokens: List[int],
|
||||
prev_tokens: Optional[List[int]] = None,
|
||||
special: bool = False,
|
||||
) -> bytes:
|
||||
"""Detokenize the tokens into text.
|
||||
|
||||
Args:
|
||||
tokens: The list of tokens to detokenize.
|
||||
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided.
|
||||
special: Whether to detokenize special tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LlamaTokenizer(BaseLlamaTokenizer):
|
||||
def __init__(self, llama: llama_cpp.Llama):
|
||||
self._model = llama._model # type: ignore
|
||||
|
||||
def tokenize(
|
||||
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||
) -> List[int]:
|
||||
return self._model.tokenize(text, add_bos=add_bos, special=special)
|
||||
|
||||
def detokenize(
|
||||
self,
|
||||
tokens: List[int],
|
||||
prev_tokens: Optional[List[int]] = None,
|
||||
special: bool = False,
|
||||
) -> bytes:
|
||||
return self._model.detokenize(tokens, special=special)
|
||||
|
||||
def encode(
|
||||
self, text: str, add_bos: bool = True, special: bool = True
|
||||
) -> List[int]:
|
||||
return self.tokenize(
|
||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
|
||||
)
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
return self.detokenize(tokens).decode("utf-8", errors="ignore")
|
||||
|
||||
@classmethod
|
||||
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
||||
return cls(llama_cpp.Llama(model_path=path, vocab_only=True))
|
||||
|
||||
|
||||
class LlamaHFTokenizer(BaseLlamaTokenizer):
|
||||
def __init__(self, hf_tokenizer: Any):
|
||||
self.hf_tokenizer = hf_tokenizer
|
||||
|
||||
def tokenize(
|
||||
self, text: bytes, add_bos: bool = True, special: bool = True
|
||||
) -> List[int]:
|
||||
return self.hf_tokenizer.encode(
|
||||
text.decode("utf-8", errors="ignore"), add_special_tokens=special
|
||||
)
|
||||
|
||||
def detokenize(
|
||||
self,
|
||||
tokens: List[int],
|
||||
prev_tokens: Optional[List[int]] = None,
|
||||
special: bool = False,
|
||||
) -> bytes:
|
||||
skip_special_tokens = not special
|
||||
if prev_tokens is not None:
|
||||
text = self.hf_tokenizer.decode(
|
||||
prev_tokens + tokens, skip_special_tokens=skip_special_tokens
|
||||
).encode("utf-8", errors="ignore")
|
||||
prev_text = self.hf_tokenizer.decode(
|
||||
prev_tokens, skip_special_tokens=skip_special_tokens
|
||||
).encode("utf-8", errors="ignore")
|
||||
return text[len(prev_text) :]
|
||||
else:
|
||||
return self.hf_tokenizer.decode(
|
||||
tokens, skip_special_tokens=skip_special_tokens
|
||||
).encode("utf-8", errors="ignore")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library is required to use the `HFTokenizer`."
|
||||
"You can install it with `pip install transformers`."
|
||||
)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path
|
||||
)
|
||||
return cls(hf_tokenizer)
|
316
venv/Lib/site-packages/llama_cpp/llama_types.py
Normal file
316
venv/Lib/site-packages/llama_cpp/llama_types.py
Normal file
|
@ -0,0 +1,316 @@
|
|||
"""Types and request signatures for OpenAI compatibility
|
||||
|
||||
NOTE: These types may change to match the OpenAI OpenAPI specification.
|
||||
|
||||
Based on the OpenAI OpenAPI specification:
|
||||
https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Dict, Union
|
||||
from typing_extensions import TypedDict, NotRequired, Literal
|
||||
|
||||
|
||||
# NOTE: Defining this correctly using annotations seems to break pydantic validation.
|
||||
# This is a workaround until we can figure out how to do this correctly
|
||||
# JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
|
||||
JsonType = Union[None, int, str, bool, List[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
class EmbeddingUsage(TypedDict):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class Embedding(TypedDict):
|
||||
index: int
|
||||
object: str
|
||||
embedding: Union[List[float], List[List[float]]]
|
||||
|
||||
|
||||
class CreateEmbeddingResponse(TypedDict):
|
||||
object: Literal["list"]
|
||||
model: str
|
||||
data: List[Embedding]
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
class CompletionLogprobs(TypedDict):
|
||||
text_offset: List[int]
|
||||
token_logprobs: List[Optional[float]]
|
||||
tokens: List[str]
|
||||
top_logprobs: List[Optional[Dict[str, float]]]
|
||||
|
||||
|
||||
class CompletionChoice(TypedDict):
|
||||
text: str
|
||||
index: int
|
||||
logprobs: Optional[CompletionLogprobs]
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
|
||||
|
||||
class CompletionUsage(TypedDict):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CreateCompletionResponse(TypedDict):
|
||||
id: str
|
||||
object: Literal["text_completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: NotRequired[CompletionUsage]
|
||||
|
||||
|
||||
class ChatCompletionResponseFunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionResponseMessage(TypedDict):
|
||||
content: Optional[str]
|
||||
tool_calls: NotRequired["ChatCompletionMessageToolCalls"]
|
||||
role: Literal["assistant", "function"] # NOTE: "function" may be incorrect here
|
||||
function_call: NotRequired[ChatCompletionResponseFunctionCall] # DEPRECATED
|
||||
|
||||
|
||||
class ChatCompletionFunction(TypedDict):
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
parameters: Dict[str, JsonType] # TODO: make this more specific
|
||||
|
||||
|
||||
class ChatCompletionTopLogprobToken(TypedDict):
|
||||
token: str
|
||||
logprob: float
|
||||
bytes: Optional[List[int]]
|
||||
|
||||
|
||||
class ChatCompletionLogprobToken(ChatCompletionTopLogprobToken):
|
||||
token: str
|
||||
logprob: float
|
||||
bytes: Optional[List[int]]
|
||||
top_logprobs: List[ChatCompletionTopLogprobToken]
|
||||
|
||||
|
||||
class ChatCompletionLogprobs(TypedDict):
|
||||
content: Optional[List[ChatCompletionLogprobToken]]
|
||||
refusal: Optional[List[ChatCompletionLogprobToken]]
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(TypedDict):
|
||||
index: int
|
||||
message: "ChatCompletionResponseMessage"
|
||||
logprobs: Optional[ChatCompletionLogprobs]
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class CreateChatCompletionResponse(TypedDict):
|
||||
id: str
|
||||
object: Literal["chat.completion"]
|
||||
created: int
|
||||
model: str
|
||||
choices: List["ChatCompletionResponseChoice"]
|
||||
usage: CompletionUsage
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallChunkFunction(TypedDict):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallChunk(TypedDict):
|
||||
index: int
|
||||
id: NotRequired[str]
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionMessageToolCallChunkFunction
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseDeltaEmpty(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseDeltaFunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseDelta(TypedDict):
|
||||
content: NotRequired[Optional[str]]
|
||||
function_call: NotRequired[
|
||||
Optional[ChatCompletionStreamResponseDeltaFunctionCall]
|
||||
] # DEPRECATED
|
||||
tool_calls: NotRequired[Optional[List[ChatCompletionMessageToolCallChunk]]]
|
||||
role: NotRequired[Optional[Literal["system", "user", "assistant", "tool"]]]
|
||||
|
||||
|
||||
class ChatCompletionStreamResponseChoice(TypedDict):
|
||||
index: int
|
||||
delta: Union[
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionStreamResponseDeltaEmpty
|
||||
]
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "function_call"]]
|
||||
logprobs: NotRequired[Optional[ChatCompletionLogprobs]]
|
||||
|
||||
|
||||
class CreateChatCompletionStreamResponse(TypedDict):
|
||||
id: str
|
||||
model: str
|
||||
object: Literal["chat.completion.chunk"]
|
||||
created: int
|
||||
choices: List[ChatCompletionStreamResponseChoice]
|
||||
|
||||
|
||||
class ChatCompletionFunctions(TypedDict):
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
parameters: Dict[str, JsonType] # TODO: make this more specific
|
||||
|
||||
|
||||
class ChatCompletionFunctionCallOption(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionRequestResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
schema: NotRequired[
|
||||
JsonType
|
||||
] # https://docs.endpoints.anyscale.com/guides/json_mode/
|
||||
|
||||
|
||||
class ChatCompletionRequestMessageContentPartText(TypedDict):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ChatCompletionRequestMessageContentPartImageImageUrl(TypedDict):
|
||||
url: str
|
||||
detail: NotRequired[Literal["auto", "low", "high"]]
|
||||
|
||||
|
||||
class ChatCompletionRequestMessageContentPartImage(TypedDict):
|
||||
type: Literal["image_url"]
|
||||
image_url: Union[str, ChatCompletionRequestMessageContentPartImageImageUrl]
|
||||
|
||||
|
||||
ChatCompletionRequestMessageContentPart = Union[
|
||||
ChatCompletionRequestMessageContentPartText,
|
||||
ChatCompletionRequestMessageContentPartImage,
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionRequestSystemMessage(TypedDict):
|
||||
role: Literal["system"]
|
||||
content: Optional[str]
|
||||
|
||||
|
||||
class ChatCompletionRequestUserMessage(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: Optional[Union[str, List[ChatCompletionRequestMessageContentPart]]]
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCallFunction(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionMessageToolCall(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionMessageToolCallFunction
|
||||
|
||||
|
||||
ChatCompletionMessageToolCalls = List[ChatCompletionMessageToolCall]
|
||||
|
||||
|
||||
class ChatCompletionRequestAssistantMessageFunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionRequestAssistantMessage(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: NotRequired[str]
|
||||
tool_calls: NotRequired[ChatCompletionMessageToolCalls]
|
||||
function_call: NotRequired[
|
||||
ChatCompletionRequestAssistantMessageFunctionCall
|
||||
] # DEPRECATED
|
||||
|
||||
|
||||
class ChatCompletionRequestToolMessage(TypedDict):
|
||||
role: Literal["tool"]
|
||||
content: Optional[str]
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class ChatCompletionRequestFunctionMessage(TypedDict):
|
||||
role: Literal["function"]
|
||||
content: Optional[str]
|
||||
name: str
|
||||
|
||||
|
||||
ChatCompletionRequestMessage = Union[
|
||||
ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestUserMessage,
|
||||
ChatCompletionRequestAssistantMessage,
|
||||
ChatCompletionRequestUserMessage,
|
||||
ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestFunctionMessage,
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionRequestFunctionCallOption(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
ChatCompletionRequestFunctionCall = Union[
|
||||
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
|
||||
]
|
||||
|
||||
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
||||
|
||||
|
||||
class ChatCompletionToolFunction(TypedDict):
|
||||
name: str
|
||||
description: NotRequired[str]
|
||||
parameters: ChatCompletionFunctionParameters
|
||||
|
||||
|
||||
class ChatCompletionTool(TypedDict):
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolFunction
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoiceFunction(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatCompletionNamedToolChoice(TypedDict):
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionNamedToolChoiceFunction
|
||||
|
||||
|
||||
ChatCompletionToolChoiceOption = Union[
|
||||
Literal["none", "auto", "required"], ChatCompletionNamedToolChoice
|
||||
]
|
||||
|
||||
|
||||
# NOTE: The following type names are not part of the OpenAI OpenAPI specification
|
||||
# and will be removed in a future major release.
|
||||
|
||||
EmbeddingData = Embedding
|
||||
CompletionChunk = CreateCompletionResponse
|
||||
Completion = CreateCompletionResponse
|
||||
CreateCompletionStreamResponse = CreateCompletionResponse
|
||||
ChatCompletionMessage = ChatCompletionResponseMessage
|
||||
ChatCompletionChoice = ChatCompletionResponseChoice
|
||||
ChatCompletion = CreateChatCompletionResponse
|
||||
ChatCompletionChunkDeltaEmpty = ChatCompletionStreamResponseDeltaEmpty
|
||||
ChatCompletionChunkChoice = ChatCompletionStreamResponseChoice
|
||||
ChatCompletionChunkDelta = ChatCompletionStreamResponseDelta
|
||||
ChatCompletionChunk = CreateChatCompletionStreamResponse
|
||||
ChatCompletionStreamResponse = CreateChatCompletionStreamResponse
|
||||
ChatCompletionResponseFunction = ChatCompletionFunction
|
||||
ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall
|
158
venv/Lib/site-packages/llama_cpp/llava_cpp.py
Normal file
158
venv/Lib/site-packages/llama_cpp/llava_cpp.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from ctypes import (
|
||||
c_bool,
|
||||
c_char_p,
|
||||
c_int,
|
||||
c_uint8,
|
||||
c_float,
|
||||
c_void_p,
|
||||
POINTER,
|
||||
_Pointer, # type: ignore
|
||||
Structure,
|
||||
)
|
||||
import pathlib
|
||||
from typing import (
|
||||
Union,
|
||||
NewType,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
from llama_cpp._ctypes_extensions import (
|
||||
load_shared_library,
|
||||
ctypes_function_for_shared_library,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp._ctypes_extensions import (
|
||||
CtypesArray,
|
||||
)
|
||||
|
||||
|
||||
# Specify the base name of the shared library to load
|
||||
_libllava_base_name = "llava"
|
||||
_libllava_override_path = os.environ.get("LLAVA_CPP_LIB")
|
||||
_libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path()
|
||||
|
||||
# Load the library
|
||||
_libllava = load_shared_library(_libllava_base_name, _libllava_base_path)
|
||||
|
||||
ctypes_function = ctypes_function_for_shared_library(_libllava)
|
||||
|
||||
|
||||
################################################
|
||||
# llava.h
|
||||
################################################
|
||||
|
||||
# struct clip_ctx;
|
||||
clip_ctx_p = NewType("clip_ctx_p", int)
|
||||
clip_ctx_p_ctypes = c_void_p
|
||||
|
||||
|
||||
# struct llava_image_embed {
|
||||
# float * embed;
|
||||
# int n_image_pos;
|
||||
# };
|
||||
class llava_image_embed(Structure):
|
||||
_fields_ = [
|
||||
("embed", POINTER(c_float)),
|
||||
("n_image_pos", c_int),
|
||||
]
|
||||
|
||||
|
||||
# /** sanity check for clip <-> llava embed size match */
|
||||
# LLAVA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
||||
@ctypes_function(
|
||||
"llava_validate_embed_size",
|
||||
[llama_cpp.llama_context_p_ctypes, clip_ctx_p_ctypes],
|
||||
c_bool,
|
||||
)
|
||||
def llava_validate_embed_size(
|
||||
ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, /
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
# /** build an image embed from image file bytes */
|
||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
||||
@ctypes_function(
|
||||
"llava_image_embed_make_with_bytes",
|
||||
[clip_ctx_p_ctypes, c_int, POINTER(c_uint8), c_int],
|
||||
POINTER(llava_image_embed),
|
||||
)
|
||||
def llava_image_embed_make_with_bytes(
|
||||
ctx_clip: clip_ctx_p,
|
||||
n_threads: Union[c_int, int],
|
||||
image_bytes: CtypesArray[c_uint8],
|
||||
image_bytes_length: Union[c_int, int],
|
||||
/,
|
||||
) -> "_Pointer[llava_image_embed]":
|
||||
...
|
||||
|
||||
|
||||
# /** build an image embed from a path to an image filename */
|
||||
# LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
||||
@ctypes_function(
|
||||
"llava_image_embed_make_with_filename",
|
||||
[clip_ctx_p_ctypes, c_int, c_char_p],
|
||||
POINTER(llava_image_embed),
|
||||
)
|
||||
def llava_image_embed_make_with_filename(
|
||||
ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, /
|
||||
) -> "_Pointer[llava_image_embed]":
|
||||
...
|
||||
|
||||
|
||||
# LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
||||
# /** free an embedding made with llava_image_embed_make_* */
|
||||
@ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None)
|
||||
def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /):
|
||||
...
|
||||
|
||||
|
||||
# /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||
# LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||
@ctypes_function(
|
||||
"llava_eval_image_embed",
|
||||
[
|
||||
llama_cpp.llama_context_p_ctypes,
|
||||
POINTER(llava_image_embed),
|
||||
c_int,
|
||||
POINTER(c_int),
|
||||
],
|
||||
c_bool,
|
||||
)
|
||||
def llava_eval_image_embed(
|
||||
ctx_llama: llama_cpp.llama_context_p,
|
||||
embed: "_Pointer[llava_image_embed]",
|
||||
n_batch: Union[c_int, int],
|
||||
n_past: "_Pointer[c_int]",
|
||||
/,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
################################################
|
||||
# clip.h
|
||||
################################################
|
||||
|
||||
|
||||
# /** load mmproj model */
|
||||
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
|
||||
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
|
||||
def clip_model_load(
|
||||
fname: bytes, verbosity: Union[c_int, int], /
|
||||
) -> Optional[clip_ctx_p]:
|
||||
...
|
||||
|
||||
|
||||
# /** free mmproj model */
|
||||
# CLIP_API void clip_free(struct clip_ctx * ctx);
|
||||
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
|
||||
def clip_free(ctx: clip_ctx_p, /):
|
||||
...
|
||||
|
280
venv/Lib/site-packages/llama_cpp/mtmd_cpp.py
Normal file
280
venv/Lib/site-packages/llama_cpp/mtmd_cpp.py
Normal file
|
@ -0,0 +1,280 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from ctypes import (
|
||||
c_bool,
|
||||
c_char_p,
|
||||
c_int,
|
||||
c_uint8,
|
||||
c_uint32,
|
||||
c_float,
|
||||
c_void_p,
|
||||
c_size_t,
|
||||
POINTER,
|
||||
_Pointer, # type: ignore
|
||||
Structure,
|
||||
byref,
|
||||
)
|
||||
import pathlib
|
||||
from typing import (
|
||||
Union,
|
||||
NewType,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
from llama_cpp._ctypes_extensions import (
|
||||
load_shared_library,
|
||||
ctypes_function_for_shared_library,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_cpp._ctypes_extensions import (
|
||||
CtypesArray,
|
||||
)
|
||||
|
||||
|
||||
# Specify the base name of the shared library to load
|
||||
_libmtmd_base_name = "mtmd"
|
||||
_libmtmd_override_path = os.environ.get("MTMD_CPP_LIB")
|
||||
_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path()
|
||||
|
||||
# Load the library
|
||||
_libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path)
|
||||
|
||||
ctypes_function = ctypes_function_for_shared_library(_libmtmd)
|
||||
|
||||
################################################
|
||||
# mtmd.h types
|
||||
################################################
|
||||
|
||||
# Opaque types
|
||||
mtmd_context_p = NewType("mtmd_context_p", int)
|
||||
mtmd_context_p_ctypes = c_void_p
|
||||
|
||||
mtmd_bitmap_p = NewType("mtmd_bitmap_p", int)
|
||||
mtmd_bitmap_p_ctypes = c_void_p
|
||||
|
||||
mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int)
|
||||
mtmd_image_tokens_p_ctypes = c_void_p
|
||||
|
||||
mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int)
|
||||
mtmd_input_chunk_p_ctypes = c_void_p
|
||||
|
||||
mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int)
|
||||
mtmd_input_chunks_p_ctypes = c_void_p
|
||||
|
||||
# Enums
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT = 0
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE = 1
|
||||
MTMD_INPUT_CHUNK_TYPE_AUDIO = 2
|
||||
|
||||
# Structures
|
||||
class mtmd_context_params(Structure):
|
||||
_fields_ = [
|
||||
("use_gpu", c_bool),
|
||||
("print_timings", c_bool),
|
||||
("n_threads", c_int),
|
||||
("verbosity", c_int), # ggml_log_level
|
||||
("image_marker", c_char_p),
|
||||
("media_marker", c_char_p),
|
||||
]
|
||||
|
||||
class mtmd_input_text(Structure):
|
||||
_fields_ = [
|
||||
("text", c_char_p),
|
||||
("add_special", c_bool),
|
||||
("parse_special", c_bool),
|
||||
]
|
||||
|
||||
################################################
|
||||
# mtmd.h functions
|
||||
################################################
|
||||
|
||||
# MTMD_API const char * mtmd_default_marker(void);
|
||||
@ctypes_function("mtmd_default_marker", [], c_char_p)
|
||||
def mtmd_default_marker() -> bytes:
|
||||
...
|
||||
|
||||
# MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
|
||||
@ctypes_function("mtmd_context_params_default", [], mtmd_context_params)
|
||||
def mtmd_context_params_default() -> mtmd_context_params:
|
||||
...
|
||||
|
||||
# MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
# const struct llama_model * text_model,
|
||||
# const struct mtmd_context_params ctx_params);
|
||||
@ctypes_function(
|
||||
"mtmd_init_from_file",
|
||||
[c_char_p, llama_cpp.llama_model_p_ctypes, mtmd_context_params],
|
||||
mtmd_context_p_ctypes
|
||||
)
|
||||
def mtmd_init_from_file(
|
||||
mmproj_fname: bytes,
|
||||
text_model: llama_cpp.llama_model_p,
|
||||
ctx_params: mtmd_context_params,
|
||||
/,
|
||||
) -> Optional[mtmd_context_p]:
|
||||
...
|
||||
|
||||
# MTMD_API void mtmd_free(mtmd_context * ctx);
|
||||
@ctypes_function("mtmd_free", [mtmd_context_p_ctypes], None)
|
||||
def mtmd_free(ctx: mtmd_context_p, /):
|
||||
...
|
||||
|
||||
# MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
|
||||
@ctypes_function("mtmd_support_vision", [mtmd_context_p_ctypes], c_bool)
|
||||
def mtmd_support_vision(ctx: mtmd_context_p, /) -> bool:
|
||||
...
|
||||
|
||||
# MTMD_API mtmd_bitmap * mtmd_bitmap_init(uint32_t nx, uint32_t ny, const unsigned char * data);
|
||||
@ctypes_function(
|
||||
"mtmd_bitmap_init",
|
||||
[c_uint32, c_uint32, POINTER(c_uint8)],
|
||||
mtmd_bitmap_p_ctypes
|
||||
)
|
||||
def mtmd_bitmap_init(
|
||||
nx: Union[c_uint32, int],
|
||||
ny: Union[c_uint32, int],
|
||||
data: CtypesArray[c_uint8],
|
||||
/,
|
||||
) -> Optional[mtmd_bitmap_p]:
|
||||
...
|
||||
|
||||
# MTMD_API void mtmd_bitmap_free(mtmd_bitmap * bitmap);
|
||||
@ctypes_function("mtmd_bitmap_free", [mtmd_bitmap_p_ctypes], None)
|
||||
def mtmd_bitmap_free(bitmap: mtmd_bitmap_p, /):
|
||||
...
|
||||
|
||||
# MTMD_API mtmd_input_chunks * mtmd_input_chunks_init(void);
|
||||
@ctypes_function("mtmd_input_chunks_init", [], mtmd_input_chunks_p_ctypes)
|
||||
def mtmd_input_chunks_init() -> Optional[mtmd_input_chunks_p]:
|
||||
...
|
||||
|
||||
# MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
|
||||
@ctypes_function("mtmd_input_chunks_free", [mtmd_input_chunks_p_ctypes], None)
|
||||
def mtmd_input_chunks_free(chunks: mtmd_input_chunks_p, /):
|
||||
...
|
||||
|
||||
# MTMD_API size_t mtmd_input_chunks_size(const mtmd_input_chunks * chunks);
|
||||
@ctypes_function("mtmd_input_chunks_size", [mtmd_input_chunks_p_ctypes], c_size_t)
|
||||
def mtmd_input_chunks_size(chunks: mtmd_input_chunks_p, /) -> int:
|
||||
...
|
||||
|
||||
# MTMD_API const mtmd_input_chunk * mtmd_input_chunks_get(const mtmd_input_chunks * chunks, size_t idx);
|
||||
@ctypes_function(
|
||||
"mtmd_input_chunks_get",
|
||||
[mtmd_input_chunks_p_ctypes, c_size_t],
|
||||
mtmd_input_chunk_p_ctypes
|
||||
)
|
||||
def mtmd_input_chunks_get(
|
||||
chunks: mtmd_input_chunks_p, idx: Union[c_size_t, int], /
|
||||
) -> Optional[mtmd_input_chunk_p]:
|
||||
...
|
||||
|
||||
# MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
# mtmd_input_chunks * output,
|
||||
# const mtmd_input_text * text,
|
||||
# const mtmd_bitmap ** bitmaps,
|
||||
# size_t n_bitmaps);
|
||||
@ctypes_function(
|
||||
"mtmd_tokenize",
|
||||
[
|
||||
mtmd_context_p_ctypes,
|
||||
mtmd_input_chunks_p_ctypes,
|
||||
POINTER(mtmd_input_text),
|
||||
POINTER(mtmd_bitmap_p_ctypes),
|
||||
c_size_t,
|
||||
],
|
||||
c_int,
|
||||
)
|
||||
def mtmd_tokenize(
|
||||
ctx: mtmd_context_p,
|
||||
output: mtmd_input_chunks_p,
|
||||
text: "_Pointer[mtmd_input_text]",
|
||||
bitmaps: CtypesArray[mtmd_bitmap_p_ctypes],
|
||||
n_bitmaps: Union[c_size_t, int],
|
||||
/,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
# MTMD_API size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk);
|
||||
@ctypes_function("mtmd_input_chunk_get_n_tokens", [mtmd_input_chunk_p_ctypes], c_size_t)
|
||||
def mtmd_input_chunk_get_n_tokens(chunk: mtmd_input_chunk_p, /) -> int:
|
||||
...
|
||||
|
||||
# MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type(const mtmd_input_chunk * chunk);
|
||||
@ctypes_function("mtmd_input_chunk_get_type", [mtmd_input_chunk_p_ctypes], c_int)
|
||||
def mtmd_input_chunk_get_type(chunk: mtmd_input_chunk_p, /) -> int:
|
||||
...
|
||||
|
||||
# MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text(const mtmd_input_chunk * chunk, size_t * n_tokens_output);
|
||||
@ctypes_function(
|
||||
"mtmd_input_chunk_get_tokens_text",
|
||||
[mtmd_input_chunk_p_ctypes, POINTER(c_size_t)],
|
||||
POINTER(llama_cpp.llama_token)
|
||||
)
|
||||
def mtmd_input_chunk_get_tokens_text(
|
||||
chunk: mtmd_input_chunk_p, n_tokens_output: "_Pointer[c_size_t]", /
|
||||
) -> Optional["_Pointer[llama_cpp.llama_token]"]:
|
||||
...
|
||||
|
||||
################################################
|
||||
# mtmd-helper.h functions
|
||||
################################################
|
||||
|
||||
# MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len);
|
||||
@ctypes_function(
|
||||
"mtmd_helper_bitmap_init_from_buf",
|
||||
[mtmd_context_p_ctypes, POINTER(c_uint8), c_size_t],
|
||||
mtmd_bitmap_p_ctypes
|
||||
)
|
||||
def mtmd_helper_bitmap_init_from_buf(
|
||||
ctx: mtmd_context_p,
|
||||
buf: CtypesArray[c_uint8],
|
||||
length: Union[c_size_t, int],
|
||||
/,
|
||||
) -> Optional[mtmd_bitmap_p]:
|
||||
...
|
||||
|
||||
# MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
|
||||
@ctypes_function("mtmd_helper_get_n_tokens", [mtmd_input_chunks_p_ctypes], c_size_t)
|
||||
def mtmd_helper_get_n_tokens(chunks: mtmd_input_chunks_p, /) -> int:
|
||||
...
|
||||
|
||||
# MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
# struct llama_context * lctx,
|
||||
# const mtmd_input_chunk * chunk,
|
||||
# llama_pos n_past,
|
||||
# llama_seq_id seq_id,
|
||||
# int32_t n_batch,
|
||||
# bool logits_last,
|
||||
# llama_pos * new_n_past);
|
||||
@ctypes_function(
|
||||
"mtmd_helper_eval_chunk_single",
|
||||
[
|
||||
mtmd_context_p_ctypes,
|
||||
llama_cpp.llama_context_p_ctypes,
|
||||
mtmd_input_chunk_p_ctypes,
|
||||
llama_cpp.llama_pos,
|
||||
llama_cpp.llama_seq_id,
|
||||
c_int,
|
||||
c_bool,
|
||||
POINTER(llama_cpp.llama_pos),
|
||||
],
|
||||
c_int,
|
||||
)
|
||||
def mtmd_helper_eval_chunk_single(
|
||||
ctx: mtmd_context_p,
|
||||
lctx: llama_cpp.llama_context_p,
|
||||
chunk: mtmd_input_chunk_p,
|
||||
n_past: llama_cpp.llama_pos,
|
||||
seq_id: llama_cpp.llama_seq_id,
|
||||
n_batch: Union[c_int, int],
|
||||
logits_last: Union[c_bool, bool],
|
||||
new_n_past: "_Pointer[llama_cpp.llama_pos]",
|
||||
/,
|
||||
) -> int:
|
||||
...
|
0
venv/Lib/site-packages/llama_cpp/py.typed
Normal file
0
venv/Lib/site-packages/llama_cpp/py.typed
Normal file
0
venv/Lib/site-packages/llama_cpp/server/__init__.py
Normal file
0
venv/Lib/site-packages/llama_cpp/server/__init__.py
Normal file
100
venv/Lib/site-packages/llama_cpp/server/__main__.py
Normal file
100
venv/Lib/site-packages/llama_cpp/server/__main__.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
"""Example FastAPI server for llama.cpp.
|
||||
|
||||
To run this example:
|
||||
|
||||
```bash
|
||||
pip install fastapi uvicorn sse-starlette pydantic-settings
|
||||
export MODEL=../models/7B/...
|
||||
```
|
||||
|
||||
Then run:
|
||||
```
|
||||
uvicorn llama_cpp.server.app:create_app --reload
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
python3 -m llama_cpp.server
|
||||
```
|
||||
|
||||
Then visit http://localhost:8000/docs to see the interactive API docs.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llama_cpp.server.app import create_app
|
||||
from llama_cpp.server.settings import (
|
||||
Settings,
|
||||
ServerSettings,
|
||||
ModelSettings,
|
||||
ConfigFileSettings,
|
||||
)
|
||||
from llama_cpp.server.cli import add_args_from_model, parse_model_from_args
|
||||
|
||||
|
||||
def main():
|
||||
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
add_args_from_model(parser, Settings)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
type=str,
|
||||
help="Path to a config file to load.",
|
||||
)
|
||||
server_settings: ServerSettings | None = None
|
||||
model_settings: list[ModelSettings] = []
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
# Load server settings from config_file if provided
|
||||
config_file = os.environ.get("CONFIG_FILE", args.config_file)
|
||||
if config_file:
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Config file {config_file} not found!")
|
||||
with open(config_file, "rb") as f:
|
||||
# Check if yaml file
|
||||
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
||||
import yaml
|
||||
import json
|
||||
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||
json.dumps(yaml.safe_load(f))
|
||||
)
|
||||
else:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||
f.read()
|
||||
)
|
||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||
model_settings = config_file_settings.models
|
||||
else:
|
||||
server_settings = parse_model_from_args(ServerSettings, args)
|
||||
model_settings = [parse_model_from_args(ModelSettings, args)]
|
||||
except Exception as e:
|
||||
print(e, file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
assert server_settings is not None
|
||||
assert model_settings is not None
|
||||
app = create_app(
|
||||
server_settings=server_settings,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=os.getenv("HOST", server_settings.host),
|
||||
port=int(os.getenv("PORT", server_settings.port)),
|
||||
ssl_keyfile=server_settings.ssl_keyfile,
|
||||
ssl_certfile=server_settings.ssl_certfile,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
597
venv/Lib/site-packages/llama_cpp/server/app.py
Normal file
597
venv/Lib/site-packages/llama_cpp/server/app.py
Normal file
|
@ -0,0 +1,597 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import typing
|
||||
import contextlib
|
||||
|
||||
from anyio import Lock
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union, Dict
|
||||
|
||||
import llama_cpp
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from starlette.concurrency import run_in_threadpool, iterate_in_threadpool
|
||||
from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body
|
||||
from fastapi.middleware import Middleware
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import HTTPBearer
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from starlette_context.plugins import RequestIdPlugin # type: ignore
|
||||
from starlette_context.middleware import RawContextMiddleware
|
||||
|
||||
from llama_cpp.server.model import (
|
||||
LlamaProxy,
|
||||
)
|
||||
from llama_cpp.server.settings import (
|
||||
ConfigFileSettings,
|
||||
Settings,
|
||||
ModelSettings,
|
||||
ServerSettings,
|
||||
)
|
||||
from llama_cpp.server.types import (
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
CreateChatCompletionRequest,
|
||||
ModelList,
|
||||
TokenizeInputRequest,
|
||||
TokenizeInputResponse,
|
||||
TokenizeInputCountResponse,
|
||||
DetokenizeInputRequest,
|
||||
DetokenizeInputResponse,
|
||||
)
|
||||
from llama_cpp.server.errors import RouteErrorHandler
|
||||
|
||||
|
||||
router = APIRouter(route_class=RouteErrorHandler)
|
||||
|
||||
_server_settings: Optional[ServerSettings] = None
|
||||
|
||||
|
||||
def set_server_settings(server_settings: ServerSettings):
|
||||
global _server_settings
|
||||
_server_settings = server_settings
|
||||
|
||||
|
||||
def get_server_settings():
|
||||
yield _server_settings
|
||||
|
||||
|
||||
_llama_proxy: Optional[LlamaProxy] = None
|
||||
|
||||
llama_outer_lock = Lock()
|
||||
llama_inner_lock = Lock()
|
||||
|
||||
|
||||
def set_llama_proxy(model_settings: List[ModelSettings]):
|
||||
global _llama_proxy
|
||||
_llama_proxy = LlamaProxy(models=model_settings)
|
||||
|
||||
|
||||
async def get_llama_proxy():
|
||||
# NOTE: This double lock allows the currently streaming llama model to
|
||||
# check if any other requests are pending in the same thread and cancel
|
||||
# the stream if so.
|
||||
await llama_outer_lock.acquire()
|
||||
release_outer_lock = True
|
||||
try:
|
||||
await llama_inner_lock.acquire()
|
||||
try:
|
||||
llama_outer_lock.release()
|
||||
release_outer_lock = False
|
||||
yield _llama_proxy
|
||||
finally:
|
||||
llama_inner_lock.release()
|
||||
finally:
|
||||
if release_outer_lock:
|
||||
llama_outer_lock.release()
|
||||
|
||||
|
||||
_ping_message_factory: typing.Optional[typing.Callable[[], bytes]] = None
|
||||
|
||||
|
||||
def set_ping_message_factory(factory: typing.Callable[[], bytes]):
|
||||
global _ping_message_factory
|
||||
_ping_message_factory = factory
|
||||
|
||||
|
||||
def create_app(
|
||||
settings: Settings | None = None,
|
||||
server_settings: ServerSettings | None = None,
|
||||
model_settings: List[ModelSettings] | None = None,
|
||||
):
|
||||
config_file = os.environ.get("CONFIG_FILE", None)
|
||||
if config_file is not None:
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Config file {config_file} not found!")
|
||||
with open(config_file, "rb") as f:
|
||||
# Check if yaml file
|
||||
if config_file.endswith(".yaml") or config_file.endswith(".yml"):
|
||||
import yaml
|
||||
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(
|
||||
json.dumps(yaml.safe_load(f))
|
||||
)
|
||||
else:
|
||||
config_file_settings = ConfigFileSettings.model_validate_json(f.read())
|
||||
server_settings = ServerSettings.model_validate(config_file_settings)
|
||||
model_settings = config_file_settings.models
|
||||
|
||||
if server_settings is None and model_settings is None:
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
server_settings = ServerSettings.model_validate(settings)
|
||||
model_settings = [ModelSettings.model_validate(settings)]
|
||||
|
||||
assert (
|
||||
server_settings is not None and model_settings is not None
|
||||
), "server_settings and model_settings must be provided together"
|
||||
|
||||
set_server_settings(server_settings)
|
||||
middleware = [Middleware(RawContextMiddleware, plugins=(RequestIdPlugin(),))]
|
||||
app = FastAPI(
|
||||
middleware=middleware,
|
||||
title="🦙 llama.cpp Python API",
|
||||
version=llama_cpp.__version__,
|
||||
root_path=server_settings.root_path,
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(router)
|
||||
|
||||
assert model_settings is not None
|
||||
set_llama_proxy(model_settings=model_settings)
|
||||
|
||||
if server_settings.disable_ping_events:
|
||||
set_ping_message_factory(lambda: bytes())
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def prepare_request_resources(
|
||||
body: CreateCompletionRequest | CreateChatCompletionRequest,
|
||||
llama_proxy: LlamaProxy,
|
||||
body_model: str | None,
|
||||
kwargs,
|
||||
) -> llama_cpp.Llama:
|
||||
if llama_proxy is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Service is not available",
|
||||
)
|
||||
llama = llama_proxy(body_model)
|
||||
if body.logit_bias is not None:
|
||||
kwargs["logit_bias"] = (
|
||||
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
|
||||
if body.logit_bias_type == "tokens"
|
||||
else body.logit_bias
|
||||
)
|
||||
|
||||
if body.grammar is not None:
|
||||
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
|
||||
|
||||
if body.min_tokens > 0:
|
||||
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
|
||||
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
|
||||
)
|
||||
if "logits_processor" not in kwargs:
|
||||
kwargs["logits_processor"] = _min_tokens_logits_processor
|
||||
else:
|
||||
kwargs["logits_processor"].extend(_min_tokens_logits_processor)
|
||||
return llama
|
||||
|
||||
|
||||
async def get_event_publisher(
|
||||
request: Request,
|
||||
inner_send_chan: MemoryObjectSendStream[typing.Any],
|
||||
body: CreateCompletionRequest | CreateChatCompletionRequest,
|
||||
body_model: str | None,
|
||||
llama_call,
|
||||
kwargs,
|
||||
):
|
||||
server_settings = next(get_server_settings())
|
||||
interrupt_requests = (
|
||||
server_settings.interrupt_requests if server_settings else False
|
||||
)
|
||||
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
|
||||
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
|
||||
async with inner_send_chan:
|
||||
try:
|
||||
iterator = await run_in_threadpool(llama_call, llama, **kwargs)
|
||||
async for chunk in iterate_in_threadpool(iterator):
|
||||
await inner_send_chan.send(dict(data=json.dumps(chunk)))
|
||||
if await request.is_disconnected():
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
if interrupt_requests and llama_outer_lock.locked():
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
raise anyio.get_cancelled_exc_class()()
|
||||
await inner_send_chan.send(dict(data="[DONE]"))
|
||||
except anyio.get_cancelled_exc_class() as e:
|
||||
print("disconnected")
|
||||
with anyio.move_on_after(1, shield=True):
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) {request.client}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def _logit_bias_tokens_to_input_ids(
|
||||
llama: llama_cpp.Llama,
|
||||
logit_bias: Dict[str, float],
|
||||
) -> Dict[str, float]:
|
||||
to_bias: Dict[str, float] = {}
|
||||
for token, score in logit_bias.items():
|
||||
token = token.encode("utf-8")
|
||||
for input_id in llama.tokenize(token, add_bos=False, special=True):
|
||||
to_bias[str(input_id)] = score
|
||||
return to_bias
|
||||
|
||||
|
||||
# Setup Bearer authentication scheme
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def authenticate(
|
||||
settings: Settings = Depends(get_server_settings),
|
||||
authorization: Optional[str] = Depends(bearer_scheme),
|
||||
):
|
||||
# Skip API key check if it's not set in settings
|
||||
if settings.api_key is None:
|
||||
return True
|
||||
|
||||
# check bearer credentials against the api_key
|
||||
if authorization and authorization.credentials == settings.api_key:
|
||||
# api key is valid
|
||||
return authorization.credentials
|
||||
|
||||
# raise http error 401
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
|
||||
openai_v1_tag = "OpenAI V1"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
summary="Completion",
|
||||
dependencies=[Depends(authenticate)],
|
||||
response_model=Union[
|
||||
llama_cpp.CreateCompletionResponse,
|
||||
str,
|
||||
],
|
||||
responses={
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{"$ref": "#/components/schemas/CreateCompletionResponse"}
|
||||
],
|
||||
"title": "Completion response, when stream=False",
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"title": "Server Side Streaming response, when stream=True. "
|
||||
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
|
||||
"example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/engines/copilot-codex/completions",
|
||||
include_in_schema=False,
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def create_completion(
|
||||
request: Request,
|
||||
body: CreateCompletionRequest,
|
||||
) -> llama_cpp.Completion:
|
||||
if isinstance(body.prompt, list):
|
||||
assert len(body.prompt) <= 1
|
||||
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""
|
||||
|
||||
body_model = (
|
||||
body.model
|
||||
if request.url.path != "/v1/engines/copilot-codex/completions"
|
||||
else "copilot-codex"
|
||||
)
|
||||
|
||||
exclude = {
|
||||
"n",
|
||||
"best_of",
|
||||
"logit_bias_type",
|
||||
"user",
|
||||
"min_tokens",
|
||||
}
|
||||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
# handle streaming request
|
||||
if kwargs.get("stream", False):
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
return EventSourceResponse(
|
||||
recv_chan,
|
||||
data_sender_callable=partial( # type: ignore
|
||||
get_event_publisher,
|
||||
request=request,
|
||||
inner_send_chan=send_chan,
|
||||
body=body,
|
||||
body_model=body_model,
|
||||
llama_call=llama_cpp.Llama.__call__,
|
||||
kwargs=kwargs,
|
||||
),
|
||||
sep="\n",
|
||||
ping_message_factory=_ping_message_factory,
|
||||
)
|
||||
|
||||
# handle regular request
|
||||
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
|
||||
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
|
||||
|
||||
if await request.is_disconnected():
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Client closed request",
|
||||
)
|
||||
|
||||
return await run_in_threadpool(llama, **kwargs)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
summary="Embedding",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def create_embedding(
|
||||
request: CreateEmbeddingRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
):
|
||||
return await run_in_threadpool(
|
||||
llama_proxy(request.model).create_embedding,
|
||||
**request.model_dump(exclude={"user"}),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
summary="Chat",
|
||||
dependencies=[Depends(authenticate)],
|
||||
response_model=Union[llama_cpp.ChatCompletion, str],
|
||||
responses={
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/CreateChatCompletionResponse"
|
||||
}
|
||||
],
|
||||
"title": "Completion response, when stream=False",
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"title": "Server Side Streaming response, when stream=True"
|
||||
+ "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501
|
||||
"example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: Request,
|
||||
body: CreateChatCompletionRequest = Body(
|
||||
openapi_examples={
|
||||
"normal": {
|
||||
"summary": "Chat Completion",
|
||||
"value": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
},
|
||||
},
|
||||
"json_mode": {
|
||||
"summary": "JSON Mode",
|
||||
"value": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020"},
|
||||
],
|
||||
"response_format": {"type": "json_object"},
|
||||
},
|
||||
},
|
||||
"tool_calling": {
|
||||
"summary": "Tool Calling",
|
||||
"value": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Extract Jason is 30 years old."},
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "User",
|
||||
"description": "User record",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
"tool_choice": {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "User",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"logprobs": {
|
||||
"summary": "Logprobs",
|
||||
"value": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
) -> llama_cpp.ChatCompletion:
|
||||
# This is a workaround for an issue in FastAPI dependencies
|
||||
# where the dependency is cleaned up before a StreamingResponse
|
||||
# is complete.
|
||||
# https://github.com/tiangolo/fastapi/issues/11143
|
||||
|
||||
body_model = body.model
|
||||
exclude = {
|
||||
"n",
|
||||
"logit_bias_type",
|
||||
"user",
|
||||
"min_tokens",
|
||||
}
|
||||
kwargs = body.model_dump(exclude=exclude)
|
||||
|
||||
# handle streaming request
|
||||
if kwargs.get("stream", False):
|
||||
send_chan, recv_chan = anyio.create_memory_object_stream(10)
|
||||
return EventSourceResponse(
|
||||
recv_chan,
|
||||
data_sender_callable=partial( # type: ignore
|
||||
get_event_publisher,
|
||||
request=request,
|
||||
inner_send_chan=send_chan,
|
||||
body=body,
|
||||
body_model=body_model,
|
||||
llama_call=llama_cpp.Llama.create_chat_completion,
|
||||
kwargs=kwargs,
|
||||
),
|
||||
sep="\n",
|
||||
ping_message_factory=_ping_message_factory,
|
||||
)
|
||||
|
||||
# handle regular request
|
||||
async with contextlib.asynccontextmanager(get_llama_proxy)() as llama_proxy:
|
||||
llama = prepare_request_resources(body, llama_proxy, body_model, kwargs)
|
||||
|
||||
if await request.is_disconnected():
|
||||
print(
|
||||
f"Disconnected from client (via refresh/close) before llm invoked {request.client}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Client closed request",
|
||||
)
|
||||
|
||||
return await run_in_threadpool(llama.create_chat_completion, **kwargs)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/models",
|
||||
summary="Models",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[openai_v1_tag],
|
||||
)
|
||||
async def get_models(
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> ModelList:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_alias,
|
||||
"object": "model",
|
||||
"owned_by": "me",
|
||||
"permissions": [],
|
||||
}
|
||||
for model_alias in llama_proxy
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
extras_tag = "Extras"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extras/tokenize",
|
||||
summary="Tokenize",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[extras_tag],
|
||||
)
|
||||
async def tokenize(
|
||||
body: TokenizeInputRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> TokenizeInputResponse:
|
||||
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
|
||||
|
||||
return TokenizeInputResponse(tokens=tokens)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extras/tokenize/count",
|
||||
summary="Tokenize Count",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[extras_tag],
|
||||
)
|
||||
async def count_query_tokens(
|
||||
body: TokenizeInputRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> TokenizeInputCountResponse:
|
||||
tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True)
|
||||
|
||||
return TokenizeInputCountResponse(count=len(tokens))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extras/detokenize",
|
||||
summary="Detokenize",
|
||||
dependencies=[Depends(authenticate)],
|
||||
tags=[extras_tag],
|
||||
)
|
||||
async def detokenize(
|
||||
body: DetokenizeInputRequest,
|
||||
llama_proxy: LlamaProxy = Depends(get_llama_proxy),
|
||||
) -> DetokenizeInputResponse:
|
||||
text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8")
|
||||
|
||||
return DetokenizeInputResponse(text=text)
|
97
venv/Lib/site-packages/llama_cpp/server/cli.py
Normal file
97
venv/Lib/site-packages/llama_cpp/server/cli.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
from typing import List, Literal, Union, Any, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _get_base_type(annotation: Type[Any]) -> Type[Any]:
|
||||
if getattr(annotation, "__origin__", None) is Literal:
|
||||
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
|
||||
return type(annotation.__args__[0]) # type: ignore
|
||||
elif getattr(annotation, "__origin__", None) is Union:
|
||||
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
|
||||
non_optional_args: List[Type[Any]] = [
|
||||
arg for arg in annotation.__args__ if arg is not type(None) # type: ignore
|
||||
]
|
||||
if non_optional_args:
|
||||
return _get_base_type(non_optional_args[0])
|
||||
elif (
|
||||
getattr(annotation, "__origin__", None) is list
|
||||
or getattr(annotation, "__origin__", None) is List
|
||||
):
|
||||
assert hasattr(annotation, "__args__") and len(annotation.__args__) >= 1 # type: ignore
|
||||
return _get_base_type(annotation.__args__[0]) # type: ignore
|
||||
return annotation
|
||||
|
||||
|
||||
def _contains_list_type(annotation: Type[Any] | None) -> bool:
|
||||
origin = getattr(annotation, "__origin__", None)
|
||||
|
||||
if origin is list or origin is List:
|
||||
return True
|
||||
elif origin in (Literal, Union):
|
||||
return any(_contains_list_type(arg) for arg in annotation.__args__) # type: ignore
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _parse_bool_arg(arg: str | bytes | bool) -> bool:
|
||||
if isinstance(arg, bytes):
|
||||
arg = arg.decode("utf-8")
|
||||
|
||||
true_values = {"1", "on", "t", "true", "y", "yes"}
|
||||
false_values = {"0", "off", "f", "false", "n", "no"}
|
||||
|
||||
arg_str = str(arg).lower().strip()
|
||||
|
||||
if arg_str in true_values:
|
||||
return True
|
||||
elif arg_str in false_values:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Invalid boolean argument: {arg}")
|
||||
|
||||
|
||||
def add_args_from_model(parser: argparse.ArgumentParser, model: Type[BaseModel]):
|
||||
"""Add arguments from a pydantic model to an argparse parser."""
|
||||
|
||||
for name, field in model.model_fields.items():
|
||||
description = field.description
|
||||
if field.default and description and not field.is_required():
|
||||
description += f" (default: {field.default})"
|
||||
base_type = (
|
||||
_get_base_type(field.annotation) if field.annotation is not None else str
|
||||
)
|
||||
list_type = _contains_list_type(field.annotation)
|
||||
if base_type is not bool:
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
nargs="*" if list_type else None,
|
||||
type=base_type,
|
||||
help=description,
|
||||
)
|
||||
if base_type is bool:
|
||||
parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=_parse_bool_arg,
|
||||
help=f"{description}",
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Type[BaseModel])
|
||||
|
||||
|
||||
def parse_model_from_args(model: T, args: argparse.Namespace) -> T:
|
||||
"""Parse a pydantic model from an argparse namespace."""
|
||||
return model(
|
||||
**{
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if v is not None and k in model.model_fields
|
||||
}
|
||||
)
|
212
venv/Lib/site-packages/llama_cpp/server/errors.py
Normal file
212
venv/Lib/site-packages/llama_cpp/server/errors.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import time
|
||||
from re import compile, Match, Pattern
|
||||
from typing import Callable, Coroutine, Optional, Tuple, Union, Dict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
from fastapi import (
|
||||
Request,
|
||||
Response,
|
||||
HTTPException,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from llama_cpp.server.types import (
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
CreateChatCompletionRequest,
|
||||
)
|
||||
|
||||
|
||||
class ErrorResponse(TypedDict):
|
||||
"""OpenAI style error response"""
|
||||
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str]
|
||||
code: Optional[str]
|
||||
|
||||
|
||||
class ErrorResponseFormatters:
|
||||
"""Collection of formatters for error responses.
|
||||
|
||||
Args:
|
||||
request (Union[CreateCompletionRequest, CreateChatCompletionRequest]):
|
||||
Request body
|
||||
match (Match[str]): Match object from regex pattern
|
||||
|
||||
Returns:
|
||||
Tuple[int, ErrorResponse]: Status code and error response
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def context_length_exceeded(
|
||||
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for context length exceeded error"""
|
||||
|
||||
context_window = int(match.group(2))
|
||||
prompt_tokens = int(match.group(1))
|
||||
completion_tokens = request.max_tokens
|
||||
if hasattr(request, "messages"):
|
||||
# Chat completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens. "
|
||||
"However, you requested {} tokens "
|
||||
"({} in the messages, {} in the completion). "
|
||||
"Please reduce the length of the messages or completion."
|
||||
)
|
||||
else:
|
||||
# Text completion
|
||||
message = (
|
||||
"This model's maximum context length is {} tokens, "
|
||||
"however you requested {} tokens "
|
||||
"({} in your prompt; {} for the completion). "
|
||||
"Please reduce your prompt; or completion length."
|
||||
)
|
||||
return 400, ErrorResponse(
|
||||
message=message.format(
|
||||
context_window,
|
||||
(completion_tokens or 0) + prompt_tokens,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
), # type: ignore
|
||||
type="invalid_request_error",
|
||||
param="messages",
|
||||
code="context_length_exceeded",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def model_not_found(
|
||||
request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
match, # type: Match[str] # type: ignore
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Formatter for model_not_found error"""
|
||||
|
||||
model_path = str(match.group(1))
|
||||
message = f"The model `{model_path}` does not exist"
|
||||
return 400, ErrorResponse(
|
||||
message=message,
|
||||
type="invalid_request_error",
|
||||
param=None,
|
||||
code="model_not_found",
|
||||
)
|
||||
|
||||
|
||||
class RouteErrorHandler(APIRoute):
|
||||
"""Custom APIRoute that handles application errors and exceptions"""
|
||||
|
||||
# key: regex pattern for original error message from llama_cpp
|
||||
# value: formatter function
|
||||
pattern_and_formatters: Dict[
|
||||
"Pattern[str]",
|
||||
Callable[
|
||||
[
|
||||
Union["CreateCompletionRequest", "CreateChatCompletionRequest"],
|
||||
"Match[str]",
|
||||
],
|
||||
Tuple[int, ErrorResponse],
|
||||
],
|
||||
] = {
|
||||
compile(
|
||||
r"Requested tokens \((\d+)\) exceed context window of (\d+)"
|
||||
): ErrorResponseFormatters.context_length_exceeded,
|
||||
compile(
|
||||
r"Model path does not exist: (.+)"
|
||||
): ErrorResponseFormatters.model_not_found,
|
||||
}
|
||||
|
||||
def error_message_wrapper(
|
||||
self,
|
||||
error: Exception,
|
||||
body: Optional[
|
||||
Union[
|
||||
"CreateChatCompletionRequest",
|
||||
"CreateCompletionRequest",
|
||||
"CreateEmbeddingRequest",
|
||||
]
|
||||
] = None,
|
||||
) -> Tuple[int, ErrorResponse]:
|
||||
"""Wraps error message in OpenAI style error response"""
|
||||
if body is not None and isinstance(
|
||||
body,
|
||||
(
|
||||
CreateCompletionRequest,
|
||||
CreateChatCompletionRequest,
|
||||
),
|
||||
):
|
||||
# When text completion or chat completion
|
||||
for pattern, callback in self.pattern_and_formatters.items():
|
||||
match = pattern.search(str(error))
|
||||
if match is not None:
|
||||
return callback(body, match)
|
||||
|
||||
# Only print the trace on unexpected exceptions
|
||||
print(f"Exception: {str(error)}", file=sys.stderr)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
|
||||
# Wrap other errors as internal server error
|
||||
return 500, ErrorResponse(
|
||||
message=str(error),
|
||||
type="internal_server_error",
|
||||
param=None,
|
||||
code=None,
|
||||
)
|
||||
|
||||
def get_route_handler(
|
||||
self,
|
||||
) -> Callable[[Request], Coroutine[None, None, Response]]:
|
||||
"""Defines custom route handler that catches exceptions and formats
|
||||
in OpenAI style error response"""
|
||||
|
||||
original_route_handler = super().get_route_handler()
|
||||
|
||||
async def custom_route_handler(request: Request) -> Response:
|
||||
try:
|
||||
start_sec = time.perf_counter()
|
||||
response = await original_route_handler(request)
|
||||
elapsed_time_ms = int((time.perf_counter() - start_sec) * 1000)
|
||||
response.headers["openai-processing-ms"] = f"{elapsed_time_ms}"
|
||||
return response
|
||||
except HTTPException as unauthorized:
|
||||
# api key check failed
|
||||
raise unauthorized
|
||||
except Exception as exc:
|
||||
json_body = await request.json()
|
||||
try:
|
||||
if "messages" in json_body:
|
||||
# Chat completion
|
||||
body: Optional[
|
||||
Union[
|
||||
CreateChatCompletionRequest,
|
||||
CreateCompletionRequest,
|
||||
CreateEmbeddingRequest,
|
||||
]
|
||||
] = CreateChatCompletionRequest(**json_body)
|
||||
elif "prompt" in json_body:
|
||||
# Text completion
|
||||
body = CreateCompletionRequest(**json_body)
|
||||
else:
|
||||
# Embedding
|
||||
body = CreateEmbeddingRequest(**json_body)
|
||||
except Exception:
|
||||
# Invalid request body
|
||||
body = None
|
||||
|
||||
# Get proper error message from the exception
|
||||
(
|
||||
status_code,
|
||||
error_message,
|
||||
) = self.error_message_wrapper(error=exc, body=body)
|
||||
return JSONResponse(
|
||||
{"error": error_message},
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
return custom_route_handler
|
312
venv/Lib/site-packages/llama_cpp/server/model.py
Normal file
312
venv/Lib/site-packages/llama_cpp/server/model.py
Normal file
|
@ -0,0 +1,312 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from typing import Dict, Optional, Union, List
|
||||
|
||||
import llama_cpp
|
||||
import llama_cpp.llama_speculative as llama_speculative
|
||||
import llama_cpp.llama_tokenizer as llama_tokenizer
|
||||
|
||||
from llama_cpp.server.settings import ModelSettings
|
||||
|
||||
|
||||
class LlamaProxy:
|
||||
def __init__(self, models: List[ModelSettings]) -> None:
|
||||
assert len(models) > 0, "No models provided!"
|
||||
|
||||
self._model_settings_dict: dict[str, ModelSettings] = {}
|
||||
for model in models:
|
||||
if not model.model_alias:
|
||||
model.model_alias = model.model
|
||||
self._model_settings_dict[model.model_alias] = model
|
||||
|
||||
self._current_model: Optional[llama_cpp.Llama] = None
|
||||
self._current_model_alias: Optional[str] = None
|
||||
|
||||
self._default_model_settings: ModelSettings = models[0]
|
||||
self._default_model_alias: str = self._default_model_settings.model_alias # type: ignore
|
||||
|
||||
# Load default model
|
||||
self._current_model = self.load_llama_from_model_settings(
|
||||
self._default_model_settings
|
||||
)
|
||||
self._current_model_alias = self._default_model_alias
|
||||
|
||||
def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
|
||||
if model is None:
|
||||
model = self._default_model_alias
|
||||
|
||||
if model not in self._model_settings_dict:
|
||||
model = self._default_model_alias
|
||||
|
||||
if model == self._current_model_alias:
|
||||
if self._current_model is not None:
|
||||
return self._current_model
|
||||
|
||||
if self._current_model:
|
||||
self._current_model.close()
|
||||
self._current_model = None
|
||||
|
||||
settings = self._model_settings_dict[model]
|
||||
self._current_model = self.load_llama_from_model_settings(settings)
|
||||
self._current_model_alias = model
|
||||
return self._current_model
|
||||
|
||||
def __getitem__(self, model: str):
|
||||
return self._model_settings_dict[model].model_dump()
|
||||
|
||||
def __setitem__(self, model: str, settings: Union[ModelSettings, str, bytes]):
|
||||
if isinstance(settings, (bytes, str)):
|
||||
settings = ModelSettings.model_validate_json(settings)
|
||||
self._model_settings_dict[model] = settings
|
||||
|
||||
def __iter__(self):
|
||||
for model in self._model_settings_dict:
|
||||
yield model
|
||||
|
||||
def free(self):
|
||||
if self._current_model:
|
||||
self._current_model.close()
|
||||
del self._current_model
|
||||
|
||||
@staticmethod
|
||||
def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
|
||||
chat_handler = None
|
||||
if settings.chat_format == "llava-1-5":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.Llava15ChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.Llava15ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "obsidian":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.ObsidianChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.ObsidianChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "llava-1-6":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.Llava16ChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.Llava16ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "moondream":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.MoondreamChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.MoondreamChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "nanollava":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.NanoLlavaChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.NanoLlavaChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "llama-3-vision-alpha":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.Llama3VisionAlpha.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.Llama3VisionAlpha(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "minicpm-v-2.6":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.MiniCPMv26ChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.MiniCPMv26ChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "qwen2.5-vl":
|
||||
assert settings.clip_model_path is not None, "clip model not found"
|
||||
if settings.hf_model_repo_id is not None:
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.Qwen25VLChatHandler.from_pretrained(
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.clip_model_path,
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
)
|
||||
else:
|
||||
chat_handler = llama_cpp.llama_chat_format.Qwen25VLChatHandler(
|
||||
clip_model_path=settings.clip_model_path, verbose=settings.verbose
|
||||
)
|
||||
elif settings.chat_format == "hf-autotokenizer":
|
||||
assert (
|
||||
settings.hf_pretrained_model_name_or_path is not None
|
||||
), "hf_pretrained_model_name_or_path must be set for hf-autotokenizer"
|
||||
chat_handler = (
|
||||
llama_cpp.llama_chat_format.hf_autotokenizer_to_chat_completion_handler(
|
||||
settings.hf_pretrained_model_name_or_path
|
||||
)
|
||||
)
|
||||
elif settings.chat_format == "hf-tokenizer-config":
|
||||
assert (
|
||||
settings.hf_tokenizer_config_path is not None
|
||||
), "hf_tokenizer_config_path must be set for hf-tokenizer-config"
|
||||
chat_handler = llama_cpp.llama_chat_format.hf_tokenizer_config_to_chat_completion_handler(
|
||||
json.load(open(settings.hf_tokenizer_config_path))
|
||||
)
|
||||
|
||||
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
|
||||
if settings.hf_pretrained_model_name_or_path is not None:
|
||||
tokenizer = llama_tokenizer.LlamaHFTokenizer.from_pretrained(
|
||||
settings.hf_pretrained_model_name_or_path
|
||||
)
|
||||
|
||||
draft_model = None
|
||||
if settings.draft_model is not None:
|
||||
draft_model = llama_speculative.LlamaPromptLookupDecoding(
|
||||
num_pred_tokens=settings.draft_model_num_pred_tokens
|
||||
)
|
||||
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None
|
||||
if settings.kv_overrides is not None:
|
||||
assert isinstance(settings.kv_overrides, list)
|
||||
kv_overrides = {}
|
||||
for kv in settings.kv_overrides:
|
||||
key, value = kv.split("=")
|
||||
if ":" in value:
|
||||
value_type, value = value.split(":")
|
||||
if value_type == "bool":
|
||||
kv_overrides[key] = value.lower() in ["true", "1"]
|
||||
elif value_type == "int":
|
||||
kv_overrides[key] = int(value)
|
||||
elif value_type == "float":
|
||||
kv_overrides[key] = float(value)
|
||||
elif value_type == "str":
|
||||
kv_overrides[key] = value
|
||||
else:
|
||||
raise ValueError(f"Unknown value type {value_type}")
|
||||
|
||||
import functools
|
||||
|
||||
kwargs = {}
|
||||
|
||||
if settings.hf_model_repo_id is not None:
|
||||
create_fn = functools.partial(
|
||||
llama_cpp.Llama.from_pretrained,
|
||||
repo_id=settings.hf_model_repo_id,
|
||||
filename=settings.model,
|
||||
)
|
||||
else:
|
||||
create_fn = llama_cpp.Llama
|
||||
kwargs["model_path"] = settings.model
|
||||
|
||||
_model = create_fn(
|
||||
**kwargs,
|
||||
# Model Params
|
||||
n_gpu_layers=settings.n_gpu_layers,
|
||||
split_mode=settings.split_mode,
|
||||
main_gpu=settings.main_gpu,
|
||||
tensor_split=settings.tensor_split,
|
||||
vocab_only=settings.vocab_only,
|
||||
use_mmap=settings.use_mmap,
|
||||
use_mlock=settings.use_mlock,
|
||||
kv_overrides=kv_overrides,
|
||||
rpc_servers=settings.rpc_servers,
|
||||
# Context Params
|
||||
seed=settings.seed,
|
||||
n_ctx=settings.n_ctx,
|
||||
n_batch=settings.n_batch,
|
||||
n_ubatch=settings.n_ubatch,
|
||||
n_threads=settings.n_threads,
|
||||
n_threads_batch=settings.n_threads_batch,
|
||||
rope_scaling_type=settings.rope_scaling_type,
|
||||
rope_freq_base=settings.rope_freq_base,
|
||||
rope_freq_scale=settings.rope_freq_scale,
|
||||
yarn_ext_factor=settings.yarn_ext_factor,
|
||||
yarn_attn_factor=settings.yarn_attn_factor,
|
||||
yarn_beta_fast=settings.yarn_beta_fast,
|
||||
yarn_beta_slow=settings.yarn_beta_slow,
|
||||
yarn_orig_ctx=settings.yarn_orig_ctx,
|
||||
mul_mat_q=settings.mul_mat_q,
|
||||
logits_all=settings.logits_all,
|
||||
embedding=settings.embedding,
|
||||
offload_kqv=settings.offload_kqv,
|
||||
flash_attn=settings.flash_attn,
|
||||
# Sampling Params
|
||||
last_n_tokens_size=settings.last_n_tokens_size,
|
||||
# LoRA Params
|
||||
lora_base=settings.lora_base,
|
||||
lora_path=settings.lora_path,
|
||||
# Backend Params
|
||||
numa=settings.numa,
|
||||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
chat_handler=chat_handler,
|
||||
# Speculative Decoding
|
||||
draft_model=draft_model,
|
||||
# KV Cache Quantization
|
||||
type_k=settings.type_k,
|
||||
type_v=settings.type_v,
|
||||
# Tokenizer
|
||||
tokenizer=tokenizer,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
if settings.cache:
|
||||
if settings.cache_type == "disk":
|
||||
if settings.verbose:
|
||||
print(f"Using disk cache with size {settings.cache_size}")
|
||||
cache = llama_cpp.LlamaDiskCache(capacity_bytes=settings.cache_size)
|
||||
else:
|
||||
if settings.verbose:
|
||||
print(f"Using ram cache with size {settings.cache_size}")
|
||||
cache = llama_cpp.LlamaRAMCache(capacity_bytes=settings.cache_size)
|
||||
_model.set_cache(cache)
|
||||
return _model
|
240
venv/Lib/site-packages/llama_cpp/server/settings.py
Normal file
240
venv/Lib/site-packages/llama_cpp/server/settings.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
|
||||
from typing import Optional, List, Literal, Union, Dict, cast
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
import llama_cpp
|
||||
|
||||
# Disable warning for model and model_alias settings
|
||||
BaseSettings.model_config["protected_namespaces"] = ()
|
||||
|
||||
|
||||
class ModelSettings(BaseSettings):
|
||||
"""Model settings used to load a Llama model."""
|
||||
|
||||
model: str = Field(
|
||||
description="The path to the model to use for generating completions."
|
||||
)
|
||||
model_alias: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The alias of the model to use for generating completions.",
|
||||
)
|
||||
# Model Params
|
||||
n_gpu_layers: int = Field(
|
||||
default=0,
|
||||
ge=-1,
|
||||
description="The number of layers to put on the GPU. The rest will be on the CPU. Set -1 to move all to GPU.",
|
||||
)
|
||||
split_mode: int = Field(
|
||||
default=llama_cpp.LLAMA_SPLIT_MODE_LAYER,
|
||||
description="The split mode to use.",
|
||||
)
|
||||
main_gpu: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Main GPU to use.",
|
||||
)
|
||||
tensor_split: Optional[List[float]] = Field(
|
||||
default=None,
|
||||
description="Split layers across multiple GPUs in proportion.",
|
||||
)
|
||||
vocab_only: bool = Field(
|
||||
default=False, description="Whether to only return the vocabulary."
|
||||
)
|
||||
use_mmap: bool = Field(
|
||||
default=llama_cpp.llama_supports_mmap(),
|
||||
description="Use mmap.",
|
||||
)
|
||||
use_mlock: bool = Field(
|
||||
default=llama_cpp.llama_supports_mlock(),
|
||||
description="Use mlock.",
|
||||
)
|
||||
kv_overrides: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.",
|
||||
)
|
||||
rpc_servers: Optional[str] = Field(
|
||||
default=None,
|
||||
description="comma seperated list of rpc servers for offloading",
|
||||
)
|
||||
# Context Params
|
||||
seed: int = Field(
|
||||
default=llama_cpp.LLAMA_DEFAULT_SEED, description="Random seed. -1 for random."
|
||||
)
|
||||
n_ctx: int = Field(default=2048, ge=0, description="The context size.")
|
||||
n_batch: int = Field(
|
||||
default=512, ge=1, description="The batch size to use per eval."
|
||||
)
|
||||
n_ubatch: int = Field(
|
||||
default=512, ge=1, description="The physical batch size used by llama.cpp"
|
||||
)
|
||||
n_threads: int = Field(
|
||||
default=max(multiprocessing.cpu_count() // 2, 1),
|
||||
ge=1,
|
||||
description="The number of threads to use. Use -1 for max cpu threads",
|
||||
)
|
||||
n_threads_batch: int = Field(
|
||||
default=max(multiprocessing.cpu_count(), 1),
|
||||
ge=0,
|
||||
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
|
||||
)
|
||||
rope_scaling_type: int = Field(
|
||||
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
||||
)
|
||||
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
|
||||
rope_freq_scale: float = Field(
|
||||
default=0.0, description="RoPE frequency scaling factor"
|
||||
)
|
||||
yarn_ext_factor: float = Field(default=-1.0)
|
||||
yarn_attn_factor: float = Field(default=1.0)
|
||||
yarn_beta_fast: float = Field(default=32.0)
|
||||
yarn_beta_slow: float = Field(default=1.0)
|
||||
yarn_orig_ctx: int = Field(default=0)
|
||||
mul_mat_q: bool = Field(
|
||||
default=True, description="if true, use experimental mul_mat_q kernels"
|
||||
)
|
||||
logits_all: bool = Field(default=True, description="Whether to return logits.")
|
||||
embedding: bool = Field(default=False, description="Whether to use embeddings.")
|
||||
offload_kqv: bool = Field(
|
||||
default=True, description="Whether to offload kqv to the GPU."
|
||||
)
|
||||
flash_attn: bool = Field(
|
||||
default=False, description="Whether to use flash attention."
|
||||
)
|
||||
# Sampling Params
|
||||
last_n_tokens_size: int = Field(
|
||||
default=64,
|
||||
ge=0,
|
||||
description="Last n tokens to keep for repeat penalty calculation.",
|
||||
)
|
||||
# LoRA Params
|
||||
lora_base: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.",
|
||||
)
|
||||
lora_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a LoRA file to apply to the model.",
|
||||
)
|
||||
# Backend Params
|
||||
numa: Union[bool, int] = Field(
|
||||
default=False,
|
||||
description="Enable NUMA support.",
|
||||
)
|
||||
# Chat Format Params
|
||||
chat_format: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Chat format to use.",
|
||||
)
|
||||
clip_model_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to a CLIP model to use for multi-modal chat completion.",
|
||||
)
|
||||
# Cache Params
|
||||
cache: bool = Field(
|
||||
default=False,
|
||||
description="Use a cache to reduce processing times for evaluated prompts.",
|
||||
)
|
||||
cache_type: Literal["ram", "disk"] = Field(
|
||||
default="ram",
|
||||
description="The type of cache to use. Only used if cache is True.",
|
||||
)
|
||||
cache_size: int = Field(
|
||||
default=2 << 30,
|
||||
description="The size of the cache in bytes. Only used if cache is True.",
|
||||
)
|
||||
# Tokenizer Options
|
||||
hf_tokenizer_config_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The path to a HuggingFace tokenizer_config.json file.",
|
||||
)
|
||||
hf_pretrained_model_name_or_path: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
|
||||
)
|
||||
# Loading from HuggingFace Model Hub
|
||||
hf_model_repo_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The model repo id to use for the HuggingFace tokenizer model.",
|
||||
)
|
||||
# Speculative Decoding
|
||||
draft_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
|
||||
)
|
||||
draft_model_num_pred_tokens: int = Field(
|
||||
default=10,
|
||||
description="Number of tokens to predict using the draft model.",
|
||||
)
|
||||
# KV Cache Quantization
|
||||
type_k: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Type of the key cache quantization.",
|
||||
)
|
||||
type_v: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Type of the value cache quantization.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
)
|
||||
|
||||
@model_validator(
|
||||
mode="before"
|
||||
) # pre=True to ensure this runs before any other validation
|
||||
def set_dynamic_defaults(self) -> Self:
|
||||
# If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
|
||||
cpu_count = multiprocessing.cpu_count()
|
||||
values = cast(Dict[str, int], self)
|
||||
if values.get("n_threads", 0) == -1:
|
||||
values["n_threads"] = cpu_count
|
||||
if values.get("n_threads_batch", 0) == -1:
|
||||
values["n_threads_batch"] = cpu_count
|
||||
return self
|
||||
|
||||
|
||||
class ServerSettings(BaseSettings):
|
||||
"""Server settings used to configure the FastAPI and Uvicorn server."""
|
||||
|
||||
# Uvicorn Settings
|
||||
host: str = Field(default="localhost", description="Listen address")
|
||||
port: int = Field(default=8000, description="Listen port")
|
||||
ssl_keyfile: Optional[str] = Field(
|
||||
default=None, description="SSL key file for HTTPS"
|
||||
)
|
||||
ssl_certfile: Optional[str] = Field(
|
||||
default=None, description="SSL certificate file for HTTPS"
|
||||
)
|
||||
# FastAPI Settings
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for authentication. If set all requests need to be authenticated.",
|
||||
)
|
||||
interrupt_requests: bool = Field(
|
||||
default=True,
|
||||
description="Whether to interrupt requests when a new request is received.",
|
||||
)
|
||||
disable_ping_events: bool = Field(
|
||||
default=False,
|
||||
description="Disable EventSource pings (may be needed for some clients).",
|
||||
)
|
||||
root_path: str = Field(
|
||||
default="",
|
||||
description="The root path for the server. Useful when running behind a reverse proxy.",
|
||||
)
|
||||
|
||||
|
||||
class Settings(ServerSettings, ModelSettings):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigFileSettings(ServerSettings):
|
||||
"""Configuration file format settings."""
|
||||
|
||||
models: List[ModelSettings] = Field(default=[], description="Model configs")
|
316
venv/Lib/site-packages/llama_cpp/server/types.py
Normal file
316
venv/Lib/site-packages/llama_cpp/server/types.py
Normal file
|
@ -0,0 +1,316 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Union, Dict
|
||||
from typing_extensions import TypedDict, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import llama_cpp
|
||||
|
||||
|
||||
model_field = Field(
|
||||
description="The model to use for generating completions.", default=None
|
||||
)
|
||||
|
||||
max_tokens_field = Field(
|
||||
default=16, ge=1, description="The maximum number of tokens to generate."
|
||||
)
|
||||
|
||||
min_tokens_field = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
|
||||
)
|
||||
|
||||
temperature_field = Field(
|
||||
default=0.8,
|
||||
description="Adjust the randomness of the generated text.\n\n"
|
||||
+ "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.",
|
||||
)
|
||||
|
||||
top_p_field = Field(
|
||||
default=0.95,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n"
|
||||
+ "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.",
|
||||
)
|
||||
|
||||
min_p_field = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Sets a minimum base probability threshold for token selection.\n\n"
|
||||
+ "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.",
|
||||
)
|
||||
|
||||
stop_field = Field(
|
||||
default=None,
|
||||
description="A list of tokens at which to stop generation. If None, no stop tokens are used.",
|
||||
)
|
||||
|
||||
stream_field = Field(
|
||||
default=False,
|
||||
description="Whether to stream the results as they are generated. Useful for chatbots.",
|
||||
)
|
||||
|
||||
top_k_field = Field(
|
||||
default=40,
|
||||
ge=0,
|
||||
description="Limit the next token selection to the K most probable tokens.\n\n"
|
||||
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
|
||||
)
|
||||
|
||||
repeat_penalty_field = Field(
|
||||
default=1.1,
|
||||
ge=0.0,
|
||||
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
|
||||
+ "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.",
|
||||
)
|
||||
|
||||
presence_penalty_field = Field(
|
||||
default=0.0,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
||||
)
|
||||
|
||||
frequency_penalty_field = Field(
|
||||
default=0.0,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||||
)
|
||||
|
||||
mirostat_mode_field = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=2,
|
||||
description="Enable Mirostat constant-perplexity algorithm of the specified version (1 or 2; 0 = disabled)",
|
||||
)
|
||||
|
||||
mirostat_tau_field = Field(
|
||||
default=5.0,
|
||||
ge=0.0,
|
||||
le=10.0,
|
||||
description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text",
|
||||
)
|
||||
|
||||
mirostat_eta_field = Field(
|
||||
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
|
||||
)
|
||||
|
||||
grammar = Field(
|
||||
default=None,
|
||||
description="A CBNF grammar (as string) to be used for formatting the model's output.",
|
||||
)
|
||||
|
||||
|
||||
class CreateCompletionRequest(BaseModel):
|
||||
prompt: Union[str, List[str]] = Field(
|
||||
default="", description="The prompt to generate completions for."
|
||||
)
|
||||
suffix: Optional[str] = Field(
|
||||
default=None,
|
||||
description="A suffix to append to the generated text. If None, no suffix is appended. Useful for chatbots.",
|
||||
)
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=16, ge=0, description="The maximum number of tokens to generate."
|
||||
)
|
||||
min_tokens: int = min_tokens_field
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to echo the prompt in the generated text. Useful for chatbots.",
|
||||
)
|
||||
stop: Optional[Union[str, List[str]]] = stop_field
|
||||
stream: bool = stream_field
|
||||
logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="The number of logprobs to generate. If None, no logprobs are generated.",
|
||||
)
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
best_of: Optional[int] = 1
|
||||
user: Optional[str] = Field(default=None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"prompt": "\n\n### Instructions:\nWhat is the capital of France?\n\n### Response:\n",
|
||||
"stop": ["\n", "###"],
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class CreateEmbeddingRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
input: Union[str, List[str]] = Field(description="The input to embed.")
|
||||
user: Optional[str] = Field(default=None)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"input": "The food was delicious and the waiter...",
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "function"] = Field(
|
||||
default="user", description="The role of the message."
|
||||
)
|
||||
content: Optional[str] = Field(
|
||||
default="", description="The content of the message."
|
||||
)
|
||||
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
messages: List[llama_cpp.ChatCompletionRequestMessage] = Field(
|
||||
default=[], description="A list of messages to generate completions for."
|
||||
)
|
||||
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
||||
default=None,
|
||||
description="A list of functions to apply to the generated completions.",
|
||||
)
|
||||
function_call: Optional[llama_cpp.ChatCompletionRequestFunctionCall] = Field(
|
||||
default=None,
|
||||
description="A function to apply to the generated completions.",
|
||||
)
|
||||
tools: Optional[List[llama_cpp.ChatCompletionTool]] = Field(
|
||||
default=None,
|
||||
description="A list of tools to apply to the generated completions.",
|
||||
)
|
||||
tool_choice: Optional[llama_cpp.ChatCompletionToolChoiceOption] = Field(
|
||||
default=None,
|
||||
description="A tool to apply to the generated completions.",
|
||||
) # TODO: verify
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="The maximum number of tokens to generate. Defaults to inf",
|
||||
)
|
||||
min_tokens: int = min_tokens_field
|
||||
logprobs: Optional[bool] = Field(
|
||||
default=False,
|
||||
description="Whether to output the logprobs or not. Default is True",
|
||||
)
|
||||
top_logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.",
|
||||
)
|
||||
temperature: float = temperature_field
|
||||
top_p: float = top_p_field
|
||||
min_p: float = min_p_field
|
||||
stop: Optional[Union[str, List[str]]] = stop_field
|
||||
stream: bool = stream_field
|
||||
presence_penalty: Optional[float] = presence_penalty_field
|
||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
n: Optional[int] = 1
|
||||
user: Optional[str] = Field(None)
|
||||
|
||||
# llama.cpp specific parameters
|
||||
top_k: int = top_k_field
|
||||
repeat_penalty: float = repeat_penalty_field
|
||||
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
|
||||
mirostat_mode: int = mirostat_mode_field
|
||||
mirostat_tau: float = mirostat_tau_field
|
||||
mirostat_eta: float = mirostat_eta_field
|
||||
grammar: Optional[str] = None
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"messages": [
|
||||
ChatCompletionRequestMessage(
|
||||
role="system", content="You are a helpful assistant."
|
||||
).model_dump(),
|
||||
ChatCompletionRequestMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
).model_dump(),
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelData(TypedDict):
|
||||
id: str
|
||||
object: Literal["model"]
|
||||
owned_by: str
|
||||
permissions: List[str]
|
||||
|
||||
|
||||
class ModelList(TypedDict):
|
||||
object: Literal["list"]
|
||||
data: List[ModelData]
|
||||
|
||||
|
||||
class TokenizeInputRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
input: str = Field(description="The input to tokenize.")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {"examples": [{"input": "How many tokens in this query?"}]}
|
||||
}
|
||||
|
||||
|
||||
class TokenizeInputResponse(BaseModel):
|
||||
tokens: List[int] = Field(description="A list of tokens.")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": {"tokens": [123, 321, 222]}}}
|
||||
|
||||
|
||||
class TokenizeInputCountResponse(BaseModel):
|
||||
count: int = Field(description="The number of tokens in the input.")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": {"count": 5}}}
|
||||
|
||||
|
||||
class DetokenizeInputRequest(BaseModel):
|
||||
model: Optional[str] = model_field
|
||||
tokens: List[int] = Field(description="A list of toekns to detokenize.")
|
||||
|
||||
model_config = {"json_schema_extra": {"example": [{"tokens": [123, 321, 222]}]}}
|
||||
|
||||
|
||||
class DetokenizeInputResponse(BaseModel):
|
||||
text: str = Field(description="The detokenized text.")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {"example": {"text": "How many tokens in this query?"}}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue