Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import torch | |
import os | |
# Hugging Face token login | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# Define models | |
MODELS = { | |
"atlas-flash-1215": { | |
"name": "π¦ Atlas-Flash 1215", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview", | |
}, | |
"emoji": "π¦", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_FLASH_1215", | |
}, | |
"atlas-pro-0403": { | |
"name": "π Atlas-Pro 0403", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview", | |
}, | |
"emoji": "π", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_PRO_0403", | |
}, | |
} | |
# Load default model | |
default_model_key = "atlas-pro-0403" | |
default_size = "1.5B" | |
default_model = MODELS[default_model_key]["sizes"][default_size] | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
) | |
model.eval() | |
return tokenizer, model | |
tokenizer, model = load_model(default_model) | |
# Generate response function | |
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens): | |
global tokenizer, model | |
# Load the selected model | |
selected_model = MODELS[model_key]["sizes"][model_size] | |
if selected_model != default_model: | |
tokenizer, model = load_model(selected_model) | |
# Get the system prompt from the environment | |
system_prompt_env = MODELS[model_key]["system_prompt_env"] | |
system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.") | |
# Construct instruction | |
if MODELS[model_key]["is_vision"]: | |
# If a vision model, include the image information | |
image_info = "An image has been provided as input." | |
instruction = ( | |
f"{system_prompt}\n\n" | |
f"### Instruction:\n{message}\n{image_info}\n\n### Response:" | |
) | |
else: | |
# For non-vision models | |
instruction = ( | |
f"{system_prompt}\n\n" | |
f"### Instruction:\n{message}\n\n### Response:" | |
) | |
# Tokenize input | |
inputs = tokenizer(instruction, return_tensors="pt") | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("### Response:")[-1].strip() | |
return response | |
# User interface | |
def create_interface(): | |
# Define input components | |
message_input = gr.Textbox(label="Message", placeholder="Type your message here...") | |
model_key_selector = gr.Dropdown( | |
label="Model", | |
choices=list(MODELS.keys()), | |
value=default_model_key | |
) | |
model_size_selector = gr.Dropdown( | |
label="Model Size", | |
choices=list(MODELS[default_model_key]["sizes"].keys()), | |
value=default_size | |
) | |
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1) | |
top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1) | |
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50) | |
image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False) | |
# Function to toggle visibility of image input | |
def toggle_image_input(model_key): | |
return MODELS[model_key]["is_vision"] | |
# Output components | |
chat_output = gr.Chatbot(label="Chatbot") | |
# Function to process inputs and generate output | |
def process_inputs(message, image, model_key, model_size, temperature, top_p, max_new_tokens, history=[]): | |
response = generate_response( | |
message=message, | |
image=image, | |
history=history, | |
model_key=model_key, | |
model_size=model_size, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=max_new_tokens | |
) | |
history.append((message, response)) | |
return history | |
# Interface layout | |
iface = gr.Interface( | |
fn=process_inputs, | |
inputs=[ | |
message_input, | |
image_input, | |
model_key_selector, | |
model_size_selector, | |
temperature_slider, | |
top_p_slider, | |
max_tokens_slider | |
], | |
outputs=chat_output, | |
title="π Atlas-Pro/Flash/Vision Interface", | |
description="Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Pro (Comming Soon!). Upload images for vision models!", | |
theme="soft", | |
live=True | |
) | |
# Add event to toggle image input visibility | |
iface.input_components[1].set_visibility(toggle_image_input(model_key_selector.value)) | |
return iface | |
# Launch the app | |
create_interface().launch() | |