Spaces:
Running
on
Zero
Running
on
Zero
import subprocess | |
import sys | |
import spaces | |
# install packages for mamba | |
def install(): | |
print("Install personal packages", flush=True) | |
subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/archive/refs/tags/v1.4.0.tar.gz")) | |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/archive/refs/tags/v2.2.2.tar.gz")) | |
install() | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import gradio as gr | |
from threading import Thread | |
MODEL = "tiiuae/falcon-mamba-7b-instruct" | |
TITLE = "<h1><center>FalconMamba-7b playground</center></h1>" | |
SUB_TITLE = """<center>FalconMamba is a new model released by Technology Innovation Institute (TII) in Abu Dhabi. The model is open source and available within the Hugging Face ecosystem for anyone to use it for their research or application purpose. Refer to <a href="https://hf.co/blog/falconmamba">the HF release blogpost</a> or <a href="https://www.tii.ae/news/uaes-technology-innovation-institute-revolutionizes-ai-language-models-new-architecture">the official announcement</a> for more details. This interface has been created for quick validation purposes, do not use it for production.</center>""" | |
CSS = """ | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: black !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
} | |
""" | |
END_MESSAGE = """ | |
\n | |
**The conversation has reached to its end, please press "Clear" to restart a new conversation** | |
""" | |
device = "cuda" # for GPU usage or "cpu" for CPU usage | |
tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL, | |
torch_dtype=torch.bfloat16, | |
).to(device) | |
if device == "cuda": | |
model = torch.compile(model) | |
def stream_chat( | |
message: str, | |
history: list, | |
temperature: float = 0.3, | |
max_new_tokens: int = 1024, | |
top_p: float = 1.0, | |
top_k: int = 20, | |
penalty: float = 1.2, | |
): | |
print(f'message: {message}') | |
print(f'history: {history}') | |
conversation = [] | |
for prompt, answer in history: | |
conversation.extend([ | |
{"role": "user", "content": prompt}, | |
{"role": "assistant", "content": answer}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
input_text = tokenizer.apply_chat_template(conversation, tokenize=False) | |
input_text += "<|im_start|>assistant\n" | |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=inputs, | |
max_new_tokens = max_new_tokens, | |
do_sample = False if temperature == 0 else True, | |
top_p = top_p, | |
top_k = top_k, | |
temperature = temperature, | |
streamer=streamer, | |
pad_token_id = 10, | |
) | |
with torch.no_grad(): | |
thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
yield buffer | |
print(f'response: {buffer}') | |
chatbot = gr.Chatbot(height=600) | |
with gr.Blocks(css=CSS, theme="soft") as demo: | |
gr.HTML(TITLE) | |
gr.HTML(SUB_TITLE) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
gr.ChatInterface( | |
fn=stream_chat, | |
chatbot=chatbot, | |
fill_height=True, | |
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
additional_inputs=[ | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.3, | |
label="Temperature", | |
render=False, | |
), | |
gr.Slider( | |
minimum=128, | |
maximum=8192, | |
step=1, | |
value=1024, | |
label="Max new tokens", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
label="top_p", | |
render=False, | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=20, | |
label="top_k", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=1.2, | |
label="Repetition penalty", | |
render=False, | |
), | |
], | |
examples=[ | |
["Hello there, can you suggest few places to visit in UAE?"], | |
["What UAE is known for?"], | |
], | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() |