robert
Increasing timeout to check shared loading
3b1ba3c
raw
history blame
5.27 kB
import json
import os
import random
from threading import Thread
import gradio as gr
import spaces
import torch
from langchain.schema import AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, SecretStr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
tokenizer = AutoTokenizer.from_pretrained("ContextualAI/archangel_sft-kto_llama30b")
model = AutoModelForCausalLM.from_pretrained("ContextualAI/archangel_sft-kto_llama30b")
model = model.to("cuda:0")
class OAAPIKey(BaseModel):
openai_api_key: SecretStr
def set_openai_api_key(api_key: SecretStr):
os.environ["OPENAI_API_KEY"] = api_key.get_secret_value()
llm = ChatOpenAI(temperature=1.0, model="gpt-3.5-turbo-0125")
return llm
class StopOnSequence(StoppingCriteria):
def __init__(self, sequence, tokenizer):
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_len = len(self.sequence_ids)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
if input_ids.shape[1] < self.sequence_len:
return False
return (
(
input_ids[0, -self.sequence_len:]
== torch.tensor(self.sequence_ids, device=input_ids.device)
)
.all()
.item()
)
@spaces.GPU(duration=512)
def spaces_model_predict(message: str, history: list[tuple[str, str]]):
history_transformer_format = history + [[message, ""]]
stop = StopOnSequence("<|human|>", tokenizer)
messages = "".join(
[
"".join(["\n<human>:" + item[0], "\n<ai>:" + item[1]])
for item in history_transformer_format
]
)
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != "<":
partial_message += new_token
return partial_message
def predict(
message: str,
chat_history_openai: list[tuple[str, str]],
chat_history_spaces: list[tuple[str, str]],
openai_api_key: SecretStr,
):
openai_key_model = OAAPIKey(openai_api_key=openai_api_key)
openai_llm = set_openai_api_key(api_key=openai_key_model.openai_api_key)
# OpenAI
history_langchain_format_openai = []
for human, ai in chat_history_openai:
history_langchain_format_openai.append(HumanMessage(content=human))
history_langchain_format_openai.append(AIMessage(content=ai))
history_langchain_format_openai.append(HumanMessage(content=message))
openai_response = openai_llm.invoke(input=history_langchain_format_openai)
# Spaces Model
spaces_model_response = spaces_model_predict(message, chat_history_spaces)
chat_history_openai.append((message, openai_response.content))
chat_history_spaces.append((message, spaces_model_response))
return "", chat_history_openai, chat_history_spaces
with open("askbakingtop.json", "r") as file:
ask_baking_msgs = json.load(file)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
openai_api_key = gr.Textbox(
label="Please enter your OpenAI API key",
type="password",
elem_id="lets-chat-openai-api-key",
)
with gr.Row():
options = [ask["history"] for ask in random.sample(ask_baking_msgs, k=3)]
msg = gr.Dropdown(
options,
label="Please enter your message",
interactive=True,
multiselect=False,
allow_custom_value=True
)
with gr.Row():
with gr.Column(scale=1):
chatbot_openai = gr.Chatbot(label="OpenAI Chatbot 🏒")
with gr.Column(scale=1):
chatbot_spaces = gr.Chatbot(
label="Your own fine-tuned preference optimized Chatbot πŸ’ͺ"
)
with gr.Row():
submit_button = gr.Button("Submit")
with gr.Row():
clear = gr.ClearButton([msg])
def respond(
message: str,
chat_history_openai: list[tuple[str, str]],
chat_history_spaces: list[tuple[str, str]],
openai_api_key: SecretStr,
):
return predict(
message=message,
chat_history_openai=chat_history_openai,
chat_history_spaces=chat_history_spaces,
openai_api_key=openai_api_key,
)
submit_button.click(
fn=respond,
inputs=[
msg,
chatbot_openai,
chatbot_spaces,
openai_api_key,
],
outputs=[msg, chatbot_openai, chatbot_spaces],
)
demo.launch()