AtlasUI / app.py
Spestly's picture
Update app.py
e1807a7 verified
raw
history blame
5.32 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import torch
import os
# Hugging Face token login
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)
# Define models
MODELS = {
"atlas-flash-1215": {
"name": "🦁 Atlas-Flash 1215",
"sizes": {
"1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
},
"emoji": "🦁",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_FLASH_1215",
},
"atlas-pro-0403": {
"name": "πŸ† Atlas-Pro 0403",
"sizes": {
"1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
},
"emoji": "πŸ†",
"experimental": True,
"is_vision": False,
"system_prompt_env": "ATLAS_PRO_0403",
},
}
# Load default model
default_model_key = "atlas-pro-0403"
default_size = "1.5B"
default_model = MODELS[default_model_key]["sizes"][default_size]
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
model.eval()
return tokenizer, model
tokenizer, model = load_model(default_model)
# Generate response function
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
global tokenizer, model
# Load the selected model
selected_model = MODELS[model_key]["sizes"][model_size]
if selected_model != default_model:
tokenizer, model = load_model(selected_model)
# Get the system prompt from the environment
system_prompt_env = MODELS[model_key]["system_prompt_env"]
system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
# Construct instruction
if MODELS[model_key]["is_vision"]:
# If a vision model, include the image information
image_info = "An image has been provided as input."
instruction = (
f"{system_prompt}\n\n"
f"### Instruction:\n{message}\n{image_info}\n\n### Response:"
)
else:
# For non-vision models
instruction = (
f"{system_prompt}\n\n"
f"### Instruction:\n{message}\n\n### Response:"
)
# Tokenize input
inputs = tokenizer(instruction, return_tensors="pt")
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_return_sequences=1,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip()
return response
# User interface
def create_interface():
# Define input components
message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
model_key_selector = gr.Dropdown(
label="Model",
choices=list(MODELS.keys()),
value=default_model_key
)
model_size_selector = gr.Dropdown(
label="Model Size",
choices=list(MODELS[default_model_key]["sizes"].keys()),
value=default_size
)
temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)
# Function to toggle visibility of image input
def toggle_image_input(model_key):
return MODELS[model_key]["is_vision"]
# Output components
chat_output = gr.Chatbot(label="Chatbot")
# Function to process inputs and generate output
def process_inputs(message, image, model_key, model_size, temperature, top_p, max_new_tokens, history=[]):
response = generate_response(
message=message,
image=image,
history=history,
model_key=model_key,
model_size=model_size,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens
)
history.append((message, response))
return history
# Interface layout
iface = gr.Interface(
fn=process_inputs,
inputs=[
message_input,
image_input,
model_key_selector,
model_size_selector,
temperature_slider,
top_p_slider,
max_tokens_slider
],
outputs=chat_output,
title="🌟 Atlas-Pro/Flash/Vision Interface",
description="Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Pro (Comming Soon!). Upload images for vision models!",
theme="soft",
live=True
)
# Add event to toggle image input visibility
iface.input_components[1].set_visibility(toggle_image_input(model_key_selector.value))
return iface
# Launch the app
create_interface().launch()