File size: 3,363 Bytes
1556304
 
 
 
6f346c7
130e4a8
1556304
6f346c7
1556304
 
 
 
6f346c7
 
 
 
 
 
 
1556304
5bd9cae
 
1556304
 
6f346c7
1556304
 
 
6f346c7
 
 
b241b47
1556304
e05cd4e
1c84354
b241b47
1556304
 
 
 
 
 
 
130e4a8
b241b47
 
e05cd4e
6f346c7
 
 
 
 
b241b47
 
130e4a8
b241b47
1556304
6f346c7
 
 
 
 
 
 
 
 
b241b47
6f346c7
e05cd4e
6f346c7
 
1556304
6f346c7
 
 
 
 
 
 
 
 
1556304
6f346c7
 
 
 
 
1556304
 
6f346c7
b241b47
 
 
 
1556304
6f346c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer

# Constants for model behavior
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# Models selection
MODELS = {
    "Fast-Model": "Artples/L-MChat-Small",
    "Quality-Model": "Artples/L-MChat-7b"
}

# Description for the application
DESCRIPTION = """\
# L-MChat
This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-mchat-663265a8351231c428318a8f) by L-AI.
"""

# Check for GPU availability
if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"

# Load models and tokenizers
models = {name: AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") for name, model_id in MODELS.items()}
tokenizers = {name: AutoTokenizer.from_pretrained(model_id) for name, model_id in MODELS.items()}

@spaces.GPU(enable_queue=True, duration=90)
def generate(
    model_choice: str,
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.1,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> str:
    model = models[model_choice]
    tokenizer = tokenizers[model_choice]

    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).input_ids
    input_ids = input_ids.to(model.device)

    output_ids = model.generate(
        input_ids,
        max_length=input_ids.shape[1] + max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        num_return_sequences=1,
    )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Gradio Interface
chat_interface = gr.Interface(
    fn=generate,
    inputs=[
        gr.Dropdown(label="Choose Model", choices=list(MODELS.keys()), default="Quality-Model"),
        gr.ChatBox(),
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
        gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
        gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
        gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
        gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
    ],
    outputs="text",
    theme='ehristoforu/RE_Theme',
    examples=[
        ["Quality-Model", "Hello there! How are you doing?", [], "Let's start the conversation.", 1024, 0.6, 0.9, 50, 1.2]
    ]
)

# Main execution block
with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()