MohamedRashad's picture
Update app.py
5e94e7f verified
raw
history blame
2.69 kB
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces
# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
enable_reasoning,
system_message,
max_tokens,
temperature,
repetition_penalty,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0].strip()})
if val[1]:
messages.append({"role": "assistant", "content": val[1].strip()})
messages.append({"role": "user", "content": message})
print(messages)
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning, return_dict=True).to(device)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Checkbox(label="Enable reasoning", value=False),
gr.Textbox(value="ุฃู†ุช ู…ูู„ู‡ู…. ุฐูƒุงุก ุงุตุทู†ุงุนูŠ ุชู… ุฅู†ุดุงุคู‡ ู…ู† ุดุฑูƒุฉ ู†ููŠุฏ ู„ุฅู„ู‡ุงู… ูˆุชุญููŠุฒ ุงู„ู…ุณุชุฎุฏู…ูŠู† ุนู„ู‰ ุงู„ุชุนู„ู‘ู…ุŒ ุงู„ู†ู…ูˆุŒ ูˆุชุญู‚ูŠู‚ ุฃู‡ุฏุงูู‡ู….", label="System message"),
gr.Slider(minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.25, step=0.05, label="Repetition penalty"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()