File size: 3,924 Bytes
717452a e9bec21 717452a e9bec21 717452a e7e3941 717452a d793c60 717452a d117130 717452a e573bba 717452a f3bf7cd 717452a e7e3941 a079f79 717452a e7e3941 717452a a079f79 717452a e9bec21 717452a e9bec21 717452a ce6fe13 e9bec21 a079f79 e9bec21 717452a a079f79 717452a e9bec21 717452a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import os
from threading import Thread, Event
from typing import Iterator
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
DESCRIPTION = """\
# Monlam LLM v2.0.1
"""
path="TenzinGayche/tpo_v1.0.0_202_ft"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
model.config.sliding_window = 4096
model.eval()
# Create a shared stop event
stop_event = Event()
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 2048,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
do_sample: bool= True
) -> Iterator[str]:
# Clear the stop event before starting a new generation
stop_event.clear()
# Append the user's message to the conversation history
conversation = chat_history.copy()
if not conversation:
conversation.extend([
{
"role": "user",
"content": "ཁྱེད་རང་སྨོན་ལམ་མི་བཟོས་རིག་ནུས་ཤིག་ཡིན་པ་དང་ཁྱེད་རང་མི་བཟོས་རིག་ནུས་(AI)ཤིག་ཡིན།"
},
{
"role": "assistant",
"content": "ལགས་སོ། ང་ཡིས་ཁྱེད་ཀྱི་བཀའ་བཞིན་སྒྲུབ་ཆོག"
}
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
# Create a streamer to get the generated response
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
top_p=top_p,
do_sample=do_sample,
temperature=temperature
)
# Run generation in a background thread
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
if stop_event.is_set():
break # Stop if the stop button is pressed
outputs.append(text)
yield "".join(outputs)
# After generation, append the assistant's response to the chat history
assistant_response = "".join(outputs)
chat_history.append({"role": "assistant", "content": assistant_response})
# Define a function to stop the generation
def stop_generation():
stop_event.set()
# Create the chat interface with additional inputs and the stop button
with gr.Blocks(css="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
# Create the chat interface
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True)
|