File size: 5,324 Bytes
1b2a38a
b2efec6
 
2960eb4
69a84bc
7e978bb
e1807a7
0189af7
094584b
0189af7
e1807a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b2a38a
e1807a7
1b2a38a
e1807a7
 
 
 
1b2a38a
e1807a7
 
 
 
 
 
2960eb4
e1807a7
 
 
 
 
 
 
 
 
 
 
 
1b2a38a
e1807a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2960eb4
 
e1807a7
2960eb4
 
 
e1807a7
2960eb4
e1807a7
 
2960eb4
1b2a38a
2960eb4
e1807a7
2960eb4
 
 
 
e1807a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
import torch
import os

# Hugging Face token login
HF_TOKEN = os.getenv("HF_TOKEN")
login(token=HF_TOKEN)

# Define models
MODELS = {
    "atlas-flash-1215": {
        "name": "🦁 Atlas-Flash 1215",
        "sizes": {
            "1.5B": "Spestly/Atlas-Flash-1.5B-Preview",
        },
        "emoji": "🦁",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_FLASH_1215",
    },
    "atlas-pro-0403": {
        "name": "πŸ† Atlas-Pro 0403",
        "sizes": {
            "1.5B": "Spestly/Atlas-Pro-1.5B-Preview",
        },
        "emoji": "πŸ†",
        "experimental": True,
        "is_vision": False,
        "system_prompt_env": "ATLAS_PRO_0403",
    },

}

# Load default model
default_model_key = "atlas-pro-0403"
default_size = "1.5B"
default_model = MODELS[default_model_key]["sizes"][default_size]

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True
    )
    model.eval()
    return tokenizer, model

tokenizer, model = load_model(default_model)

# Generate response function
def generate_response(message, image, history, model_key, model_size, temperature, top_p, max_new_tokens):
    global tokenizer, model
    # Load the selected model
    selected_model = MODELS[model_key]["sizes"][model_size]
    if selected_model != default_model:
        tokenizer, model = load_model(selected_model)
    
    # Get the system prompt from the environment
    system_prompt_env = MODELS[model_key]["system_prompt_env"]
    system_prompt = os.getenv(system_prompt_env, "You are an advanced AI system. Help the user as best as you can.")
    
    # Construct instruction
    if MODELS[model_key]["is_vision"]:
        # If a vision model, include the image information
        image_info = "An image has been provided as input."
        instruction = (
            f"{system_prompt}\n\n"
            f"### Instruction:\n{message}\n{image_info}\n\n### Response:"
        )
    else:
        # For non-vision models
        instruction = (
            f"{system_prompt}\n\n"
            f"### Instruction:\n{message}\n\n### Response:"
        )
    
    # Tokenize input
    inputs = tokenizer(instruction, return_tensors="pt")
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            temperature=temperature,
            top_p=top_p,
            do_sample=True
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split("### Response:")[-1].strip()
    return response

# User interface
def create_interface():
    # Define input components
    message_input = gr.Textbox(label="Message", placeholder="Type your message here...")
    model_key_selector = gr.Dropdown(
        label="Model",
        choices=list(MODELS.keys()),
        value=default_model_key
    )
    model_size_selector = gr.Dropdown(
        label="Model Size",
        choices=list(MODELS[default_model_key]["sizes"].keys()),
        value=default_size
    )
    temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7, step=0.1)
    top_p_slider = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
    max_tokens_slider = gr.Slider(label="Max New Tokens", minimum=50, maximum=2000, value=1000, step=50)
    image_input = gr.Image(label="Upload Image (if applicable)", type="filepath", visible=False)

    # Function to toggle visibility of image input
    def toggle_image_input(model_key):
        return MODELS[model_key]["is_vision"]

    # Output components
    chat_output = gr.Chatbot(label="Chatbot")

    # Function to process inputs and generate output
    def process_inputs(message, image, model_key, model_size, temperature, top_p, max_new_tokens, history=[]):
        response = generate_response(
            message=message,
            image=image,
            history=history,
            model_key=model_key,
            model_size=model_size,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens
        )
        history.append((message, response))
        return history

    # Interface layout
    iface = gr.Interface(
        fn=process_inputs,
        inputs=[
            message_input,
            image_input,
            model_key_selector,
            model_size_selector,
            temperature_slider,
            top_p_slider,
            max_tokens_slider
        ],
        outputs=chat_output,
        title="🌟 Atlas-Pro/Flash/Vision Interface",
        description="Interact with multiple models like Atlas-Pro, Atlas-Flash, and AtlasV-Pro (Comming Soon!). Upload images for vision models!",
        theme="soft",
        live=True
    )

    # Add event to toggle image input visibility
    iface.input_components[1].set_visibility(toggle_image_input(model_key_selector.value))
    return iface

# Launch the app
create_interface().launch()