Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import torch | |
from langchain.schema import AIMessage, HumanMessage | |
from langchain_openai import ChatOpenAI | |
from pydantic import BaseModel, SecretStr | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("ContextualAI/archangel_sft-kto_llama13b") | |
model = AutoModelForCausalLM.from_pretrained( | |
"ContextualAI/archangel_sft-kto_llama13b", device_map="auto", load_in_4bit=True | |
) | |
class OAAPIKey(BaseModel): | |
openai_api_key: SecretStr | |
def set_openai_api_key(api_key: SecretStr): | |
os.environ["OPENAI_API_KEY"] = api_key.get_secret_value() | |
llm = ChatOpenAI(temperature=1.0, model="gpt-3.5-turbo-0125") | |
return llm | |
class StopOnSequence(StoppingCriteria): | |
def __init__(self, sequence, tokenizer): | |
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) | |
self.sequence_len = len(self.sequence_ids) | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
) -> bool: | |
for i in range(input_ids.shape[0]): | |
if input_ids[i, -self.sequence_len:].tolist() == self.sequence_ids: | |
return True | |
return False | |
def spaces_model_predict(message: str, history: list[tuple[str, str]]): | |
history_transformer_format = history + [[message, ""]] | |
stop = StopOnSequence("<|user|>", tokenizer) | |
messages = "".join( | |
[ | |
f"<|user|>\n{item[0]}\n<|assistant|>\n{item[1]}" | |
for item in history_transformer_format | |
] | |
) | |
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=512, | |
do_sample=True, | |
top_p=0.95, | |
top_k=1000, | |
temperature=1.0, | |
num_beams=1, | |
stopping_criteria=StoppingCriteriaList([stop]), | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_message = "" | |
for new_token in streamer: | |
partial_message += new_token | |
yield partial_message | |
def predict( | |
message: str, | |
chat_history_openai: list[tuple[str, str]], | |
chat_history_spaces: list[tuple[str, str]], | |
openai_api_key: SecretStr, | |
): | |
openai_key_model = OAAPIKey(openai_api_key=openai_api_key) | |
openai_llm = set_openai_api_key(api_key=openai_key_model.openai_api_key) | |
# OpenAI | |
history_langchain_format_openai = [] | |
for human, ai in chat_history_openai: | |
history_langchain_format_openai.append(HumanMessage(content=human)) | |
history_langchain_format_openai.append(AIMessage(content=ai)) | |
history_langchain_format_openai.append(HumanMessage(content=message)) | |
openai_response = openai_llm.invoke(input=history_langchain_format_openai) | |
# Spaces Model | |
spaces_model_response = spaces_model_predict(message, chat_history_spaces) | |
chat_history_openai.append((message, openai_response.content)) | |
chat_history_spaces.append((message, spaces_model_response)) | |
return "", chat_history_openai, chat_history_spaces | |
with open("askbakingtop.json", "r") as file: | |
ask_baking_msgs = json.load(file) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
openai_api_key = gr.Textbox( | |
label="Please enter your OpenAI API key", | |
type="password", | |
elem_id="lets-chat-openai-api-key", | |
) | |
with gr.Row(): | |
options = [ask["history"] for ask in random.sample(ask_baking_msgs, k=3)] | |
msg = gr.Dropdown( | |
options, | |
label="Please enter your message", | |
interactive=True, | |
multiselect=False, | |
allow_custom_value=True, | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
chatbot_openai = gr.Chatbot(label="OpenAI Chatbot π’") | |
with gr.Column(scale=1): | |
chatbot_spaces = gr.Chatbot( | |
label="Your own fine-tuned preference optimized Chatbot πͺ" | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Submit") | |
with gr.Row(): | |
clear = gr.ClearButton([msg]) | |
def respond( | |
message: str, | |
chat_history_openai: list[tuple[str, str]], | |
chat_history_spaces: list[tuple[str, str]], | |
openai_api_key: SecretStr, | |
): | |
return predict( | |
message=message, | |
chat_history_openai=chat_history_openai, | |
chat_history_spaces=chat_history_spaces, | |
openai_api_key=openai_api_key, | |
) | |
submit_button.click( | |
fn=respond, | |
inputs=[ | |
msg, | |
chatbot_openai, | |
chatbot_spaces, | |
openai_api_key, | |
], | |
outputs=[msg, chatbot_openai, chatbot_spaces], | |
) | |
demo.launch() | |