AtlasUI / app.py
Spestly's picture
Update app.py
ebac108 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import torch
import os
# Hugging Face token login
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Define models
MODELS = {
"atlas-flash-1215": {
"name": "🦁 Atlas-Flash 1215",
"sizes": {
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
},
"emoji": "🦁",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_FLASH_1215",
},
"atlas-pro-0403": {
"name": "πŸ† Atlas-Pro 0403",
"sizes": {
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
},
"emoji": "πŸ†",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_PRO_0403",
},
}
# Load default model
default_model_key = "atlas-pro-0403"
default_size = "1.5B"
default_model = MODELS[default_model_key]["sizes"][default_size]
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
model.eval()
return tokenizer, model
tokenizer, model = load_model(default_model)
# Generate response function
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
global tokenizer, model
selected_model = MODELS[model_key]["sizes"][model_size]
if selected_model != default_model:
tokenizer, model = load_model(selected_model)
system_prompt_env = MODELS[model_key]["system_prompt_env"]
system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
if MODELS[model_key]["is_vision"]:
image_info = "An image has been provided as input."
instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:"
else:
instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:"
inputs = tokenizer(instruction, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=temperature,
top_p=top_p,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
def create_interface():
with gr.Blocks(title="🌟 Atlas-Pro/Flash/Vision Interface", theme="soft") as iface:
gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!")
model_key_selector = gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value=default_model_key
)
model_size_selector = gr.Dropdown(
label="Model Size",
choices=list(MODELS[default_model_key]["sizes"].keys()),
value=default_size
)
image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
chat_output = gr.Chatbot(label="Chatbot")
submit_button = gr.Button("Submit")
def update_components(model_key):
model_info = MODELS[model_key]
new_sizes = list(model_info["sizes"].keys())
return [
gr.Dropdown(choices=new_sizes, value=new_sizes[0]),
gr.Image(visible=model_info["is_vision"])
]
model_key_selector.change(
fn=update_components,
inputs=model_key_selector,
outputs=[model_size_selector, image_input]
)
submit_button.click(
fn=generate_response,
inputs=[
message_input,
image_input,
chat_output,
model_key_selector,
model_size_selector,
temperature_slider,
top_p_slider,
max_tokens_slider
],
outputs=chat_output
)
return iface
create_interface().launch()