File size: 1,142 Bytes
8b60cb4
 
 
 
 
38370aa
 
 
 
 
 
 
 
 
 
 
 
 
 
390672f
38370aa
 
390672f
38370aa
 
 
390672f
38370aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import spaces
import transformers_gradio

# Load models
models = {
    "Llama": {
        "name": "allenai/Llama-3.1-Tulu-3-8B",
        "src": transformers_gradio.registry
    },
    "OLMo": {
        "name": "akhaliq/olmo-anychat",
        "src": "spaces"
    }
}

def load_model(model_choice):
    model_info = models[model_choice]
    demo = gr.load(name=model_info["name"], src=model_info["src"])
    
    if model_choice == "Llama":
        demo.fn = spaces.GPU()(demo.fn)
    
    # Disable API names
    for fn in demo.fns.values():
        fn.api_name = False
    
    return demo

if __name__ == "__main__":
    with gr.Blocks() as interface:
        model_dropdown = gr.Dropdown(
            choices=list(models.keys()),
            value="Llama",
            label="Select Model"
        )
        
        demo_container = gr.Group()
        
        def update_demo(model_choice):
            return load_model(model_choice)
        
        model_dropdown.change(
            fn=update_demo,
            inputs=[model_dropdown],
            outputs=[demo_container]
        )
        
    interface.launch()