Spaces:
Running
Running
from __future__ import annotations as _annotations | |
import json | |
import os | |
from dataclasses import dataclass | |
from typing import Any | |
import gradio as gr | |
from dotenv import load_dotenv | |
from httpx import AsyncClient | |
from pydantic_ai import Agent, ModelRetry, RunContext | |
from pydantic_ai.messages import ModelStructuredResponse, ModelTextResponse, ToolReturn | |
load_dotenv() | |
class Deps: | |
client: AsyncClient | |
weather_api_key: str | None | |
geo_api_key: str | None | |
weather_agent = Agent( | |
"openai:gpt-4o", | |
system_prompt="You are an expert packer. A user will ask you for help packing for a trip given a destination. Use your weather tools to provide a concise and effective packing list. Also ask follow up questions if neccessary.", | |
deps_type=Deps, | |
retries=2, | |
) | |
async def get_lat_lng( | |
ctx: RunContext[Deps], location_description: str | |
) -> dict[str, float]: | |
"""Get the latitude and longitude of a location. | |
Args: | |
ctx: The context. | |
location_description: A description of a location. | |
""" | |
if ctx.deps.geo_api_key is None: | |
# if no API key is provided, return a dummy response (London) | |
return {"lat": 51.1, "lng": -0.1} | |
params = { | |
"q": location_description, | |
"api_key": ctx.deps.geo_api_key, | |
} | |
r = await ctx.deps.client.get("https://geocode.maps.co/search", params=params) | |
r.raise_for_status() | |
data = r.json() | |
if data: | |
return {"lat": data[0]["lat"], "lng": data[0]["lon"]} | |
else: | |
raise ModelRetry("Could not find the location") | |
async def get_weather(ctx: RunContext[Deps], lat: float, lng: float) -> dict[str, Any]: | |
"""Get the weather at a location. | |
Args: | |
ctx: The context. | |
lat: Latitude of the location. | |
lng: Longitude of the location. | |
""" | |
if ctx.deps.weather_api_key is None: | |
# if no API key is provided, return a dummy response | |
return {"temperature": "21 °C", "description": "Sunny"} | |
params = { | |
"apikey": ctx.deps.weather_api_key, | |
"location": f"{lat},{lng}", | |
"units": "metric", | |
} | |
r = await ctx.deps.client.get( | |
"https://api.tomorrow.io/v4/weather/realtime", params=params | |
) | |
r.raise_for_status() | |
data = r.json() | |
values = data["data"]["values"] | |
# https://docs.tomorrow.io/reference/data-layers-weather-codes | |
code_lookup = { | |
1000: "Clear, Sunny", | |
1100: "Mostly Clear", | |
1101: "Partly Cloudy", | |
1102: "Mostly Cloudy", | |
1001: "Cloudy", | |
2000: "Fog", | |
2100: "Light Fog", | |
4000: "Drizzle", | |
4001: "Rain", | |
4200: "Light Rain", | |
4201: "Heavy Rain", | |
5000: "Snow", | |
5001: "Flurries", | |
5100: "Light Snow", | |
5101: "Heavy Snow", | |
6000: "Freezing Drizzle", | |
6001: "Freezing Rain", | |
6200: "Light Freezing Rain", | |
6201: "Heavy Freezing Rain", | |
7000: "Ice Pellets", | |
7101: "Heavy Ice Pellets", | |
7102: "Light Ice Pellets", | |
8000: "Thunderstorm", | |
} | |
return { | |
"temperature": f'{values["temperatureApparent"]:0.0f}°C', | |
"description": code_lookup.get(values["weatherCode"], "Unknown"), | |
} | |
TOOL_TO_DISPLAY_NAME = {"get_lat_lng": "Geocoding API", "get_weather": "Weather API"} | |
client = AsyncClient() | |
weather_api_key = os.getenv("WEATHER_API_KEY") | |
# create a free API key at https://geocode.maps.co/ | |
geo_api_key = os.getenv("GEO_API_KEY") | |
deps = Deps(client=client, weather_api_key=weather_api_key, geo_api_key=geo_api_key) | |
async def stream_from_agent(prompt: str, chatbot: list[dict], past_messages: list): | |
chatbot.append({"role": "user", "content": prompt}) | |
yield gr.Textbox(interactive=False, value=""), chatbot, gr.skip() | |
async with weather_agent.run_stream( | |
prompt, deps=deps, message_history=past_messages | |
) as result: | |
for message in result.new_messages(): | |
past_messages.append(message) | |
if isinstance(message, ModelStructuredResponse): | |
for call in message.calls: | |
gr_message = { | |
"role": "assistant", | |
"content": "", | |
"metadata": { | |
"title": f"### 🛠️ Using {TOOL_TO_DISPLAY_NAME[call.tool_name]}", | |
"id": call.tool_id, | |
}, | |
} | |
chatbot.append(gr_message) | |
if isinstance(message, ToolReturn): | |
for gr_message in chatbot: | |
if gr_message.get("metadata", {}).get("id", "") == message.tool_id: | |
gr_message["content"] = f"Output: {json.dumps(message.content)}" | |
yield gr.skip(), chatbot, gr.skip() | |
chatbot.append({"role": "assistant", "content": ""}) | |
async for message in result.stream_text(): | |
chatbot[-1]["content"] = message | |
yield gr.skip(), chatbot, gr.skip() | |
data = await result.get_data() | |
past_messages.append(ModelTextResponse(content=data)) | |
yield gr.Textbox(interactive=True), gr.skip(), past_messages | |
async def handle_retry(chatbot, past_messages: list, retry_data: gr.RetryData): | |
new_history = chatbot[: retry_data.index] | |
previous_prompt = chatbot[retry_data.index]["content"] | |
past_messages = past_messages[: retry_data.index] | |
async for update in stream_from_agent(previous_prompt, new_history, past_messages): | |
yield update | |
def undo(chatbot, past_messages: list, undo_data: gr.UndoData): | |
new_history = chatbot[: undo_data.index] | |
past_messages = past_messages[: undo_data.index] | |
return chatbot[undo_data.index]["content"], new_history, past_messages | |
def select_data(message: gr.SelectData) -> str: | |
return message.value["text"] | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<div style="display: flex; justify-content: center; align-items: center; gap: 2rem; padding: 1rem; width: 100%"> | |
<img src="https://ai.pydantic.dev/img/logo-white.svg" style="max-width: 200px; height: auto"> | |
<div> | |
<h1 style="margin: 0 0 1rem 0">Vacation Packing Assistant</h1> | |
<h3 style="margin: 0 0 0.5rem 0"> | |
This assistant will help you pack for your vacation. Enter your destination and it will provide you with a concise packing list based on the weather forecast. | |
</h3> | |
<h3 style="margin: 0"> | |
Feel free to ask for help with any other questions you have about your trip! | |
</h3> | |
</div> | |
</div> | |
""" | |
) | |
past_messages = gr.State([]) | |
chatbot = gr.Chatbot( | |
label="Packing Assistant", | |
type="messages", | |
avatar_images=(None, "https://ai.pydantic.dev/img/logo-white.svg"), | |
examples=[ | |
{"text": "I am going to Paris for the holidays, what should I pack?"}, | |
{"text": "I am going to Tokyo this week."}, | |
], | |
) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
lines=1, | |
show_label=False, | |
placeholder="I am planning a trip to Miami, what should I pack?", | |
) | |
generation = prompt.submit( | |
stream_from_agent, | |
inputs=[prompt, chatbot, past_messages], | |
outputs=[prompt, chatbot, past_messages], | |
) | |
chatbot.example_select(select_data, None, [prompt]) | |
chatbot.retry( | |
handle_retry, [chatbot, past_messages], [prompt, chatbot, past_messages] | |
) | |
chatbot.undo(undo, [chatbot, past_messages], [prompt, chatbot, past_messages]) | |
if __name__ == "__main__": | |
demo.launch() | |