samuelemarro commited on
Commit
c9d2369
·
1 Parent(s): 19bcd88

Several UI improvements.

Browse files
Files changed (1) hide show
  1. app.py +81 -32
app.py CHANGED
@@ -2,8 +2,6 @@ import json
2
 
3
  import gradio as gr
4
 
5
- from collections import UserList
6
-
7
  from flow import full_flow
8
 
9
  from utils import use_cost_tracker, get_costs, compute_hash
@@ -40,11 +38,18 @@ def parse_raw_messages(messages_raw):
40
 
41
  def main():
42
  with gr.Blocks() as demo:
43
- gr.Markdown("### Agora Demo")
44
  gr.Markdown("We will create a new Agora channel and offer it to Alice as a tool.")
45
 
46
- chosen_task = gr.Dropdown(choices=list(SCHEMAS.keys()), label="Schema", value="weather_forecast")
47
- custom_task = gr.Checkbox(label="Custom Task")
 
 
 
 
 
 
 
48
 
49
  STATE_TRACKER = {}
50
 
@@ -52,13 +57,15 @@ def main():
52
  def render(chosen_task, custom_task):
53
  if STATE_TRACKER.get('chosen_task') != chosen_task:
54
  STATE_TRACKER['chosen_task'] = chosen_task
55
- for k, v in SCHEMAS[chosen_task].items():
56
  if isinstance(v, str):
57
  STATE_TRACKER[k] = v
58
  else:
59
  STATE_TRACKER[k] = json.dumps(v, indent=2)
60
 
61
  if custom_task:
 
 
62
  gr.Text(label="Description", value=STATE_TRACKER["description"], interactive=True).change(lambda x: STATE_TRACKER.update({'description': x}))
63
  gr.TextArea(label="Input Schema", value=STATE_TRACKER["input"], interactive=True).change(lambda x: STATE_TRACKER.update({'input': x}))
64
  gr.TextArea(label="Output Schema", value=STATE_TRACKER["output"], interactive=True).change(lambda x: STATE_TRACKER.update({'output': x}))
@@ -68,7 +75,8 @@ def main():
68
  model_options = [
69
  ('GPT 4o (Camel AI)', 'gpt-4o'),
70
  ('GPT 4o-mini (Camel AI)', 'gpt-4o-mini'),
71
- ('Claude 3 Sonnet (LangChain)', 'claude-3-sonnet'),
 
72
  ('Gemini 1.5 Pro (Google GenAI)', 'gemini-1.5-pro'),
73
  ('Llama3 405B (Sambanova + LangChain)', 'llama3-405b')
74
  ]
@@ -90,59 +98,100 @@ def main():
90
  with gr.Column(scale=1):
91
  bob_model_dd = gr.Dropdown(label="Bob Model", choices=model_options, value="gpt-4o")
92
 
93
- button = gr.Button('Start', elem_id='start_button')
94
- gr.Markdown('### Natural Language')
95
 
96
  @gr.render(inputs=[alice_model_dd, bob_model_dd])
97
  def render_with_images(alice_model, bob_model):
 
 
 
98
  avatar_images = [images.get(alice_model, fallback_image), images.get(bob_model, fallback_image)]
99
  chatbot_nl = gr.Chatbot(type="messages", avatar_images=avatar_images)
100
 
101
  with gr.Accordion(label="Raw Messages", open=False):
102
  chatbot_nl_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
103
 
104
- gr.Markdown('### Negotiation')
105
  chatbot_negotiation = gr.Chatbot(type="messages", avatar_images=avatar_images)
106
 
107
- gr.Markdown('### Protocol')
 
108
  protocol_result = gr.TextArea(interactive=False, label="Protocol")
109
 
110
- gr.Markdown('### Implementation')
111
  with gr.Row():
112
  with gr.Column(scale=1):
113
  alice_implementation = gr.TextArea(interactive=False, label="Alice Implementation")
114
  with gr.Column(scale=1):
115
  bob_implementation = gr.TextArea(interactive=False, label="Bob Implementation")
116
 
117
- gr.Markdown('### Structured Communication')
118
  structured_communication = gr.Chatbot(type="messages", avatar_images=avatar_images)
119
 
120
  with gr.Accordion(label="Raw Messages", open=False):
121
  structured_communication_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
 
 
 
 
 
 
 
 
122
 
123
- def respond(chosen_task, custom_task, alice_model, bob_model):
124
- yield gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), \
125
- None, None, None, None, None, None, None, None
126
-
127
- if custom_task:
128
- schema = dict(STATE_TRACKER)
129
- for k, v in schema.items():
130
- if isinstance(v, str):
131
- try:
132
- schema[k] = json.loads(v)
133
- except:
134
- pass
 
 
 
135
  else:
136
- schema = SCHEMAS[chosen_task]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- for nl_messages_raw, negotiation_messages, structured_messages_raw, protocol, alice_implementation, bob_implementation in full_flow(schema, alice_model, bob_model):
139
- nl_messages_clean, nl_messages_agora = parse_raw_messages(nl_messages_raw)
140
- structured_messages_clean, structured_messages_agora = parse_raw_messages(structured_messages_raw)
141
 
142
- yield gr.update(), gr.update(), gr.update(), nl_messages_clean, nl_messages_agora, negotiation_messages, structured_messages_clean, structured_messages_agora, protocol, alice_implementation, bob_implementation
 
143
 
144
- #yield from full_flow(schema, alice_model, bob_model)
145
- yield gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
146
 
147
  button.click(respond, [chosen_task, custom_task, alice_model_dd, bob_model_dd], [button, alice_model_dd, bob_model_dd, chatbot_nl, chatbot_nl_raw, chatbot_negotiation, structured_communication, structured_communication_raw, protocol_result, alice_implementation, bob_implementation])
148
 
 
2
 
3
  import gradio as gr
4
 
 
 
5
  from flow import full_flow
6
 
7
  from utils import use_cost_tracker, get_costs, compute_hash
 
38
 
39
  def main():
40
  with gr.Blocks() as demo:
41
+ gr.Markdown("# Agora Demo")
42
  gr.Markdown("We will create a new Agora channel and offer it to Alice as a tool.")
43
 
44
+ chosen_task = gr.Dropdown(choices=[
45
+ (v['display_name'], k) for k, v in SCHEMAS.items()
46
+ ], label="Schema", value="weather_forecast")
47
+
48
+ @gr.render(inputs=[chosen_task])
49
+ def render2(chosen_task):
50
+ gr.Markdown('**Description**: ' + SCHEMAS[chosen_task]["description"])
51
+
52
+ custom_task = gr.Checkbox(label="Override Demo Parameters")
53
 
54
  STATE_TRACKER = {}
55
 
 
57
  def render(chosen_task, custom_task):
58
  if STATE_TRACKER.get('chosen_task') != chosen_task:
59
  STATE_TRACKER['chosen_task'] = chosen_task
60
+ for k, v in SCHEMAS[chosen_task]['schema'].items():
61
  if isinstance(v, str):
62
  STATE_TRACKER[k] = v
63
  else:
64
  STATE_TRACKER[k] = json.dumps(v, indent=2)
65
 
66
  if custom_task:
67
+ gr.Markdown('#### Custom Demo Parameters')
68
+ gr.Markdown('You can override the default parameters for the demo. Note: recommended for advanced users only.')
69
  gr.Text(label="Description", value=STATE_TRACKER["description"], interactive=True).change(lambda x: STATE_TRACKER.update({'description': x}))
70
  gr.TextArea(label="Input Schema", value=STATE_TRACKER["input"], interactive=True).change(lambda x: STATE_TRACKER.update({'input': x}))
71
  gr.TextArea(label="Output Schema", value=STATE_TRACKER["output"], interactive=True).change(lambda x: STATE_TRACKER.update({'output': x}))
 
75
  model_options = [
76
  ('GPT 4o (Camel AI)', 'gpt-4o'),
77
  ('GPT 4o-mini (Camel AI)', 'gpt-4o-mini'),
78
+ ('Claude 3 Sonnet (LangChain)', 'claude-3-5-sonnet-latest'),
79
+ ('Claude 3 Haiku (LangChain)', 'claude-3-5-haiku-latest'),
80
  ('Gemini 1.5 Pro (Google GenAI)', 'gemini-1.5-pro'),
81
  ('Llama3 405B (Sambanova + LangChain)', 'llama3-405b')
82
  ]
 
98
  with gr.Column(scale=1):
99
  bob_model_dd = gr.Dropdown(label="Bob Model", choices=model_options, value="gpt-4o")
100
 
101
+
 
102
 
103
  @gr.render(inputs=[alice_model_dd, bob_model_dd])
104
  def render_with_images(alice_model, bob_model):
105
+ button = gr.Button('Start', elem_id='start_button')
106
+ gr.Markdown('## Natural Language')
107
+
108
  avatar_images = [images.get(alice_model, fallback_image), images.get(bob_model, fallback_image)]
109
  chatbot_nl = gr.Chatbot(type="messages", avatar_images=avatar_images)
110
 
111
  with gr.Accordion(label="Raw Messages", open=False):
112
  chatbot_nl_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
113
 
114
+ gr.Markdown('## Negotiation')
115
  chatbot_negotiation = gr.Chatbot(type="messages", avatar_images=avatar_images)
116
 
117
+ gr.Markdown('## Protocol')
118
+ protocol_hash_result = gr.Text(interactive=False, label="Protocol Hash")
119
  protocol_result = gr.TextArea(interactive=False, label="Protocol")
120
 
121
+ gr.Markdown('## Implementation')
122
  with gr.Row():
123
  with gr.Column(scale=1):
124
  alice_implementation = gr.TextArea(interactive=False, label="Alice Implementation")
125
  with gr.Column(scale=1):
126
  bob_implementation = gr.TextArea(interactive=False, label="Bob Implementation")
127
 
128
+ gr.Markdown('## Structured Communication')
129
  structured_communication = gr.Chatbot(type="messages", avatar_images=avatar_images)
130
 
131
  with gr.Accordion(label="Raw Messages", open=False):
132
  structured_communication_raw = gr.Chatbot(type="messages", avatar_images=avatar_images)
133
+
134
+ gr.Markdown('## Cost')
135
+ cost_info = gr.State(value=None)
136
+ #cost_info = gr.TextArea(interactive=False, label="Cost")
137
+
138
+ query_slider = gr.Slider(label="Number of queries", minimum=1, maximum=10_000, step=1, value=50, interactive=True)
139
+ cost_display = gr.Markdown('')
140
+
141
 
142
+ def render_info(query_count, cost_info):
143
+ if not cost_info:
144
+ return ''
145
+ natural_cost = cost_info['conversation'] * query_count
146
+ agora_cost = cost_info['negotiation'] + cost_info['programming']
147
+
148
+ cost_message = f'Cost of one natural language conversation: {cost_info["conversation"]:.4f} USD\n\n'
149
+ cost_message += f'Cost of negotiating the protocol: {cost_info["negotiation"]:.4f} USD\n\n'
150
+ cost_message += f'Cost of implementing the protocol: {cost_info["programming"]:.4f} USD\n\n'
151
+ cost_message += f'Cost of {query_count} queries with natural language: {natural_cost:.4f} USD\n\n'
152
+ cost_message += f'Cost of {query_count} queries with Agora: {agora_cost:.4f} USD\n\n'
153
+
154
+ if natural_cost < agora_cost:
155
+ factor = agora_cost / natural_cost
156
+ cost_message += f'Natural language is {factor:.2f}x cheaper than Agora.'
157
  else:
158
+ factor = natural_cost / agora_cost
159
+ cost_message += f'Agora is {factor:.2f}x cheaper than natural language.'
160
+
161
+ return cost_message
162
+
163
+ cost_info.change(render_info, [query_slider, cost_info], [cost_display])
164
+ query_slider.change(render_info, [query_slider, cost_info], [cost_display])
165
+
166
+
167
+ def respond(chosen_task, custom_task, alice_model, bob_model, query_count):
168
+ with use_cost_tracker():
169
+ yield gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), \
170
+ None, None, None, None, None, None, None, None, None, None, None
171
+
172
+ if custom_task:
173
+ schema = dict(STATE_TRACKER)
174
+ for k, v in schema.items():
175
+ if isinstance(v, str):
176
+ try:
177
+ schema[k] = json.loads(v)
178
+ except:
179
+ pass
180
+ else:
181
+ schema = SCHEMAS[chosen_task]["schema"]
182
+
183
+ for nl_messages_raw, negotiation_messages, structured_messages_raw, protocol, alice_implementation, bob_implementation in full_flow(schema, alice_model, bob_model):
184
+ nl_messages_clean, nl_messages_agora = parse_raw_messages(nl_messages_raw)
185
+ structured_messages_clean, structured_messages_agora = parse_raw_messages(structured_messages_raw)
186
+ protocol_hash = compute_hash(protocol) if protocol else None
187
+ yield gr.update(), gr.update(), gr.update(), None, None, nl_messages_clean, nl_messages_agora, negotiation_messages, structured_messages_clean, structured_messages_agora, protocol, protocol_hash, alice_implementation, bob_implementation
188
 
189
+ #yield from full_flow(schema, alice_model, bob_model)
 
 
190
 
191
+ cost_data = get_costs()
192
+ cost_data_formatted = render_info(query_count, cost_data)
193
 
194
+ yield gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), cost_data, cost_data_formatted, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
 
195
 
196
  button.click(respond, [chosen_task, custom_task, alice_model_dd, bob_model_dd], [button, alice_model_dd, bob_model_dd, chatbot_nl, chatbot_nl_raw, chatbot_negotiation, structured_communication, structured_communication_raw, protocol_result, alice_implementation, bob_implementation])
197