Spaces:
Runtime error
Runtime error
File size: 1,544 Bytes
5ec7b76 48a378e 5ec7b76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from airllm import AirLLMLlama2
class AirLLM(LLM):
max_len: int
model: AirLLMLlama2
def __init__(self, llama2_model_id : str, max_len : int, compression = ""):
# could use hugging face model repo id:
self.model = AirLLMLlama2(llama2_model_id)#,compression=compression
self.max_len = max_len
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
input_tokens = model.tokenizer(input_text,
return_tensors="pt",
return_attention_mask=False,
truncation=True,
max_length=self.max_len,
padding=True)
generation_output = model.generate(
input_tokens['input_ids'].cuda(),
max_new_tokens=20,
use_cache=True,
return_dict_in_generate=True)
output = model.tokenizer.decode(generation_output.sequences[0])
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"max_len": self.max_len}
|