mpt-7b / handler.py
Itamarl's picture
Update handler.py
2c0e0ca
raw
history blame
2.18 kB
from typing import Dict, List, Any
import transformers
import torch
from datetime import datetime
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger("transformers")
class EndpointHandler():
def __init__(self, path=""):
logger.info("111111111111111111111111111")
logger.info(f"Hugging face handler path {path}")
path = 'mosaicml/mpt-7b-instruct'
#path = 'mosaicml/mpt-7b'
self.model = transformers.AutoModelForCausalLM.from_pretrained(path,
#"/Users/itamarlevi/Downloads/my_repo_hf/hf/mpt-7b/venv/Itamarl/test",
# 'mosaicml/mpt-7b-instruct',
# 'mosaicml/mpt-7b',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
max_seq_len=32000
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
print("tokenizer created ", datetime.now())
stop_token_ids = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
self.generate_text = transformers.pipeline(
model=self.model,
tokenizer=self.tokenizer,
stopping_criteria=stopping_criteria,
task='text-generation',
return_full_text=True,
temperature=0.1,
top_p=0.15,
top_k=0,
max_new_tokens=2048,
repetition_penalty=1.1
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
logger.info(f"iiinnnnnnnnnn {data}")
inputs = data.pop("inputs",data)
logger.info(f"iiinnnnnnnnnnbbbbbb {inputs}")
res = self.generate_text(inputs)
return res