Spaces:
Runtime error
Runtime error
File size: 2,517 Bytes
c152a6e 92360e8 c152a6e 92360e8 c152a6e 92360e8 805fbb1 92360e8 c152a6e 92360e8 c152a6e 92360e8 c152a6e 92360e8 c152a6e 92360e8 c152a6e 92360e8 3bac1fd 92360e8 c152a6e 92360e8 c152a6e 26c120c 965be30 e414796 965be30 c152a6e 92360e8 c152a6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from transformers import AutoTokenizer
import time
import gradio as gr
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from optimum.utils import NormalizedConfigManager
@classmethod
def _new_get_normalized_config_class(cls, model_type):
return cls._conf["t5"]
NormalizedConfigManager.get_normalized_config_class = _new_get_normalized_config_class
N = 2 # Number of previous QA pairs to use for context
MAX_NEW_TOKENS = 128 # Maximum number of tokens for each answer
tokenizer = AutoTokenizer.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
model = ORTModelForSeq2SeqLM.from_pretrained("tryolabs/long-t5-tglobal-base-blogpost-cqa-onnx")
with open("updated_context.txt", "r") as f:
context = f.read()
def build_input(question, state=[[],[]]):
model_input = f"{context} || "
previous = min(len(state[1][1:]), N)
for i in range(previous, 0, -1):
prev_question = state[0][-i-1]
prev_answer = state[1][-i]
model_input += f"<Q{i}> {prev_question} <A{i}> {prev_answer} "
model_input += f"<Q> {question} <A> "
return model_input
def get_model_answer(question, state=[[],[]]):
start = time.perf_counter()
model_input = build_input(question, state)
end = time.perf_counter()
print(f"Build input: {end-start}")
start = time.perf_counter()
encoded_inputs = tokenizer(model_input, max_length=7000, truncation=True, return_tensors="pt")
input_ids, attention_mask = (
encoded_inputs.input_ids,
encoded_inputs.attention_mask
)
end = time.perf_counter()
print(f"Tokenize: {end-start}")
start = time.perf_counter()
encoded_output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=MAX_NEW_TOKENS)
answer = tokenizer.decode(encoded_output[0], skip_special_tokens=True)
end = time.perf_counter()
print(f"Generate: {end-start}")
state[0].append(question)
state[1].append(answer)
responses = [(state[0][i], state[1][i]) for i in range(len(state[0]))]
return responses, state
with gr.Blocks() as demo:
state = gr.State([[],[]])
chatbot = gr.Chatbot()
text = gr.Textbox(label="Ask a question (press enter to submit)", value="How are you?")
gr.Examples(
["What's the name of the dataset that was built?", "what task does it focus on?", "what is that task about?"],
text
)
text.submit(get_model_answer, [text, state], [chatbot, state])
text.submit(lambda x: "", text, text)
demo.launch() |