MPT7BTest / app.py
danavirtual's picture
added InstructionTextGenerationPipeline
3abff73
raw
history blame
6.61 kB
import gradio as gr
import requests
import torch
import transformers
import einops
###
from typing import Any, Dict, Tuple
import warnings
import datetime
import os
from threading import Event, Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
import config
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
PROMPT_FOR_GENERATION_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
""".format(
intro=INTRO_BLURB,
instruction_key=INSTRUCTION_KEY,
instruction="{instruction}",
response_key=RESPONSE_KEY,
)
#class InstructionTextGenerationPipeline:
# def __init__(
# self,
# model_name,
# torch_dtype=torch.bfloat16,
# trust_remote_code=True,
# use_auth_token=None,
# ) -> None:
# self.model = AutoModelForCausalLM.from_pretrained(
# model_name,
# torch_dtype=torch_dtype,
# trust_remote_code=trust_remote_code,
# use_auth_token=use_auth_token,
# )
#
# tokenizer = AutoTokenizer.from_pretrained(
# model_name,
# trust_remote_code=trust_remote_code,
# use_auth_token=use_auth_token,
# )
# if tokenizer.pad_token_id is None:
# warnings.warn(
# "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
# )
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"
# self.tokenizer = tokenizer
#
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.model.eval()
# self.model.to(device=device, dtype=torch_dtype)
#
# self.generate_kwargs = {
# "temperature": 0.5,
# "top_p": 0.92,
# "top_k": 0,
# "max_new_tokens": 512,
# "use_cache": True,
# "do_sample": True,
# "eos_token_id": self.tokenizer.eos_token_id,
# "pad_token_id": self.tokenizer.pad_token_id,
# "repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
# }
#
# def format_instruction(self, instruction):
# return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
#
# def __call__(
# self, instruction: str, **generate_kwargs: Dict[str, Any]
# ) -> Tuple[str, str, float]:
# s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
# input_ids = self.tokenizer(s, return_tensors="pt").input_ids
# input_ids = input_ids.to(self.model.device)
# gkw = {**self.generate_kwargs, **generate_kwargs}
# with torch.no_grad():
# output_ids = self.model.generate(input_ids, **gkw)
# # Slice the output_ids tensor to get only new tokens
# new_tokens = output_ids[0, len(input_ids[0]) :]
# output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
# return output_text
##
from InstructionTextGenerationPipeline import *
from timeit import default_timer as timer
import time
import datetime
from datetime import datetime
import json
# create some interactive controls
import sys
import os
import os.path as osp
import pprint
pp = pprint.PrettyPrinter(indent=4)
LIBRARY_PATH = "/home/ec2-user/workspace/Notebooks/lib"
module_path = os.path.abspath(os.path.join(LIBRARY_PATH))
if module_path not in sys.path:
sys.path.append(module_path)
print (f"sys.path : {sys.path}")
def complete(state="complete"):
print(f"\nCell {state} @ {(datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S'))}")
complete(state='imports done')
complete(state="start generate")
generate = InstructionTextGenerationPipeline(
"mosaicml/mpt-7b-instruct",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
complete(state="Model generated")
# Define a custom stopping criteria
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def process_stream(instruction, temperature, top_p, top_k, max_new_tokens):
# Tokenize the input
input_ids = generate.tokenizer(
generate.format_instruction(instruction), return_tensors="pt"
).input_ids
input_ids = input_ids.to(generate.model.device)
# Initialize the streamer and stopping criteria
streamer = TextIteratorStreamer(
generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop = StopOnTokens()
if temperature < 0.1:
temperature = 0.0
do_sample = False
else:
do_sample = True
gkw = {
**generate.generate_kwargs,
**{
"input_ids": input_ids,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"do_sample": do_sample,
"top_p": top_p,
"top_k": top_k,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([stop]),
},
}
response = ''
def generate_and_signal_complete():
generate.model.generate(**gkw)
t1 = Thread(target=generate_and_signal_complete)
t1.start()
for new_text in streamer:
response += new_text
return response
gr.close_all()
def tester(uPrompt, max_new_tokens, temperature, top_k, top_p):
salutation = uPrompt
response = process_stream(uPrompt, temperature, top_p, top_k, max_new_tokens)
results = f"{salutation} max_new_tokens{max_new_tokens}; temperature{temperature}; top_k{top_k}; top_p{top_p}; "
return response
config.init_device="meta"
demo = gr.Interface(
fn=tester,
inputs=[gr.Textbox(label="Prompt",info="Prompt",lines=3,value="Provide Prompt"),
gr.Slider(256, 3072,value=1024, step=256, label="Tokens" ),
gr.Slider(0.0, 1.0, value=0.1, step=0.1, label='temperature:'),
gr.Slider(0, 1, value=0, step=1, label='top_k:'),
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label='top_p:')
],
outputs=["text"],
)
demo.launch(share=True,
server_name="0.0.0.0",
server_port=8081
)