850 lines
28 KiB
Python
850 lines
28 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import ctypes
|
|
|
|
from typing import (
|
|
Dict,
|
|
List,
|
|
Tuple,
|
|
Optional,
|
|
Sequence,
|
|
Callable,
|
|
Union,
|
|
)
|
|
from dataclasses import dataclass, field
|
|
from contextlib import ExitStack
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
from .llama_types import *
|
|
from .llama_grammar import LlamaGrammar
|
|
from ._utils import suppress_stdout_stderr
|
|
|
|
import llama_cpp.llama_cpp as llama_cpp
|
|
|
|
|
|
# Python wrappers over llama.h structs
|
|
|
|
|
|
class LlamaModel:
|
|
"""Intermediate Python wrapper for a llama.cpp llama_model.
|
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
path_model: str,
|
|
params: llama_cpp.llama_model_params,
|
|
verbose: bool = True,
|
|
):
|
|
self.path_model = path_model
|
|
self.params = params
|
|
self.verbose = verbose
|
|
self._exit_stack = ExitStack()
|
|
|
|
model = None
|
|
|
|
if not os.path.exists(path_model):
|
|
raise ValueError(f"Model path does not exist: {path_model}")
|
|
|
|
with suppress_stdout_stderr(disable=verbose):
|
|
model = llama_cpp.llama_model_load_from_file(
|
|
self.path_model.encode("utf-8"), self.params
|
|
)
|
|
|
|
if model is None:
|
|
raise ValueError(f"Failed to load model from file: {path_model}")
|
|
|
|
vocab = llama_cpp.llama_model_get_vocab(model)
|
|
|
|
if vocab is None:
|
|
raise ValueError(f"Failed to get vocab from model: {path_model}")
|
|
|
|
self.model = model
|
|
self.vocab = vocab
|
|
self.sampler = None # LlamaModel doesn't use samplers, but some cleanup code expects this attribute
|
|
|
|
def free_model():
|
|
if self.model is None:
|
|
return
|
|
llama_cpp.llama_model_free(self.model)
|
|
self.model = None
|
|
|
|
self._exit_stack.callback(free_model)
|
|
|
|
def close(self):
|
|
if self.sampler is not None:
|
|
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
|
for i, _ in reversed(self.custom_samplers):
|
|
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
|
|
self.custom_samplers.clear()
|
|
self._exit_stack.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
def vocab_type(self) -> int:
|
|
return llama_cpp.llama_vocab_type(self.vocab)
|
|
|
|
def n_vocab(self) -> int:
|
|
return llama_cpp.llama_vocab_n_tokens(self.vocab)
|
|
|
|
def n_ctx_train(self) -> int:
|
|
return llama_cpp.llama_model_n_ctx_train(self.model)
|
|
|
|
def n_embd(self) -> int:
|
|
return llama_cpp.llama_model_n_embd(self.model)
|
|
|
|
def rope_freq_scale_train(self) -> float:
|
|
return llama_cpp.llama_model_rope_freq_scale_train(self.model)
|
|
|
|
def desc(self) -> str:
|
|
buf = ctypes.create_string_buffer(1024)
|
|
llama_cpp.llama_model_desc(self.model, buf, 1024)
|
|
return buf.value.decode("utf-8")
|
|
|
|
def size(self) -> int:
|
|
return llama_cpp.llama_model_size(self.model)
|
|
|
|
def n_params(self) -> int:
|
|
return llama_cpp.llama_model_n_params(self.model)
|
|
|
|
def get_tensor(self, name: str) -> ctypes.c_void_p:
|
|
raise NotImplementedError("get_tensor is not implemented in llama.cpp")
|
|
|
|
# Vocab
|
|
|
|
def token_get_text(self, token: int) -> str:
|
|
return llama_cpp.llama_vocab_get_text(self.vocab, token).decode("utf-8")
|
|
|
|
def token_get_score(self, token: int) -> float:
|
|
return llama_cpp.llama_vocab_get_score(self.vocab, token)
|
|
|
|
def token_get_attr(self, token: int) -> int:
|
|
return llama_cpp.llama_vocab_get_attr(self.vocab, token)
|
|
|
|
# Special tokens
|
|
|
|
def token_bos(self) -> int:
|
|
return llama_cpp.llama_vocab_bos(self.vocab)
|
|
|
|
def token_eos(self) -> int:
|
|
return llama_cpp.llama_vocab_eos(self.vocab)
|
|
|
|
def token_cls(self) -> int:
|
|
return llama_cpp.llama_vocab_cls(self.vocab)
|
|
|
|
def token_sep(self) -> int:
|
|
return llama_cpp.llama_vocab_sep(self.vocab)
|
|
|
|
def token_nl(self) -> int:
|
|
return llama_cpp.llama_vocab_nl(self.vocab)
|
|
|
|
def token_prefix(self) -> int:
|
|
return llama_cpp.llama_vocab_fim_pre(self.vocab)
|
|
|
|
def token_middle(self) -> int:
|
|
return llama_cpp.llama_vocab_fim_mid(self.vocab)
|
|
|
|
def token_suffix(self) -> int:
|
|
return llama_cpp.llama_vocab_fim_suf(self.vocab)
|
|
|
|
def token_eot(self) -> int:
|
|
return llama_cpp.llama_vocab_eot(self.vocab)
|
|
|
|
def add_bos_token(self) -> bool:
|
|
return llama_cpp.llama_vocab_get_add_bos(self.vocab)
|
|
|
|
def add_eos_token(self) -> bool:
|
|
return llama_cpp.llama_vocab_get_add_eos(self.vocab)
|
|
|
|
# Tokenization
|
|
|
|
def tokenize(self, text: bytes, add_bos: bool, special: bool):
|
|
n_ctx = self.n_ctx_train()
|
|
tokens = (llama_cpp.llama_token * n_ctx)()
|
|
n_tokens = llama_cpp.llama_tokenize(
|
|
self.vocab, text, len(text), tokens, n_ctx, add_bos, special
|
|
)
|
|
if n_tokens < 0:
|
|
n_tokens = abs(n_tokens)
|
|
tokens = (llama_cpp.llama_token * n_tokens)()
|
|
n_tokens = llama_cpp.llama_tokenize(
|
|
self.vocab, text, len(text), tokens, n_tokens, add_bos, special
|
|
)
|
|
if n_tokens < 0:
|
|
raise RuntimeError(
|
|
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
|
)
|
|
return list(tokens[:n_tokens])
|
|
|
|
def token_to_piece(self, token: int, special: bool = False) -> bytes:
|
|
buf = ctypes.create_string_buffer(32)
|
|
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
|
|
return bytes(buf)
|
|
|
|
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
|
|
output = b""
|
|
size = 32
|
|
buffer = (ctypes.c_char * size)()
|
|
for token in tokens:
|
|
n = llama_cpp.llama_token_to_piece(
|
|
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
|
|
)
|
|
assert n <= size
|
|
output += bytes(buffer[:n])
|
|
# NOTE: Llama1 models automatically added a space at the start of the prompt
|
|
# this line removes a leading space if the first token is a beginning of sentence token
|
|
return (
|
|
output[1:]
|
|
if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
|
|
else output
|
|
)
|
|
|
|
# Extra
|
|
def metadata(self) -> Dict[str, str]:
|
|
metadata: Dict[str, str] = {}
|
|
buffer_size = 1024
|
|
buffer = ctypes.create_string_buffer(buffer_size)
|
|
# zero the buffer
|
|
buffer.value = b"\0" * buffer_size
|
|
# iterate over model keys
|
|
for i in range(llama_cpp.llama_model_meta_count(self.model)):
|
|
nbytes = llama_cpp.llama_model_meta_key_by_index(
|
|
self.model, i, buffer, buffer_size
|
|
)
|
|
if nbytes > buffer_size:
|
|
buffer_size = nbytes + 1
|
|
buffer = ctypes.create_string_buffer(buffer_size)
|
|
nbytes = llama_cpp.llama_model_meta_key_by_index(
|
|
self.model, i, buffer, buffer_size
|
|
)
|
|
key = buffer.value.decode("utf-8")
|
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
|
|
self.model, i, buffer, buffer_size
|
|
)
|
|
if nbytes > buffer_size:
|
|
buffer_size = nbytes + 1
|
|
buffer = ctypes.create_string_buffer(buffer_size)
|
|
nbytes = llama_cpp.llama_model_meta_val_str_by_index(
|
|
self.model, i, buffer, buffer_size
|
|
)
|
|
value = buffer.value.decode("utf-8")
|
|
metadata[key] = value
|
|
return metadata
|
|
|
|
@staticmethod
|
|
def default_params():
|
|
"""Get the default llama_model_params."""
|
|
return llama_cpp.llama_model_default_params()
|
|
|
|
|
|
class LlamaContext:
|
|
"""Intermediate Python wrapper for a llama.cpp llama_context.
|
|
NOTE: For stability it's recommended you use the Llama class instead."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
model: LlamaModel,
|
|
params: llama_cpp.llama_context_params,
|
|
verbose: bool = True,
|
|
):
|
|
self.model = model
|
|
self.params = params
|
|
self.verbose = verbose
|
|
self._exit_stack = ExitStack()
|
|
|
|
ctx = llama_cpp.llama_init_from_model(self.model.model, self.params)
|
|
|
|
if ctx is None:
|
|
raise ValueError("Failed to create llama_context")
|
|
|
|
self.ctx = ctx
|
|
self.memory = llama_cpp.llama_get_memory(self.ctx)
|
|
self.sampler = None # LlamaContext doesn't manage samplers directly, but some cleanup code expects this attribute
|
|
|
|
def free_ctx():
|
|
if self.ctx is None:
|
|
return
|
|
llama_cpp.llama_free(self.ctx)
|
|
self.ctx = None
|
|
|
|
self._exit_stack.callback(free_ctx)
|
|
|
|
def close(self):
|
|
self._exit_stack.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
def n_ctx(self) -> int:
|
|
return llama_cpp.llama_n_ctx(self.ctx)
|
|
|
|
def pooling_type(self) -> int:
|
|
return llama_cpp.llama_pooling_type(self.ctx)
|
|
|
|
def kv_cache_clear(self):
|
|
llama_cpp.llama_memory_clear(self.memory, True)
|
|
|
|
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
|
|
llama_cpp.llama_memory_seq_rm(self.memory, seq_id, p0, p1)
|
|
|
|
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
|
|
llama_cpp.llama_memory_seq_cp(self.memory, seq_id_src, seq_id_dst, p0, p1)
|
|
|
|
def kv_cache_seq_keep(self, seq_id: int):
|
|
llama_cpp.llama_memory_seq_keep(self.memory, seq_id)
|
|
|
|
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
|
|
llama_cpp.llama_memory_seq_add(self.memory, seq_id, p0, p1, shift)
|
|
|
|
def get_state_size(self) -> int:
|
|
return llama_cpp.llama_state_get_size(self.ctx)
|
|
|
|
# TODO: copy_state_data
|
|
|
|
# TODO: set_state_data
|
|
|
|
# TODO: llama_load_session_file
|
|
|
|
# TODO: llama_save_session_file
|
|
|
|
def decode(self, batch: LlamaBatch):
|
|
return_code = llama_cpp.llama_decode(
|
|
self.ctx,
|
|
batch.batch,
|
|
)
|
|
if return_code != 0:
|
|
raise RuntimeError(f"llama_decode returned {return_code}")
|
|
|
|
def encode(self, batch: LlamaBatch):
|
|
return_code = llama_cpp.llama_encode(
|
|
self.ctx,
|
|
batch.batch,
|
|
)
|
|
if return_code != 0:
|
|
raise RuntimeError(f"llama_encode returned {return_code}")
|
|
|
|
def set_n_threads(self, n_threads: int, n_threads_batch: int):
|
|
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
|
|
|
|
def get_logits(self):
|
|
return llama_cpp.llama_get_logits(self.ctx)
|
|
|
|
def get_logits_ith(self, i: int):
|
|
return llama_cpp.llama_get_logits_ith(self.ctx, i)
|
|
|
|
def get_embeddings(self):
|
|
return llama_cpp.llama_get_embeddings(self.ctx)
|
|
|
|
def get_embeddings_ith(self, i: int):
|
|
return llama_cpp.llama_get_embeddings_ith(self.ctx, i)
|
|
|
|
def get_embeddings_seq(self, seq_id: int):
|
|
return llama_cpp.llama_get_embeddings_seq(self.ctx, seq_id)
|
|
|
|
# Sampling functions - deprecated, use LlamaSampler instead
|
|
|
|
def set_rng_seed(self, seed: int):
|
|
raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_repetition_penalties(
|
|
self,
|
|
candidates: "_LlamaTokenDataArray",
|
|
last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]",
|
|
penalty_last_n: int,
|
|
penalty_repeat: float,
|
|
penalty_freq: float,
|
|
penalty_present: float,
|
|
):
|
|
raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
|
|
raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
|
|
raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
|
raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
|
|
raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_typical(
|
|
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
|
|
):
|
|
raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
|
|
raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
|
|
raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_token_mirostat(
|
|
self,
|
|
candidates: "_LlamaTokenDataArray",
|
|
tau: float,
|
|
eta: float,
|
|
m: int,
|
|
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
|
) -> int:
|
|
raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_token_mirostat_v2(
|
|
self,
|
|
candidates: "_LlamaTokenDataArray",
|
|
tau: float,
|
|
eta: float,
|
|
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
|
|
) -> int:
|
|
raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
|
|
raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead")
|
|
|
|
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
|
|
raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead")
|
|
|
|
# Grammar
|
|
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
|
|
raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead")
|
|
|
|
def reset_timings(self):
|
|
llama_cpp.llama_perf_context_reset(self.ctx)
|
|
|
|
def print_timings(self):
|
|
llama_cpp.llama_perf_context_print(self.ctx)
|
|
|
|
# Utility functions
|
|
@staticmethod
|
|
def default_params():
|
|
"""Get the default llama_context_params."""
|
|
return llama_cpp.llama_context_default_params()
|
|
|
|
|
|
class LlamaBatch:
|
|
def __init__(
|
|
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
|
):
|
|
self._n_tokens = n_tokens
|
|
self.embd = embd
|
|
self.n_seq_max = n_seq_max
|
|
self.verbose = verbose
|
|
self._exit_stack = ExitStack()
|
|
|
|
batch = llama_cpp.llama_batch_init(self._n_tokens, self.embd, self.n_seq_max)
|
|
|
|
if batch is None:
|
|
raise ValueError("Failed to create llama_batch")
|
|
|
|
self.batch = batch
|
|
self.sampler = None # LlamaBatch doesn't use samplers, but some cleanup code expects this attribute
|
|
|
|
def free_batch():
|
|
if self.batch is None:
|
|
return
|
|
llama_cpp.llama_batch_free(self.batch)
|
|
self.batch = None
|
|
|
|
self._exit_stack.callback(free_batch)
|
|
|
|
def close(self):
|
|
self._exit_stack.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
def n_tokens(self) -> int:
|
|
return self.batch.n_tokens
|
|
|
|
def reset(self):
|
|
self.batch.n_tokens = 0
|
|
|
|
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
|
n_tokens = len(batch)
|
|
self.batch.n_tokens = n_tokens
|
|
for i in range(n_tokens):
|
|
self.batch.token[i] = batch[i]
|
|
self.batch.pos[i] = n_past + i
|
|
self.batch.seq_id[i][0] = 0
|
|
self.batch.n_seq_id[i] = 1
|
|
self.batch.logits[i] = logits_all
|
|
self.batch.logits[n_tokens - 1] = True
|
|
|
|
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
|
|
n_tokens = len(batch)
|
|
n_tokens0 = self.batch.n_tokens
|
|
self.batch.n_tokens += n_tokens
|
|
for i in range(n_tokens):
|
|
j = n_tokens0 + i
|
|
self.batch.token[j] = batch[i]
|
|
self.batch.pos[j] = i
|
|
self.batch.seq_id[j][0] = seq_id
|
|
self.batch.n_seq_id[j] = 1
|
|
self.batch.logits[j] = logits_all
|
|
self.batch.logits[n_tokens - 1] = True
|
|
|
|
|
|
class LlamaTokenDataArray:
|
|
def __init__(self, *, n_vocab: int):
|
|
self.n_vocab = n_vocab
|
|
self.candidates_data = np.recarray(
|
|
(self.n_vocab,),
|
|
dtype=np.dtype(
|
|
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
|
),
|
|
)
|
|
self.candidates = llama_cpp.llama_token_data_array(
|
|
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
|
size=self.n_vocab,
|
|
sorted=False,
|
|
)
|
|
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
|
|
self.default_candidates_data_p = np.zeros(self.n_vocab, dtype=np.single)
|
|
self.sampler = None # LlamaTokenDataArray doesn't use samplers, but some cleanup code expects this attribute
|
|
|
|
def copy_logits(self, logits: npt.NDArray[np.single]):
|
|
self.candidates_data.id[:] = self.default_candidates_data_id
|
|
self.candidates_data.logit[:] = logits
|
|
self.candidates_data.p[:] = self.default_candidates_data_p
|
|
self.candidates.sorted = False
|
|
self.candidates.size = self.n_vocab
|
|
|
|
|
|
# Embedding functions
|
|
|
|
|
|
def normalize_embedding(embedding):
|
|
norm = float(np.linalg.norm(embedding))
|
|
if norm == 0.0:
|
|
return embedding
|
|
return [v / norm for v in embedding]
|
|
|
|
|
|
# Python wrappers over common/sampling structs
|
|
|
|
|
|
@dataclass
|
|
class LlamaSamplingParams:
|
|
n_prev: int = 64
|
|
n_probs: int = 0
|
|
top_k: int = 40
|
|
top_p: float = 0.95
|
|
min_p: float = 0.05
|
|
tfs_z: float = 1.00
|
|
typical_p: float = 1.00
|
|
temp: float = 0.80
|
|
penalty_last_n: int = 64
|
|
penalty_repeat: float = 1.0
|
|
penalty_freq: float = 0.00
|
|
penalty_present: float = 0.00
|
|
mirostat: int = 0
|
|
mirostat_tau: float = 5.00
|
|
mirostat_eta: float = 0.10
|
|
penalize_nl: bool = True
|
|
|
|
grammar: str = ""
|
|
|
|
cfg_negative_prompt: str = ""
|
|
cfg_scale: float = 1.00
|
|
|
|
logit_bias: dict[int, float] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class LlamaSamplingContext:
|
|
params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams)
|
|
mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float)
|
|
grammar: Optional[LlamaGrammar] = None
|
|
# NOTE: Missing parsed_grammar
|
|
prev: list[int] = field(default_factory=list)
|
|
cur: list[llama_cpp.llama_token_data] = field(default_factory=list)
|
|
|
|
def reset(self):
|
|
self.prev = []
|
|
self.cur = []
|
|
if self.grammar is not None:
|
|
self.grammar.reset()
|
|
|
|
def cp(self):
|
|
return LlamaSamplingContext(
|
|
params=self.params,
|
|
mirostat_mu=self.mirostat_mu,
|
|
grammar=self.grammar,
|
|
prev=self.prev.copy(),
|
|
cur=self.cur.copy(),
|
|
)
|
|
|
|
def last(self) -> Optional[int]:
|
|
if len(self.prev) > 0:
|
|
return self.prev[-1]
|
|
else:
|
|
return None
|
|
|
|
def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
|
|
return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8")
|
|
|
|
def sample(
|
|
self,
|
|
ctx_main: LlamaContext,
|
|
idx: int = 0,
|
|
logits_array: Optional[npt.NDArray[np.single]] = None,
|
|
):
|
|
# This method is deprecated in favor of using LlamaSampler directly
|
|
raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead")
|
|
|
|
def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
|
|
self.prev.append(id)
|
|
|
|
|
|
class CustomSampler:
|
|
def __init__(
|
|
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
|
|
):
|
|
self.apply_func = apply_func
|
|
|
|
def apply_wrapper(
|
|
sampler: llama_cpp.llama_sampler_p,
|
|
cur_p: llama_cpp.llama_token_data_array_p,
|
|
):
|
|
self.apply_func(cur_p)
|
|
|
|
def free_wrapper(sampler: llama_cpp.llama_sampler_p):
|
|
pass
|
|
|
|
sampler_i = llama_cpp.llama_sampler_i()
|
|
sampler_i.apply = llama_cpp.llama_sampler_i_apply(apply_wrapper)
|
|
self._apply_wrapper_ref = apply_wrapper
|
|
|
|
sampler_i.name = llama_cpp.llama_sampler_i_name(0)
|
|
sampler_i.accept = llama_cpp.llama_sampler_i_accept(0)
|
|
sampler_i.reset = llama_cpp.llama_sampler_i_reset(0)
|
|
sampler_i.clone = llama_cpp.llama_sampler_i_clone(0)
|
|
sampler_i.free = llama_cpp.llama_sampler_i_free(0)
|
|
|
|
self.sampler = llama_cpp.llama_sampler()
|
|
self.sampler.iface = ctypes.pointer(sampler_i)
|
|
self.sampler.ctx = None
|
|
|
|
def get_sampler(self) -> llama_cpp.llama_sampler_p:
|
|
return ctypes.pointer(self.sampler)
|
|
|
|
|
|
class LlamaSampler:
|
|
def __init__(self):
|
|
params = llama_cpp.llama_sampler_chain_default_params()
|
|
self.sampler = llama_cpp.llama_sampler_chain_init(params)
|
|
self.custom_samplers: List[Tuple[int, CustomSampler]] = []
|
|
self._exit_stack = ExitStack()
|
|
|
|
def free_sampler():
|
|
if self.sampler is not None:
|
|
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
|
for i, _ in reversed(self.custom_samplers):
|
|
llama_cpp.llama_sampler_chain_remove(self.sampler, i)
|
|
llama_cpp.llama_sampler_free(self.sampler)
|
|
self.sampler = None
|
|
|
|
self._exit_stack.callback(free_sampler)
|
|
|
|
def close(self):
|
|
self._exit_stack.close()
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
def add_greedy(self):
|
|
sampler = llama_cpp.llama_sampler_init_greedy()
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_dist(self, seed: int):
|
|
sampler = llama_cpp.llama_sampler_init_dist(seed)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_softmax(self):
|
|
sampler = llama_cpp.llama_sampler_init_softmax()
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_top_k(self, k: int):
|
|
sampler = llama_cpp.llama_sampler_init_top_k(k)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_top_p(self, p: float, min_keep: int = 1):
|
|
sampler = llama_cpp.llama_sampler_init_top_p(p, min_keep)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_min_p(self, p: float, min_keep: int = 1):
|
|
sampler = llama_cpp.llama_sampler_init_min_p(p, min_keep)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_typical(self, p: float, min_keep: int = 1):
|
|
sampler = llama_cpp.llama_sampler_init_typical(p, min_keep)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_temp(self, temp: float):
|
|
sampler = llama_cpp.llama_sampler_init_temp(temp)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_temp_ext(self, t: float, delta: float, exponent: float):
|
|
sampler = llama_cpp.llama_sampler_init_temp_ext(t, delta, exponent)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_xtc(self, p: float, t: float, min_keep: int, seed: int):
|
|
sampler = llama_cpp.llama_sampler_init_xtc(p, t, min_keep, seed)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_top_n_sigma(self, n: float):
|
|
sampler = llama_cpp.llama_sampler_init_top_n_sigma(n)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int):
|
|
sampler = llama_cpp.llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_mirostat_v2(self, seed: int, tau: float, eta: float):
|
|
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
|
|
sampler = llama_cpp.llama_sampler_init_grammar(
|
|
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
|
|
)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_grammar_lazy_patterns(
|
|
self,
|
|
model: LlamaModel,
|
|
grammar: LlamaGrammar,
|
|
trigger_patterns: List[str],
|
|
trigger_tokens: List[int]
|
|
):
|
|
# Convert patterns to C array
|
|
pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))()
|
|
for i, pattern in enumerate(trigger_patterns):
|
|
pattern_ptrs[i] = pattern.encode("utf-8")
|
|
|
|
# Convert tokens to C array
|
|
token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens)
|
|
|
|
sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
|
|
model.vocab,
|
|
grammar._grammar.encode("utf-8"),
|
|
grammar._root.encode("utf-8"),
|
|
pattern_ptrs,
|
|
len(trigger_patterns),
|
|
token_array,
|
|
len(trigger_tokens)
|
|
)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_penalties(
|
|
self,
|
|
penalty_last_n: int,
|
|
penalty_repeat: float,
|
|
penalty_freq: float,
|
|
penalty_present: float,
|
|
):
|
|
sampler = llama_cpp.llama_sampler_init_penalties(
|
|
penalty_last_n,
|
|
penalty_repeat,
|
|
penalty_freq,
|
|
penalty_present,
|
|
)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_dry(
|
|
self,
|
|
model: LlamaModel,
|
|
n_ctx_train: int,
|
|
dry_multiplier: float,
|
|
dry_base: float,
|
|
dry_allowed_length: int,
|
|
dry_penalty_last_n: int,
|
|
seq_breakers: List[str]
|
|
):
|
|
# Convert seq_breakers to C array
|
|
breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))()
|
|
for i, breaker in enumerate(seq_breakers):
|
|
breaker_ptrs[i] = breaker.encode("utf-8")
|
|
|
|
sampler = llama_cpp.llama_sampler_init_dry(
|
|
model.vocab,
|
|
n_ctx_train,
|
|
dry_multiplier,
|
|
dry_base,
|
|
dry_allowed_length,
|
|
dry_penalty_last_n,
|
|
breaker_ptrs,
|
|
len(seq_breakers)
|
|
)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_logit_bias(
|
|
self,
|
|
n_vocab: int,
|
|
logit_bias: Dict[int, float]
|
|
):
|
|
# Convert logit_bias dict to C array
|
|
bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
|
|
for i, (token, bias) in enumerate(logit_bias.items()):
|
|
bias_array[i].token = token
|
|
bias_array[i].bias = bias
|
|
|
|
sampler = llama_cpp.llama_sampler_init_logit_bias(
|
|
n_vocab,
|
|
len(logit_bias),
|
|
bias_array
|
|
)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_infill(self, model: LlamaModel):
|
|
sampler = llama_cpp.llama_sampler_init_infill(model.vocab)
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
|
|
def add_custom(
|
|
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
|
|
):
|
|
custom_sampler = CustomSampler(apply_func)
|
|
sampler = custom_sampler.get_sampler()
|
|
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
|
|
# NOTE: Must remove custom samplers before free or llama.cpp will try to free them
|
|
self.custom_samplers.append(
|
|
(llama_cpp.llama_sampler_chain_n(self.sampler) - 1, custom_sampler)
|
|
)
|
|
|
|
def get_seed(self) -> int:
|
|
return llama_cpp.llama_sampler_get_seed(self.sampler)
|
|
|
|
def sample(self, ctx: LlamaContext, idx: int = -1) -> int:
|
|
return llama_cpp.llama_sampler_sample(self.sampler, ctx.ctx, idx)
|
|
|
|
def accept(self, token: int):
|
|
llama_cpp.llama_sampler_accept(self.sampler, token)
|
|
|
|
def reset(self):
|
|
llama_cpp.llama_sampler_reset(self.sampler)
|
|
|
|
def clone(self):
|
|
# NOTE: Custom samplers cannot be cloned due to Python callback limitations
|
|
if self.custom_samplers:
|
|
raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers")
|
|
|
|
cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler)
|
|
# Create a new wrapper around the cloned sampler
|
|
new_sampler = LlamaSampler.__new__(LlamaSampler)
|
|
new_sampler.sampler = cloned_sampler
|
|
new_sampler.custom_samplers = []
|
|
new_sampler._exit_stack = ExitStack()
|
|
|
|
def free_sampler():
|
|
if new_sampler.sampler is not None:
|
|
llama_cpp.llama_sampler_free(new_sampler.sampler)
|
|
new_sampler.sampler = None
|
|
|
|
new_sampler._exit_stack.callback(free_sampler)
|
|
return new_sampler
|