File size: 1,705 Bytes
4c81d8f
 
 
 
 
64e14a4
4c81d8f
 
 
 
 
31c3b35
4c81d8f
 
 
31c3b35
 
4c81d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c3b35
 
4c81d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import torch
import os

token = os.getenv("HUGGINGFACE_HUB_TOKEN")

class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", token=token)
        base_model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-2-7b-hf",
            torch_dtype=torch.float16,
            device_map="auto",
            token=token
        )

        lora_config = LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["q_proj"],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )

        self.model = get_peft_model(base_model, lora_config)
        adapter_path = hf_hub_download(
            repo_id="vignesh0007/Anime-Gen-Llama-2-7B",
            filename="adapter_model.safetensors",
            repo_type="model",
            token=token
        )
        lora_state = load_file(adapter_path)
        self.model.load_state_dict(lora_state, strict=False)
        self.model.eval()

    def __call__(self, data):
        inputs = data.get("inputs", "")
        tokens = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
        with torch.no_grad():
            outputs = self.model.generate(
                **tokens,
                max_new_tokens=256,
                temperature=0.8,
                top_p=0.95,
                do_sample=True
            )
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)