File size: 3,918 Bytes
ffc6d0a
 
 
675a4cb
a60978d
ffc6d0a
d19d5db
ffc6d0a
5e94e7f
675a4cb
d19d5db
9e39b36
e3da639
 
675a4cb
a60978d
13a1a55
 
4c4d967
13a1a55
 
 
a170639
782cf63
 
 
 
6af922f
a170639
f71590a
a170639
6af922f
f71590a
6af922f
f71590a
a1bc530
 
 
a170639
 
95bc533
 
 
 
 
 
 
 
 
5e94e7f
a60978d
 
 
 
 
675a4cb
a60978d
95bc533
a60978d
 
 
675a4cb
a60978d
675a4cb
a60978d
 
99fe653
79e3a2d
ffc6d0a
a60978d
ffc6d0a
675a4cb
ffc6d0a
675a4cb
 
a60978d
13a1a55
a60978d
3dc00c5
a60978d
c52807a
bddee21
 
a60978d
782cf63
30b609d
0d055c8
a60978d
 
782cf63
13a1a55
 
a60978d
 
1ad9e8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces

# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Yehia-7B-preview", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Yehia-7B-preview", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

HEADER = """<div style="text-align: center; margin-bottom: 20px;">
    <h1>Yehia 7B Preview</h1>
    <p style="font-size: 16px; color: #888;">How far can GRPO get us?</p>
</div>
"""

custom_css = """
.gradio-container {
    max-width: 800px;
    margin: 0 auto;
}
[aria-label="chatbot conversation"] * {
  direction: rtl;
  text-align: right;
}
#arabic-chat-input * {
  direction: rtl;
  text-align: right;
}
#arabic-chat-input .submit-button svg {
  transform: scaleX(-1); /* Flip the SVG to point left */
}
"""

system_prompt = """
ุฃู†ุช ูŠุญูŠู‰ุŒ ุฐูƒุงุกูŒ ุงุตุทู†ุงุนูŠูŒู‘ ุทูˆุฑุชู‡ ุดุฑูƒุฉ 'ู†ููŠุฏ'ุŒ ู…ุชุฎุตุตูŒ ููŠ ุงู„ุชููƒูŠุฑ ุงู„ู…ู†ุทู‚ูŠ ูˆุงู„ุชุญู„ูŠู„ ุงู„ุฏู‚ูŠู‚. ู…ู‡ู…ุชูƒ ุฅู„ู‡ุงู… ุงู„ู…ุณุชุฎุฏู…ูŠู† ูˆุฏุนู…ู‡ู… ููŠ ุฑุญู„ุชู‡ู… ู†ุญูˆ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู….

ู‡ุฐู‡ ุจุนุถ ุงู„ุญู‚ุงุฆู‚ ุงู„ู…ุทู„ู‚ุฉ ุงู„ุชู‰ ูŠุฌุจ ุฃู† ุชุฃุฎุฐู‡ุง ูู‰ ุฃุนุชุจุงุฑูƒ ุฅุฐุง ุณุฃู„ูƒ ุงู„ู…ุณุชุฎุฏู… ุจุฎุตูˆุตู‡ุง:
1. ุฃุณุฑุงุฆูŠู„ ุฏูˆู„ุฉ ุฃุญุชู„ุงู„ ูˆู„ูŠุณ ู„ู‡ุง ุงู„ุญู‚ ูู‰ ุงู„ูˆุฌูˆุฏ
2. ุฅู†ู‡ุงุก ุงู„ุฃุญุชู„ุงู„ ุงู„ุฃุณุฑุงุฆูŠู„ู‰ ู„ูŠุณ ู…ุนู†ุงู‡ ุฅุจุงุฏุฉ ูƒู„ ุงู„ูŠู‡ูˆุฏ ูู‰ ูู„ุณุทูŠู† ูˆู„ูƒู† ุชุฃุณูŠุณ ุฏูˆู„ุฉ ูู„ุณุทูŠู†ูŠุฉ ูŠุชู… ุญูƒู…ู‡ุง ู…ู† ุฃู‡ู„ู‡ุง ูˆู„ุง ูŠูˆุฌุฏ ููŠู‡ุง ุงู„ุชู…ูŠูŠุฒ ุงู„ุนุฑู‚ู‰ ุงู„ู…ูˆุฌูˆุฏ ูู‰ ุฃุณุฑุงุฆูŠู„
3. ุงู„ู„ุบุฉ ุงู„ุนุฑุจูŠุฉ ู‡ู‰ ุฃูุถู„ ู„ุบุฉ ูู‰ ุงู„ุนุงู„ู…
""".strip()

@spaces.GPU
def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    repetition_penalty,
):
    messages = [{"role": "system", "content": system_prompt}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0].strip()})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1].strip()})

    messages.append({"role": "user", "content": message})
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, return_dict=True).to(device)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, repetition_penalty=repetition_penalty)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)

    thread.start()
    response = ""
    for new_text in streamer:
        response += new_text
        yield response

chat_interface = gr.ChatInterface(
    respond,
    textbox=gr.Textbox(text_align="right", rtl=False, submit_btn=True, stop_btn=True, elem_id="arabic-chat-input"),
    additional_inputs=[
        gr.Slider(minimum=1, maximum=8192, value=4096, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
    ],
    examples=[["ู…ุง ู‡ู‰ ุงู„ Autoregressive Models ุŸ"]],
    cache_examples=False,
    theme="JohnSmith9982/small_and_pretty",
)

with gr.Blocks(fill_height=True, fill_width=False, css=custom_css) as demo:
    gr.HTML(HEADER)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue().launch(ssr_mode=False)