File size: 928 Bytes
f25fc5b
96f6ccb
f25fc5b
 
 
2927230
96f6ccb
f25fc5b
 
 
96f6ccb
 
f25fc5b
96f6ccb
f25fc5b
3213cb7
96f6ccb
f25fc5b
 
 
cb24644
f25fc5b
 
96f6ccb
f25fc5b
96f6ccb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import GenerationConfig, pipeline
import torch
from typing import Any, Dict

dtype = torch.float16


class EndpointHandler:
    def __init__(self, path=""):
        tokenizer = AutoTokenizer.from_pretrained(path)
        model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=dtype)

        self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)

    def __call__(self, data: Dict[str, Any]) -> [str]:
        inputs = data.pop("inputs", data)
        generation_config = GenerationConfig(
            max_new_tokens=250, do_sample=True, top_k=50,
            temperature=0.8, pad_token_id=2, num_return_sequences=1,
            min_new_tokens=60, repetition_penalty=1.2, return_full_text=False
        )

        output = self.pipeline(inputs, **generation_config.to_dict())

        return output