Spaces:
Running
Running
from __future__ import annotations as _annotations | |
import json | |
import os | |
from dataclasses import dataclass | |
from typing import Any | |
import gradio as gr | |
from httpx import AsyncClient | |
from groq import Groq | |
import numpy as np | |
from gradio_webrtc import WebRTC, AdditionalOutputs, ReplyOnPause, audio_to_bytes | |
from pydantic_ai import Agent, ModelRetry, RunContext | |
from pydantic_ai.messages import ModelStructuredResponse, ToolReturn, ModelTextResponse | |
from dotenv import load_dotenv | |
load_dotenv() | |
class Deps: | |
client: AsyncClient | |
weather_api_key: str | None | |
geo_api_key: str | None | |
weather_agent = Agent( | |
'openai:gpt-4o', | |
system_prompt='You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.', | |
deps_type=Deps, | |
retries=2, | |
) | |
async def get_lat_lng( | |
ctx: RunContext[Deps], location_description: str | |
) -> dict[str, float]: | |
"""Get the latitude and longitude of a location. | |
Args: | |
ctx: The context. | |
location_description: A description of a location. | |
""" | |
if ctx.deps.geo_api_key is None: | |
# if no API key is provided, return a dummy response (London) | |
return {'lat': 51.1, 'lng': -0.1} | |
params = { | |
'q': location_description, | |
'api_key': ctx.deps.geo_api_key, | |
} | |
r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params) | |
r.raise_for_status() | |
data = r.json() | |
if data: | |
return {'lat': data[0]['lat'], 'lng': data[0]['lon']} | |
else: | |
raise ModelRetry('Could not find the location') | |
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]: | |
"""Get the weather at a location. | |
Args: | |
ctx: The context. | |
lat: Latitude of the location. | |
lng: Longitude of the location. | |
""" | |
if ctx.deps.weather_api_key is None: | |
# if no API key is provided, return a dummy response | |
return {'temperature': '21 °C', 'description': 'Sunny'} | |
params = { | |
'apikey': ctx.deps.weather_api_key, | |
'location': f'{lat},{lng}', | |
'units': 'metric', | |
} | |
r = await ctx.deps.client.get( | |
'https://api.tomorrow.io/v4/weather/realtime', params=params | |
) | |
r.raise_for_status() | |
data = r.json() | |
values = data['data']['values'] | |
# https://docs.tomorrow.io/reference/data-layers-weather-codes | |
code_lookup = { | |
1000: 'Clear, Sunny', | |
1100: 'Mostly Clear', | |
1101: 'Partly Cloudy', | |
1102: 'Mostly Cloudy', | |
1001: 'Cloudy', | |
2000: 'Fog', | |
2100: 'Light Fog', | |
4000: 'Drizzle', | |
4001: 'Rain', | |
4200: 'Light Rain', | |
4201: 'Heavy Rain', | |
5000: 'Snow', | |
5001: 'Flurries', | |
5100: 'Light Snow', | |
5101: 'Heavy Snow', | |
6000: 'Freezing Drizzle', | |
6001: 'Freezing Rain', | |
6200: 'Light Freezing Rain', | |
6201: 'Heavy Freezing Rain', | |
7000: 'Ice Pellets', | |
7101: 'Heavy Ice Pellets', | |
7102: 'Light Ice Pellets', | |
8000: 'Thunderstorm', | |
} | |
return { | |
'temperature': f'{values["temperatureApparent"]:0.0f}°C', | |
'description': code_lookup.get(values['weatherCode'], 'Unknown'), | |
} | |
TOOL_TO_DISPLAY_NAME = { | |
'get_lat_lng': 'Geocoding API', | |
"get_weather": "Weather API" | |
} | |
groq_client = Groq() | |
client = AsyncClient() | |
weather_api_key = os.getenv('WEATHER_API_KEY') | |
# create a free API key at https://geocode.maps.co/ | |
geo_api_key = os.getenv('GEO_API_KEY') | |
deps = Deps( | |
client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key | |
) | |
async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list): | |
yield "", gr.skip(), gr.skip() | |
chatbot.append({'role': 'user', 'content': prompt}) | |
yield gr.skip(), chatbot, gr.skip() | |
async with weather_agent.run_stream(prompt, deps=deps, message_history=past_messages) as result: | |
for message in result.new_messages(): | |
past_messages.append(message) | |
if isinstance(message, ModelStructuredResponse): | |
for call in message.calls: | |
gr_message = {"role": "assistant", | |
"content": "", | |
"metadata": {"title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}", | |
"id": call.tool_id} | |
} | |
chatbot.append(gr_message) | |
if isinstance(message, ToolReturn): | |
for gr_message in chatbot: | |
if gr_message.get('metadata', {}).get('id', "") == message.tool_id: | |
gr_message['content'] = f"Output: {json.dumps(message.content)}" | |
yield gr.skip(), chatbot, gr.skip() | |
chatbot.append({'role': 'assistant', 'content': ""}) | |
async for message in result.stream_text(): | |
chatbot[-1]["content"] = message | |
yield gr.skip(), chatbot, gr.skip() | |
data = await result.get_data() | |
past_messages.append(ModelTextResponse(content=data)) | |
yield gr.skip(), gr.skip(), past_messages | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%"> | |
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto"> | |
<div> | |
<h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1> | |
<h3 style="margin: 0 0 0.5rem 0"> | |
This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast. | |
</h3> | |
<h3 style="margin: 0"> | |
Feel free to ask for help with any other questions you have about your trip! | |
</h3> | |
</div> | |
</div> | |
""" | |
) | |
past_messages = gr.State([]) | |
chatbot = gr.Chatbot(label="Packing Assistant", type="messages", | |
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg")) | |
prompt = gr.Textbox(lines=1, label="Enter your destination or follow-up question", placeholder="Miami, Florida") | |
prompt.submit(stream_from_agent, inputs=[prompt, chatbot, past_messages], | |
outputs=[prompt, chatbot, past_messages]) | |
if __name__ == '__main__': | |
demo.launch() |