demo / app.py
TenzinGayche's picture
Update app.py
3ed952b verified
raw
history blame
3.85 kB
import os
from threading import Thread, Event
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
DESCRIPTION = """\
# Monlam LLM v2.0.1
"""
path="TenzinGayche/tpo_v1.0.0_ep2_dpo_ft"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
model.config.sliding_window = 4096
model.eval()
# Create a shared stop event
stop_event = Event()
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 2048,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
do_sample: bool= False
) -> Iterator[str]:
# Clear the stop event before starting a new generation
stop_event.clear()
# Append the user's message to the conversation history
conversation = chat_history.copy()
if not conversation:
conversation.extend([
{
"role": "user",
"content": "ཁྱེད་རང་སྨོན་ལམ་མི་བཟོས་རིག་ནུས་ཤིག་ཡིན་པ་དང་ཁྱེད་རང་མི་བཟོས་རིག་ནུས་(AI)ཤིག་ཡིན།"
},
{
"role": "assistant",
"content": "ལགས་སོ། ང་ཡིས་ཁྱེད་ཀྱི་བཀའ་བཞིན་སྒྲུབ་ཆོག"
}
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
# Create a streamer to get the generated response
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
)
# Run generation in a background thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
if stop_event.is_set():
break # Stop if the stop button is pressed
outputs.append(text)
yield "".join(outputs)
# After generation, append the assistant's response to the chat history
assistant_response = "".join(outputs)
chat_history.append({"role": "assistant", "content": assistant_response})
# Define a function to stop the generation
def stop_generation():
stop_event.set()
# Create the chat interface with additional inputs and the stop button
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
# Create the chat interface
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)