File size: 2,182 Bytes
9890e4e 5398a1f eab9373 9890e4e eab9373 2c0e0ca 9890e4e a36c9e4 9890e4e 3ac4d0b 1b9dc29 3ac4d0b 9890e4e 3ac4d0b 9890e4e a36c9e4 9890e4e 3ac4d0b eab9373 b1e6d73 bef6eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from typing import Dict, List, Any
import transformers
import torch
from datetime import datetime
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
class EndpointHandler():
def __init__(self, path=""):
logger.info("111111111111111111111111111")
logger.info(f"Hugging face handler path {path}")
path = 'mosaicml/mpt-7b-instruct'
#path = 'mosaicml/mpt-7b'
self.model = transformers.AutoModelForCausalLM.from_pretrained(path,
#"/Users/itamarlevi/Downloads/my_repo_hf/hf/mpt-7b/venv/Itamarl/test",
# 'mosaicml/mpt-7b-instruct',
# 'mosaicml/mpt-7b',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
max_seq_len=32000
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
print("tokenizer created ", datetime.now())
stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
self.generate_text = transformers.pipeline(
model=self.model,
tokenizer=self.tokenizer,
stopping_criteria=stopping_criteria,
task='text-generation',
return_full_text=True,
temperature=0.1,
top_p=0.15,
top_k=0,
max_new_tokens=2048,
repetition_penalty=1.1
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
logger.info(f"iiinnnnnnnnnn {data}")
inputs = data.pop("inputs",data)
logger.info(f"iiinnnnnnnnnnbbbbbb {inputs}")
res = self.generate_text(inputs)
return res
|