from typing import Any, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from airllm import AirLLMLlama2 class AirLLM(LLM): max_len: int model: AirLLMLlama2 def __init__(self, llama2_model_id : str, max_len : int, compression = ""): # could use hugging face model repo id: self.model = AirLLMLlama2(llama2_model_id)#,compression=compression self.max_len = max_len @property def _llm_type(self) -> str: return "custom" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") input_tokens = model.tokenizer(input_text, return_tensors="pt", return_attention_mask=False, truncation=True, max_length=self.max_len, padding=True) generation_output = model.generate( input_tokens['input_ids'].cuda(), max_new_tokens=20, use_cache=True, return_dict_in_generate=True) output = model.tokenizer.decode(generation_output.sequences[0]) return output @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {"max_len": self.max_len}