from typing import Callable, Generator, Iterator, List, Optional, Union import ctypes from ctypes import ( c_bool, c_char_p, c_int, c_int8, c_int32, c_uint8, c_uint32, c_size_t, c_float, c_double, c_void_p, POINTER, _Pointer, # type: ignore Structure, Array, ) import pathlib import os import sys # Load the library def _load_shared_library(lib_base_name: str): # Construct the paths to the possible shared library names _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) # Searching for the library in the current directory under the name "libllama2" (default name # for llama2.cu) and "llama" (default name for this repo) _lib_paths: List[pathlib.Path] = [] # Determine the file extension based on the platform if sys.platform.startswith("linux"): _lib_paths += [ _base_path / f"lib{lib_base_name}.so", ] else: raise RuntimeError("Unsupported platform") if "LLAMA2_CU_LIB" in os.environ: lib_base_name = os.environ["LLAMA2_CU_LIB"] _lib = pathlib.Path(lib_base_name) _base_path = _lib.parent.resolve() _lib_paths = [_lib.resolve()] cdll_args = dict() # type: ignore # Add the library directory to the DLL search path on Windows (if needed) # Try to load the shared library, handling potential errors for _lib_path in _lib_paths: if _lib_path.exists(): try: return ctypes.CDLL(str(_lib_path), **cdll_args) except Exception as e: raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") raise FileNotFoundError( f"Shared library with base name '{lib_base_name}' not found" ) # Specify the base name of the shared library to load _lib_base_name = "llama2" # Load the library _lib = _load_shared_library(_lib_base_name) def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p: return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8')) _lib.llama2_init.argtypes = [c_char_p, c_char_p] _lib.llama2_init.restype = c_void_p def llama2_free(ctx: c_void_p) -> None: _lib.llama2_free(ctx) _lib.llama2_free.argtypes = [c_void_p] _lib.llama2_free.restype = None def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int: return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed) _lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int] _lib.llama2_generate.restype = c_int def llama2_get_last(ctx: c_void_p) -> bytes: return _lib.llama2_get_last(ctx) # bytes or None _lib.llama2_get_last.argtypes = [c_void_p] _lib.llama2_get_last.restype = c_char_p def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]: tokens = (c_int * (len(text) + 3))() n_tokens = (c_int * 1)() _lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens) return tokens[:n_tokens[0]] _lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)] _lib.llama2_tokenize.restype = None class Llama2: def __init__( self, model_path: str, tokenizer_path: str='tokenizer.bin', n_ctx: int = 0, n_batch: int = 0) -> None: self.n_ctx = n_ctx self.n_batch = n_batch self.llama2_ctx = llama2_init(model_path, tokenizer_path) def tokenize( self, text: str, add_bos: bool = True, add_eos: bool = False ) -> List[int]: return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos) def __call__( self, prompt: str, max_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95, min_p: float = 0.05, typical_p: float = 1.0, logprobs: Optional[int] = None, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repeat_penalty: float = 1.1, top_k: int = 40, stream: bool = False, seed: Optional[int] = None, ) -> Iterator[str]: if seed is None: seed = 42 ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed) if ret != 0: raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'") bytes_buffer = b'' # store generated bytes until decoded (in case of multi-byte characters) while True: result = llama2_get_last(self.llama2_ctx) if result is None: break bytes_buffer += result try: string = bytes_buffer.decode('utf-8') except UnicodeDecodeError: pass else: bytes_buffer = b'' yield string