File size: 3,924 Bytes
717452a
 
 
 
e9bec21
717452a
e9bec21
717452a
 
e7e3941
717452a
d793c60
717452a
 
 
 
 
d117130
 
717452a
 
 
 
 
 
 
 
 
 
e573bba
717452a
 
 
 
f3bf7cd
717452a
 
 
 
e7e3941
a079f79
717452a
e7e3941
 
 
 
 
 
 
 
 
 
 
717452a
 
 
 
 
 
 
 
a079f79
717452a
e9bec21
717452a
e9bec21
717452a
ce6fe13
 
 
e9bec21
 
a079f79
e9bec21
 
 
717452a
 
 
 
 
 
 
a079f79
 
 
 
 
717452a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9bec21
717452a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from threading import Thread, Event
from typing import Iterator

import gradio as gr

import torch
from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
DESCRIPTION = """\
# Monlam LLM v2.0.1
"""
path="TenzinGayche/tpo_v1.0.0_202_ft"
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# Load the model and tokenizer
tokenizer = GemmaTokenizerFast.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16).to("cuda")

model.config.sliding_window = 4096
model.eval()

# Create a shared stop event
stop_event = Event()

def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 2048,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    do_sample: bool= True
) -> Iterator[str]:
    # Clear the stop event before starting a new generation
    stop_event.clear()


    # Append the user's message to the conversation history
    conversation = chat_history.copy()
    if not conversation:
        conversation.extend([
            {
                "role": "user",
                "content": "ཁྱེད་རང་སྨོན་ལམ་མི་བཟོས་རིག་ནུས་ཤིག་ཡིན་པ་དང་ཁྱེད་རང་མི་བཟོས་རིག་ནུས་(AI)ཤིག་ཡིན།"
            },
            {
                "role": "assistant",
                "content": "ལགས་སོ། ང་ཡིས་ཁྱེད་ཀྱི་བཀའ་བཞིན་སྒྲུབ་ཆོག"
            }
        ])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    # Create a streamer to get the generated response
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        do_sample=do_sample,
        temperature=temperature
    )
    
    # Run generation in a background thread
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        if stop_event.is_set():
            break  # Stop if the stop button is pressed
        outputs.append(text)
        yield "".join(outputs)

    # After generation, append the assistant's response to the chat history
    assistant_response = "".join(outputs)
    chat_history.append({"role": "assistant", "content": assistant_response})


# Define a function to stop the generation
def stop_generation():
    stop_event.set()

# Create the chat interface with additional inputs and the stop button
with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)

    # Create the chat interface
    chat_interface = gr.ChatInterface(
        fn=generate,
        examples=[
            ["Hello there! How are you doing?"],
            ["Can you explain briefly to me what is the Python programming language?"],
            ["Explain the plot of Cinderella in a sentence."],
            ["How many hours does it take a man to eat a Helicopter?"],
            ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
        ],
        cache_examples=False,
        type="messages",
    )
    

if __name__ == "__main__":
    demo.queue(max_size=20).launch(share=True)