demo / app.py
TenzinGayche's picture
Update app.py
eae63f9 verified
raw
history blame
2.84 kB
import os
from threading import Thread, Event
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
DESCRIPTION = """\
# Monlam LLM
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained("TenzinGayche/example")
model = AutoModelForCausalLM.from_pretrained("TenzinGayche/example", torch_dtype=torch.float16).to("cuda")
model.config.sliding_window = 4096
model.eval()
# Create a shared stop event
stop_event = Event()
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
# Clear the stop event before starting a new generation
stop_event.clear()
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
if stop_event.is_set():
break # Stop if the stop button is pressed
outputs.append(text)
yield "".join(outputs)
# Define a function to stop the generation
def stop_generation():
stop_event.set()
# Create the chat interface with additional inputs and the stop button
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
# Create the chat interface
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)