File size: 1,544 Bytes
5ec7b76
 
 
 
 
 
 
 
 
 
 
 
 
48a378e
5ec7b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

from airllm import AirLLMLlama2

class AirLLM(LLM):
    max_len: int
    model: AirLLMLlama2 

    def __init__(self, llama2_model_id : str, max_len : int, compression = ""):            
        # could use hugging face model repo id:
        self.model = AirLLMLlama2(llama2_model_id)#,compression=compression
        self.max_len = max_len

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        
        input_tokens = model.tokenizer(input_text,
            return_tensors="pt", 
            return_attention_mask=False, 
            truncation=True, 
            max_length=self.max_len, 
            padding=True)
                
        generation_output = model.generate(
            input_tokens['input_ids'].cuda(), 
            max_new_tokens=20,
            use_cache=True,
            return_dict_in_generate=True)

        output = model.tokenizer.decode(generation_output.sequences[0])
        return output


    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"max_len": self.max_len}