Spaces:
Runtime error
Runtime error
from typing import Any, List, Mapping, Optional | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
from airllm import AirLLMLlama2 | |
class AirLLM(LLM): | |
max_len: int | |
model: AirLLMLlama2 | |
def __init__(self, llama2_model_id : str, max_len : int, compression = ""): | |
# could use hugging face model repo id: | |
self.model = AirLLMLlama2(llama2_model_id)#,compression=compression | |
self.max_len = max_len | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
if stop is not None: | |
raise ValueError("stop kwargs are not permitted.") | |
input_tokens = model.tokenizer(input_text, | |
return_tensors="pt", | |
return_attention_mask=False, | |
truncation=True, | |
max_length=self.max_len, | |
padding=True) | |
generation_output = model.generate( | |
input_tokens['input_ids'].cuda(), | |
max_new_tokens=20, | |
use_cache=True, | |
return_dict_in_generate=True) | |
output = model.tokenizer.decode(generation_output.sequences[0]) | |
return output | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {"max_len": self.max_len} | |