File size: 2,182 Bytes
9890e4e
 
 
 
5398a1f
eab9373
 
 
 
9890e4e
 
 
 
 
eab9373
 
2c0e0ca
 
9890e4e
 
 
 
 
 
a36c9e4
9890e4e
 
 
 
3ac4d0b
 
1b9dc29
 
 
 
 
 
 
 
 
 
3ac4d0b
9890e4e
 
 
3ac4d0b
9890e4e
 
 
 
 
a36c9e4
9890e4e
 
 
3ac4d0b
 
eab9373
b1e6d73
 
 
bef6eeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from typing import Dict, List, Any
import transformers
import torch
from datetime import datetime
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers.utils import logging

logging.set_verbosity_info()
logger = logging.get_logger("transformers")


class EndpointHandler():

    def __init__(self, path=""):
        logger.info("111111111111111111111111111")
        logger.info(f"Hugging face handler path {path}")
        path = 'mosaicml/mpt-7b-instruct'
        #path = 'mosaicml/mpt-7b'
        self.model = transformers.AutoModelForCausalLM.from_pretrained(path,
            #"/Users/itamarlevi/Downloads/my_repo_hf/hf/mpt-7b/venv/Itamarl/test",
            # 'mosaicml/mpt-7b-instruct',
            # 'mosaicml/mpt-7b',
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
            max_seq_len=32000
            )

        self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
        print("tokenizer created  ", datetime.now())

        
        stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
    
        class StopOnTokens(StoppingCriteria):
            def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
                for stop_id in stop_token_ids:
                    if input_ids[0][-1] == stop_id:
                        return True
                return False
    
        stopping_criteria = StoppingCriteriaList([StopOnTokens()])
        
        self.generate_text = transformers.pipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            stopping_criteria=stopping_criteria,
            task='text-generation',
            return_full_text=True,
            temperature=0.1,
            top_p=0.15,
            top_k=0,
            max_new_tokens=2048,
            repetition_penalty=1.1
        )


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:    
        logger.info(f"iiinnnnnnnnnn    {data}")
        inputs = data.pop("inputs",data)
        logger.info(f"iiinnnnnnnnnnbbbbbb    {inputs}")
        res = self.generate_text(inputs)
        return res