akhaliq HF staff commited on
Commit
38370aa
·
verified ·
1 Parent(s): 02c8f8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -55
app.py CHANGED
@@ -3,61 +3,47 @@ import spaces
3
  import transformers_gradio
4
 
5
  # Load models
6
- llama_demo = gr.load(name="allenai/Llama-3.1-Tulu-3-8B", src=transformers_gradio.registry)
7
- llama_demo.fn = spaces.GPU()(llama_demo.fn)
8
-
9
- # Modified OLMo loading - pass additional parameters to handle chat interface
10
- olmo_demo = gr.load(
11
- name="akhaliq/olmo-anychat",
12
- src="spaces",
13
- api_name=False, # Disable API endpoint
14
- title="OLMo Chat", # These parameters will be passed to ChatInterface
15
- description="A chat interface for the OLMo model"
16
- )
17
-
18
- # Create the interface
19
- with gr.Blocks() as demo:
20
- model_dropdown = gr.Dropdown(
21
- choices=["allenai/Llama-3.1-Tulu-3-8B", "akhaliq/olmo-anychat"],
22
- value="allenai/Llama-3.1-Tulu-3-8B",
23
- label="Select Model"
24
- )
25
-
26
- # Create columns for each model
27
- with gr.Column(visible=True) as llama_column:
28
- gr.ChatInterface(
29
- fn=lambda message, history: next(llama_demo.fn(message)),
30
- title="Llama 3.1 Tulu",
31
- description="A chat interface for the Llama 3.1 Tulu model"
32
- )
33
 
34
- with gr.Column(visible=False) as olmo_column:
35
- # Wrap the olmo_demo with error handling if needed
36
- if hasattr(olmo_demo, 'fn'):
37
- olmo_demo.fn = lambda message, history: next(olmo_demo.fn(message)) if olmo_demo.fn(message) else "I apologize, but I couldn't generate a response. Please try again."
38
- olmo_demo # Just render the loaded interface directly
39
 
40
- # Update visibility when model changes
41
- def update_model(new_model):
42
- return [
43
- gr.Column(visible=new_model == "allenai/Llama-3.1-Tulu-3-8B"),
44
- gr.Column(visible=new_model == "akhaliq/olmo-anychat")
45
- ]
46
 
47
- model_dropdown.change(
48
- fn=update_model,
49
- inputs=model_dropdown,
50
- outputs=[llama_column, olmo_column],
51
- api_name=False,
52
- queue=False,
53
- )
54
-
55
- # Disable API names
56
- for fn in demo.fns.values():
57
- fn.api_name = False
58
-
59
-
60
-
61
-
62
-
63
- demo.launch()
 
 
 
 
 
 
3
  import transformers_gradio
4
 
5
  # Load models
6
+ models = {
7
+ "Llama": {
8
+ "name": "allenai/Llama-3.1-Tulu-3-8B",
9
+ "src": transformers_gradio.registry
10
+ },
11
+ "OLMo": {
12
+ "name": "akhaliq/olmo-anychat",
13
+ "src": "spaces"
14
+ }
15
+ }
16
+
17
+ def load_model(model_choice):
18
+ model_info = models[model_choice]
19
+ demo = gr.load(name=model_info["name"], src=model_info["src"])
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ if model_choice == "Llama":
22
+ demo.fn = spaces.GPU()(demo.fn)
 
 
 
23
 
24
+ # Disable API names
25
+ for fn in demo.fns.values():
26
+ fn.api_name = False
 
 
 
27
 
28
+ return demo
29
+
30
+ if __name__ == "__main__":
31
+ with gr.Blocks() as interface:
32
+ model_dropdown = gr.Dropdown(
33
+ choices=list(models.keys()),
34
+ value="Llama",
35
+ label="Select Model"
36
+ )
37
+
38
+ demo_container = gr.Group()
39
+
40
+ def update_demo(model_choice):
41
+ return load_model(model_choice)
42
+
43
+ model_dropdown.change(
44
+ fn=update_demo,
45
+ inputs=[model_dropdown],
46
+ outputs=[demo_container]
47
+ )
48
+
49
+ interface.launch()