import subprocess import sys import shlex import spaces # install packages for mamba def install(): print("Install personal packages", flush=True) subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/archive/refs/tags/v1.4.0.tar.gz")) subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/archive/refs/tags/v2.2.2.tar.gz")) import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import gradio as gr from threading import Thread MODEL = "tiiuae/falcon-mamba-7b-instruct" TITLE = "

FalconMamba-7b playground

" SUB_TITLE = """
FalconMamba is a new model released by Technology Innovation Institute (TII) in Abu Dhabi. The model is open source and available within the Hugging Face ecosystem for anyone to use it for their research or application purpose. Refer to the HF release blogpost or the official announcement for more details. This interface has been created for quick validation purposes, do not use it for production.
""" CSS = """ .duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important; } h3 { text-align: center; } """ END_MESSAGE = """ \n **The conversation has reached to its end, please press "Clear" to restart a new conversation** """ device = "cuda" # for GPU usage or "cpu" for CPU usage tokenizer = AutoTokenizer.from_pretrained(MODEL) model = AutoModelForCausalLM.from_pretrained( MODEL, torch_dtype=torch.bfloat16, ).to(device) if device == "cuda": model = torch.compile(model) install() @spaces.GPU def stream_chat( message: str, history: list, temperature: float = 0.3, max_new_tokens: int = 1024, top_p: float = 1.0, top_k: int = 20, penalty: float = 1.2, ): print(f'message: {message}') print(f'history: {history}') conversation = [] for prompt, answer in history: conversation.extend([ {"role": "user", "content": prompt}, {"role": "assistant", "content": answer}, ]) conversation.append({"role": "user", "content": message}) input_text = tokenizer.apply_chat_template(conversation, tokenize=False) input_text += "<|im_start|>assistant\n" inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=inputs, max_new_tokens = max_new_tokens, do_sample = False if temperature == 0 else True, top_p = top_p, top_k = top_k, temperature = temperature, streamer=streamer, pad_token_id = 10, ) with torch.no_grad(): thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer print(f'response: {buffer}') chatbot = gr.Chatbot(height=600) with gr.Blocks(css=CSS, theme="soft") as demo: gr.HTML(TITLE) gr.HTML(SUB_TITLE) gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.ChatInterface( fn=stream_chat, chatbot=chatbot, fill_height=True, additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Slider( minimum=0, maximum=1, step=0.1, value=0.3, label="Temperature", render=False, ), gr.Slider( minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens", render=False, ), gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False, ), gr.Slider( minimum=1, maximum=20, step=1, value=20, label="top_k", render=False, ), gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.2, label="Repetition penalty", render=False, ), ], examples=[ ["Hello there, can you suggest few places to visit in UAE?"], ["What UAE is known for?"], ], cache_examples=False, ) if __name__ == "__main__": demo.launch()