Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,686 Bytes
ffc6d0a 675a4cb a60978d ffc6d0a d19d5db ffc6d0a 5e94e7f 675a4cb d19d5db 9e39b36 ffc6d0a 675a4cb a60978d 5e94e7f a60978d 6389312 a60978d 675a4cb a60978d 675a4cb a60978d 675a4cb a60978d 675a4cb ffc6d0a a60978d ffc6d0a 675a4cb ffc6d0a 675a4cb a60978d 6389312 675a4cb a60978d 6389312 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import spaces
# Load model directly
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("Navid-AI/Mulhem-1-Mini", token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained("Navid-AI/Mulhem-1-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", token=os.getenv("HF_TOKEN")).to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
enable_reasoning,
system_message,
max_tokens,
temperature,
repetition_penalty,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0].strip()})
if val[1]:
messages.append({"role": "assistant", "content": val[1].strip()})
messages.append({"role": "user", "content": message})
print(messages)
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, enable_reasoning=enable_reasoning, return_dict=True).to(device)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Checkbox(label="Enable reasoning", value=False),
gr.Textbox(value="ุฃูุช ู
ูููู
. ุฐูุงุก ุงุตุทูุงุนู ุชู
ุฅูุดุงุคู ู
ู ุดุฑูุฉ ูููุฏ ูุฅููุงู
ูุชุญููุฒ ุงูู
ุณุชุฎุฏู
ูู ุนูู ุงูุชุนููู
ุ ุงููู
ูุ ูุชุญููู ุฃูุฏุงููู
.", label="System message"),
gr.Slider(minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.25, step=0.05, label="Repetition penalty"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch() |