import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login import torch import os # Hugging Face token login HF_TOKEN = os.getenv("HF_TOKEN") login(token=HF_TOKEN) # Define models MODELS = { "atlas-flash-1215": { "name": "🦁 Atlas-Flash 1215", "sizes": { "1.5B": "Spestly/Atlas-Flash-1.5B-Preview", }, "emoji": "🦁", "experimental": True, "is_vision": False, "system_prompt_env": "ATLAS_FLASH_1215", }, "atlas-pro-0403": { "name": "🏆 Atlas-Pro 0403", "sizes": { "1.5B": "Spestly/Atlas-Pro-1.5B-Preview", }, "emoji": "🏆", "experimental": True, "is_vision": False, "system_prompt_env": "ATLAS_PRO_0403", }, } # Load default model default_model_key = "atlas-pro-0403" default_size = "1.5B" default_model = MODELS[default_model_key]["sizes"][default_size] def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True ) model.eval() return tokenizer, model tokenizer, model = load_model(default_model) # Generate response function def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens): global tokenizer, model selected_model = MODELS[model_key]["sizes"][model_size] if selected_model != default_model: tokenizer, model = load_model(selected_model) system_prompt_env = MODELS[model_key]["system_prompt_env"] system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.") if MODELS[model_key]["is_vision"]: image_info = "An image has been provided as input." instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n{image_info}\n\n### Response:" else: instruction = f"{system_prompt}\n\n### Instruction:\n{message}\n\n### Response:" inputs = tokenizer(instruction, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, num_return_sequences=1, temperature=temperature, top_p=top_p, do_sample=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.split("### Response:")[-1].strip() return response def create_interface(): with gr.Blocks(title="🌟 Atlas-Pro/Flash/Vision Interface", theme="soft") as iface: gr.Markdown("Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Flash (Coming Soon!). Upload images for vision models!") model_key_selector = gr.Dropdown( label="Model", choices=list(MODELS.keys()), value=default_model_key ) model_size_selector = gr.Dropdown( label="Model Size", choices=list(MODELS[default_model_key]["sizes"].keys()), value=default_size ) image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False) message_input = gr.Textbox(label="Message", placeholder="Type your message here...") temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1) top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1) max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50) chat_output = gr.Chatbot(label="Chatbot") submit_button = gr.Button("Submit") def update_components(model_key): model_info = MODELS[model_key] new_sizes = list(model_info["sizes"].keys()) return [ gr.Dropdown(choices=new_sizes, value=new_sizes[0]), gr.Image(visible=model_info["is_vision"]) ] model_key_selector.change( fn=update_components, inputs=model_key_selector, outputs=[model_size_selector, image_input] ) submit_button.click( fn=generate_response, inputs=[ message_input, image_input, chat_output, model_key_selector, model_size_selector, temperature_slider, top_p_slider, max_tokens_slider ], outputs=chat_output ) return iface create_interface().launch()