File size: 2,686 Bytes
ffc6d0a
 
 
675a4cb
a60978d
ffc6d0a
d19d5db
ffc6d0a
5e94e7f
675a4cb
d19d5db
9e39b36
ffc6d0a
675a4cb
 
a60978d
5e94e7f
a60978d
 
 
6389312
a60978d
 
 
675a4cb
a60978d
 
 
 
 
 
675a4cb
a60978d
675a4cb
a60978d
 
675a4cb
 
 
ffc6d0a
a60978d
ffc6d0a
675a4cb
ffc6d0a
675a4cb
 
a60978d
 
 
 
 
6389312
 
 
675a4cb
 
a60978d
 
 
 
 
 
 
 
 
 
 
 
6389312
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces

# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

@spaces.GPU
def respond(
    message,
    history: list[tuple[str, str]],
    enable_reasoning,
    system_message,
    max_tokens,
    temperature,
    repetition_penalty,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0].strip()})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1].strip()})

    messages.append({"role": "user", "content": message})
    print(messages)
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning, return_dict=True).to(device)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)

    thread.start()
    response = ""
    for new_text in streamer:
        response += new_text
        yield response


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Checkbox(label="Enable reasoning", value=False),
        gr.Textbox(value="ุฃู†ุช ู…ูู„ู‡ู…. ุฐูƒุงุก ุงุตุทู†ุงุนูŠ ุชู… ุฅู†ุดุงุคู‡ ู…ู† ุดุฑูƒุฉ ู†ููŠุฏ ู„ุฅู„ู‡ุงู… ูˆุชุญููŠุฒ ุงู„ู…ุณุชุฎุฏู…ูŠู† ุนู„ู‰ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู….", label="System message"),
        gr.Slider(minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=2.0, value=1.25, step=0.05, label="Repetition penalty"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()