import os import json import subprocess from threading import Thread import torch import spaces import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer from ui import css, PLACEHOLDER subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) MODEL_ID = os.environ.get("MODEL_ID") CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE") MODEL_NAME = MODEL_ID.split("/")[-1] CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH")) COLOR = os.environ.get("COLOR") EMOJI = os.environ.get("EMOJI") DESCRIPTION = os.environ.get("DESCRIPTION") @spaces.GPU() def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p): # Format history with a given chat template if CHAT_TEMPLATE == "ChatML": stop_tokens = ["<|endoftext|>", "<|im_end|>"] instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n' for human, assistant in history: instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n' elif CHAT_TEMPLATE == "Mistral Instruct": stop_tokens = ["", "[INST]", "[INST] ", "", "[/INST]", "[/INST] "] instruction = '[INST] ' + system_prompt for human, assistant in history: instruction += human + ' [/INST] ' + assistant + '[INST]' instruction += ' ' + message + ' [/INST]' else: raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'") print(instruction) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) input_ids, attention_mask = enc.input_ids, enc.attention_mask if input_ids.shape[1] > CONTEXT_LENGTH: input_ids = input_ids[:, -CONTEXT_LENGTH:] generate_kwargs = dict( {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)}, streamer=streamer, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for new_token in streamer: outputs.append(new_token) if new_token in stop_tokens: break yield "".join(outputs) # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') quantization_config = BitsAndBytesConfig( load_in_8bit=True, # bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", quantization_config=quantization_config, attn_implementation="flash_attention_2", ) # Create Gradio interface gr.ChatInterface( predict, title=EMOJI + " " + MODEL_NAME, description=DESCRIPTION, examples=[ ["Can you solve the equation 2x + 3 = 11 for x?"], ["Write an epic poem about Ancient Rome."], ["Who was the first person to walk on the Moon?"], ["Use a list comprehension to create a list of squares for numbers from 1 to 10."], ["Recommend some popular science fiction books."], ["Can you write a short story about a time-traveling detective?"] ], additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), additional_inputs=[ gr.Textbox("Perform the task to the best of your ability.", label="System prompt"), gr.Slider(0, 1, 0.8, label="Temperature"), gr.Slider(128, 4096, 1024, label="Max new tokens"), gr.Slider(1, 80, 40, label="Top K sampling"), gr.Slider(0, 2, 1.1, label="Repetition penalty"), gr.Slider(0, 1, 0.95, label="Top P sampling"), ], theme=gr.themes.Soft(primary_hue=COLOR).set( background_fill_primary_dark="#020417", background_fill_secondary_dark="#020417", body_background_fill_dark="#020417", block_background_fill_dark="#020417", block_border_width="1px", block_title_background_fill_dark="#15172c", input_background_fill_dark="#15172c", button_secondary_background_fill_dark="#15172c", border_color_accent_dark="#15172c", border_color_primary_dark="#15172c", color_accent_soft_dark="#10132c", code_background_fill_dark="#15172c", ), css=css, retry_btn="Retry", undo_btn="Undo", clear_btn="Clear", submit_btn="Send", chatbot=gr.Chatbot( scale=1, placeholder=PLACEHOLDER, show_copy_button=True ) ).queue().launch()