Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import torch | |
import os | |
# Hugging Face token login | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
login(token=HF_TOKEN) | |
# Define models | |
MODELS = { | |
"atlas-flash-1215": { | |
"name": "π¦ Atlas-Flash 1215", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview", | |
}, | |
"emoji": "π¦", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_FLASH_1215", | |
}, | |
"atlas-pro-0403": { | |
"name": "π Atlas-Pro 0403", | |
"sizes": { | |
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview", | |
}, | |
"emoji": "π", | |
"experimental": True, | |
"is_vision": False, | |
"system_prompt_env": "ATLAS_PRO_0403", | |
}, | |
} | |
# Load default model | |
default_model_key = "atlas-pro-0403" | |
default_size = "1.5B" | |
default_model = MODELS[default_model_key]["sizes"][default_size] | |
def load_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
) | |
model.eval() | |
return tokenizer, model | |
tokenizer, model = load_model(default_model) | |
# Generate response function | |
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens): | |
global tokenizer, model | |
selected_model = MODELS[model_key]["sizes"][model_size] | |
if selected_model != default_model: | |
tokenizer, model = load_model(selected_model) | |
system_prompt_env = MODELS[model_key]["system_prompt_env"] | |
system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.") | |
if MODELS[model_key]["is_vision"]: | |
image_info = "An image has been provided as input." | |
instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:" | |
else: | |
instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:" | |
inputs = tokenizer(instruction, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
num_return_sequences=1, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("### Response:")[-1].strip() | |
return response | |
def create_interface(): | |
with gr.Blocks(title="π Atlas-Pro/Flash/Vision Interface", theme="soft") as iface: | |
gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!") | |
model_key_selector = gr.Dropdown( | |
label="Model", | |
choices=list(MODELS.keys()), | |
value=default_model_key | |
) | |
model_size_selector = gr.Dropdown( | |
label="Model Size", | |
choices=list(MODELS[default_model_key]["sizes"].keys()), | |
value=default_size | |
) | |
image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False) | |
message_input = gr.Textbox(label="Message", placeholder="Type your message here...") | |
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1) | |
top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1) | |
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50) | |
chat_output = gr.Chatbot(label="Chatbot") | |
submit_button = gr.Button("Submit") | |
def update_components(model_key): | |
model_info = MODELS[model_key] | |
new_sizes = list(model_info["sizes"].keys()) | |
return [ | |
gr.Dropdown(choices=new_sizes, value=new_sizes[0]), | |
gr.Image(visible=model_info["is_vision"]) | |
] | |
model_key_selector.change( | |
fn=update_components, | |
inputs=model_key_selector, | |
outputs=[model_size_selector, image_input] | |
) | |
submit_button.click( | |
fn=generate_response, | |
inputs=[ | |
message_input, | |
image_input, | |
chat_output, | |
model_key_selector, | |
model_size_selector, | |
temperature_slider, | |
top_p_slider, | |
max_tokens_slider | |
], | |
outputs=chat_output | |
) | |
return iface | |
create_interface().launch() |