freddyaboulton HF Staff commited on
Commit
b2e5de5
·
1 Parent(s): a405624
Files changed (2) hide show
  1. app.py +193 -0
  2. requirements.txt +293 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations as _annotations
2
+
3
+ import json
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Any
7
+ import gradio as gr
8
+ from httpx import AsyncClient
9
+ from groq import Groq
10
+ import numpy as np
11
+ from gradio_webrtc import WebRTC, AdditionalOutputs, ReplyOnPause, audio_to_bytes
12
+
13
+ from pydantic_ai import Agent, ModelRetry, RunContext
14
+ from pydantic_ai.messages import ModelStructuredResponse, ToolReturn, ModelTextResponse
15
+
16
+ from dotenv import load_dotenv
17
+
18
+ load_dotenv()
19
+
20
+ @dataclass
21
+ class Deps:
22
+ client: AsyncClient
23
+ weather_api_key: str | None
24
+ geo_api_key: str | None
25
+
26
+
27
+ weather_agent = Agent(
28
+ 'openai:gpt-4o',
29
+ system_prompt='You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.',
30
+ deps_type=Deps,
31
+ retries=2,
32
+ )
33
+
34
+
35
+ @weather_agent.tool
36
+ async def get_lat_lng(
37
+ ctx: RunContext[Deps], location_description: str
38
+ ) -> dict[str, float]:
39
+ """Get the latitude and longitude of a location.
40
+
41
+ Args:
42
+ ctx: The context.
43
+ location_description: A description of a location.
44
+ """
45
+ if ctx.deps.geo_api_key is None:
46
+ # if no API key is provided, return a dummy response (London)
47
+ return {'lat': 51.1, 'lng': -0.1}
48
+
49
+ params = {
50
+ 'q': location_description,
51
+ 'api_key': ctx.deps.geo_api_key,
52
+ }
53
+ r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params)
54
+ r.raise_for_status()
55
+ data = r.json()
56
+
57
+ if data:
58
+ return {'lat': data[0]['lat'], 'lng': data[0]['lon']}
59
+ else:
60
+ raise ModelRetry('Could not find the location')
61
+
62
+
63
+ @weather_agent.tool
64
+ async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]:
65
+ """Get the weather at a location.
66
+
67
+ Args:
68
+ ctx: The context.
69
+ lat: Latitude of the location.
70
+ lng: Longitude of the location.
71
+ """
72
+ if ctx.deps.weather_api_key is None:
73
+ # if no API key is provided, return a dummy response
74
+ return {'temperature': '21 °C', 'description': 'Sunny'}
75
+
76
+ params = {
77
+ 'apikey': ctx.deps.weather_api_key,
78
+ 'location': f'{lat},{lng}',
79
+ 'units': 'metric',
80
+ }
81
+ r = await ctx.deps.client.get(
82
+ 'https://api.tomorrow.io/v4/weather/realtime', params=params
83
+ )
84
+ r.raise_for_status()
85
+ data = r.json()
86
+
87
+ values = data['data']['values']
88
+ # https://docs.tomorrow.io/reference/data-layers-weather-codes
89
+ code_lookup = {
90
+ 1000: 'Clear, Sunny',
91
+ 1100: 'Mostly Clear',
92
+ 1101: 'Partly Cloudy',
93
+ 1102: 'Mostly Cloudy',
94
+ 1001: 'Cloudy',
95
+ 2000: 'Fog',
96
+ 2100: 'Light Fog',
97
+ 4000: 'Drizzle',
98
+ 4001: 'Rain',
99
+ 4200: 'Light Rain',
100
+ 4201: 'Heavy Rain',
101
+ 5000: 'Snow',
102
+ 5001: 'Flurries',
103
+ 5100: 'Light Snow',
104
+ 5101: 'Heavy Snow',
105
+ 6000: 'Freezing Drizzle',
106
+ 6001: 'Freezing Rain',
107
+ 6200: 'Light Freezing Rain',
108
+ 6201: 'Heavy Freezing Rain',
109
+ 7000: 'Ice Pellets',
110
+ 7101: 'Heavy Ice Pellets',
111
+ 7102: 'Light Ice Pellets',
112
+ 8000: 'Thunderstorm',
113
+ }
114
+ return {
115
+ 'temperature': f'{values["temperatureApparent"]:0.0f}°C',
116
+ 'description': code_lookup.get(values['weatherCode'], 'Unknown'),
117
+ }
118
+
119
+
120
+ TOOL_TO_DISPLAY_NAME = {
121
+ 'get_lat_lng': 'Geocoding API',
122
+ "get_weather": "Weather API"
123
+ }
124
+
125
+
126
+ groq_client = Groq()
127
+
128
+ client = AsyncClient()
129
+ weather_api_key = os.getenv('WEATHER_API_KEY')
130
+ # create a free API key at https://geocode.maps.co/
131
+ geo_api_key = os.getenv('GEO_API_KEY')
132
+ deps = Deps(
133
+ client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key
134
+ )
135
+
136
+
137
+ async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list):
138
+ yield "", gr.skip(), gr.skip()
139
+ chatbot.append({'role': 'user', 'content': prompt})
140
+ yield gr.skip(), chatbot, gr.skip()
141
+ async with weather_agent.run_stream(prompt, deps=deps, message_history=past_messages) as result:
142
+ for message in result.new_messages():
143
+ past_messages.append(message)
144
+ if isinstance(message, ModelStructuredResponse):
145
+ for call in message.calls:
146
+ gr_message = {"role": "assistant",
147
+ "content": "",
148
+ "metadata": {"title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}",
149
+ "id": call.tool_id}
150
+ }
151
+ chatbot.append(gr_message)
152
+ if isinstance(message, ToolReturn):
153
+ for gr_message in chatbot:
154
+ if gr_message.get('metadata', {}).get('id', "") == message.tool_id:
155
+ gr_message['content'] = f"Output: {json.dumps(message.content)}"
156
+ yield gr.skip(), chatbot, gr.skip()
157
+ chatbot.append({'role': 'assistant', 'content': ""})
158
+ async for message in result.stream_text():
159
+ chatbot[-1]["content"] = message
160
+ yield gr.skip(), chatbot, gr.skip()
161
+ data = await result.get_data()
162
+ past_messages.append(ModelTextResponse(content=data))
163
+ yield gr.skip(), gr.skip(), past_messages
164
+
165
+
166
+ with gr.Blocks() as demo:
167
+ gr.HTML(
168
+ """
169
+ <div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%">
170
+ <img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto">
171
+ <div>
172
+ <h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1>
173
+ <h3 style="margin: 0 0 0.5rem 0">
174
+ This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast.
175
+ </h3>
176
+ <h3 style="margin: 0">
177
+ Feel free to ask for help with any other questions you have about your trip!
178
+ </h3>
179
+ </div>
180
+ </div>
181
+ """
182
+ )
183
+ past_messages = gr.State([])
184
+ chatbot = gr.Chatbot(label="Packing Assistant", type="messages",
185
+ avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"))
186
+ prompt = gr.Textbox(lines=1, label="Enter your destination or follow-up question", placeholder="Miami, Florida")
187
+ prompt.submit(stream_from_agent, inputs=[prompt, chatbot, past_messages],
188
+ outputs=[prompt, chatbot, past_messages])
189
+
190
+
191
+
192
+ if __name__ == '__main__':
193
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile requirements.in -o requirements.txt
3
+ aiofiles==23.2.1
4
+ # via gradio
5
+ aioice==0.9.0
6
+ # via aiortc
7
+ aiortc==1.9.0
8
+ # via gradio-webrtc
9
+ annotated-types==0.7.0
10
+ # via pydantic
11
+ anyio==4.6.2.post1
12
+ # via
13
+ # gradio
14
+ # groq
15
+ # httpx
16
+ # openai
17
+ # starlette
18
+ audioread==3.0.1
19
+ # via librosa
20
+ av==12.3.0
21
+ # via aiortc
22
+ cachetools==5.5.0
23
+ # via google-auth
24
+ certifi==2024.8.30
25
+ # via
26
+ # httpcore
27
+ # httpx
28
+ # requests
29
+ cffi==1.17.1
30
+ # via
31
+ # aiortc
32
+ # cryptography
33
+ # pylibsrtp
34
+ # soundfile
35
+ charset-normalizer==3.4.0
36
+ # via requests
37
+ click==8.1.7
38
+ # via
39
+ # typer
40
+ # uvicorn
41
+ colorama==0.4.6
42
+ # via griffe
43
+ coloredlogs==15.0.1
44
+ # via onnxruntime
45
+ cryptography==44.0.0
46
+ # via
47
+ # aiortc
48
+ # pyopenssl
49
+ decorator==5.1.1
50
+ # via librosa
51
+ distro==1.9.0
52
+ # via
53
+ # groq
54
+ # openai
55
+ dnspython==2.7.0
56
+ # via aioice
57
+ eval-type-backport==0.2.0
58
+ # via pydantic-ai-slim
59
+ fastapi==0.115.5
60
+ # via gradio
61
+ ffmpy==0.4.0
62
+ # via gradio
63
+ filelock==3.16.1
64
+ # via huggingface-hub
65
+ flatbuffers==24.3.25
66
+ # via onnxruntime
67
+ fsspec==2024.10.0
68
+ # via
69
+ # gradio-client
70
+ # huggingface-hub
71
+ google-auth==2.36.0
72
+ # via pydantic-ai-slim
73
+ google-crc32c==1.6.0
74
+ # via aiortc
75
+ gradio==5.7.1
76
+ # via gradio-webrtc
77
+ gradio-client==1.5.0
78
+ # via gradio
79
+ gradio-webrtc==0.0.16rc1
80
+ # via -r requirements.in
81
+ griffe==1.5.1
82
+ # via pydantic-ai-slim
83
+ groq==0.13.0
84
+ # via
85
+ # -r requirements.in
86
+ # pydantic-ai-slim
87
+ h11==0.14.0
88
+ # via
89
+ # httpcore
90
+ # uvicorn
91
+ httpcore==1.0.7
92
+ # via httpx
93
+ httpx==0.28.0
94
+ # via
95
+ # gradio
96
+ # gradio-client
97
+ # groq
98
+ # openai
99
+ # pydantic-ai-slim
100
+ # safehttpx
101
+ huggingface-hub==0.26.3
102
+ # via
103
+ # gradio
104
+ # gradio-client
105
+ humanfriendly==10.0
106
+ # via coloredlogs
107
+ idna==3.10
108
+ # via
109
+ # anyio
110
+ # httpx
111
+ # requests
112
+ ifaddr==0.2.0
113
+ # via aioice
114
+ jinja2==3.1.4
115
+ # via gradio
116
+ jiter==0.8.0
117
+ # via openai
118
+ joblib==1.4.2
119
+ # via
120
+ # librosa
121
+ # scikit-learn
122
+ lazy-loader==0.4
123
+ # via librosa
124
+ librosa==0.10.2.post1
125
+ # via gradio-webrtc
126
+ llvmlite==0.43.0
127
+ # via numba
128
+ logfire-api==2.6.0
129
+ # via pydantic-ai-slim
130
+ markdown-it-py==3.0.0
131
+ # via rich
132
+ markupsafe==2.1.5
133
+ # via
134
+ # gradio
135
+ # jinja2
136
+ mdurl==0.1.2
137
+ # via markdown-it-py
138
+ mpmath==1.3.0
139
+ # via sympy
140
+ msgpack==1.1.0
141
+ # via librosa
142
+ numba==0.60.0
143
+ # via
144
+ # -r requirements.in
145
+ # librosa
146
+ numpy==2.0.2
147
+ # via
148
+ # gradio
149
+ # librosa
150
+ # numba
151
+ # onnxruntime
152
+ # pandas
153
+ # scikit-learn
154
+ # scipy
155
+ # soxr
156
+ onnxruntime==1.20.1
157
+ # via gradio-webrtc
158
+ openai==1.56.1
159
+ # via pydantic-ai-slim
160
+ orjson==3.10.12
161
+ # via gradio
162
+ packaging==24.2
163
+ # via
164
+ # gradio
165
+ # gradio-client
166
+ # huggingface-hub
167
+ # lazy-loader
168
+ # onnxruntime
169
+ # pooch
170
+ pandas==2.2.3
171
+ # via gradio
172
+ pillow==11.0.0
173
+ # via gradio
174
+ platformdirs==4.3.6
175
+ # via pooch
176
+ pooch==1.8.2
177
+ # via librosa
178
+ protobuf==5.29.0
179
+ # via onnxruntime
180
+ pyasn1==0.6.1
181
+ # via
182
+ # pyasn1-modules
183
+ # rsa
184
+ pyasn1-modules==0.4.1
185
+ # via google-auth
186
+ pycparser==2.22
187
+ # via cffi
188
+ pydantic==2.10.3
189
+ # via
190
+ # fastapi
191
+ # gradio
192
+ # groq
193
+ # openai
194
+ # pydantic-ai-slim
195
+ pydantic-ai==0.0.8
196
+ # via -r requirements.in
197
+ pydantic-ai-slim==0.0.8
198
+ # via pydantic-ai
199
+ pydantic-core==2.27.1
200
+ # via pydantic
201
+ pydub==0.25.1
202
+ # via gradio
203
+ pyee==12.1.1
204
+ # via aiortc
205
+ pygments==2.18.0
206
+ # via rich
207
+ pylibsrtp==0.10.0
208
+ # via aiortc
209
+ pyopenssl==24.3.0
210
+ # via aiortc
211
+ python-dateutil==2.9.0.post0
212
+ # via pandas
213
+ python-dotenv==1.0.1
214
+ # via -r requirements.in
215
+ python-multipart==0.0.12
216
+ # via gradio
217
+ pytz==2024.2
218
+ # via pandas
219
+ pyyaml==6.0.2
220
+ # via
221
+ # gradio
222
+ # huggingface-hub
223
+ requests==2.32.3
224
+ # via
225
+ # huggingface-hub
226
+ # pooch
227
+ # pydantic-ai-slim
228
+ rich==13.9.4
229
+ # via typer
230
+ rsa==4.9
231
+ # via google-auth
232
+ ruff==0.8.1
233
+ # via gradio
234
+ safehttpx==0.1.6
235
+ # via gradio
236
+ scikit-learn==1.5.2
237
+ # via librosa
238
+ scipy==1.14.1
239
+ # via
240
+ # librosa
241
+ # scikit-learn
242
+ semantic-version==2.10.0
243
+ # via gradio
244
+ shellingham==1.5.4
245
+ # via typer
246
+ six==1.16.0
247
+ # via python-dateutil
248
+ sniffio==1.3.1
249
+ # via
250
+ # anyio
251
+ # groq
252
+ # openai
253
+ soundfile==0.12.1
254
+ # via librosa
255
+ soxr==0.5.0.post1
256
+ # via librosa
257
+ starlette==0.41.3
258
+ # via
259
+ # fastapi
260
+ # gradio
261
+ sympy==1.13.3
262
+ # via onnxruntime
263
+ threadpoolctl==3.5.0
264
+ # via scikit-learn
265
+ tomlkit==0.12.0
266
+ # via gradio
267
+ tqdm==4.67.1
268
+ # via
269
+ # huggingface-hub
270
+ # openai
271
+ typer==0.15.0
272
+ # via gradio
273
+ typing-extensions==4.12.2
274
+ # via
275
+ # fastapi
276
+ # gradio
277
+ # gradio-client
278
+ # groq
279
+ # huggingface-hub
280
+ # librosa
281
+ # openai
282
+ # pydantic
283
+ # pydantic-core
284
+ # pyee
285
+ # typer
286
+ tzdata==2024.2
287
+ # via pandas
288
+ urllib3==2.2.3
289
+ # via requests
290
+ uvicorn==0.32.1
291
+ # via gradio
292
+ websockets==12.0
293
+ # via gradio-client