Spaces:
Running
Running
File size: 4,908 Bytes
27f8cfc 35b631a 27f8cfc 89845e5 27f8cfc a80cc10 27f8cfc a80cc10 27f8cfc 35b631a 27f8cfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from datetime import datetime, timedelta
import json
import requests
import streamlit as st
from any_agent import AgentFramework
from any_agent.tracing.trace import _is_tracing_supported
from any_agent.evaluation import EvaluationCase
from any_agent.evaluation.schemas import CheckpointCriteria
import pandas as pd
from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
import copy
from pydantic import BaseModel, ConfigDict
class UserInputs(BaseModel):
model_config = ConfigDict(extra="forbid")
model_id: str
location: str
max_driving_hours: int
date: datetime
framework: str
evaluation_case: EvaluationCase
run_evaluation: bool
@st.cache_resource
def get_area(area_name: str) -> dict:
"""Get the area from Nominatim.
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
Args:
area_name (str): The name of the area.
Returns:
dict: The area found.
"""
response = requests.get(
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
headers={"User-Agent": "Mozilla/5.0"},
timeout=5,
)
response.raise_for_status()
response_json = json.loads(response.content.decode())
return response_json
def get_user_inputs() -> UserInputs:
default_val = "Los Angeles California, US"
location = st.text_input("Enter a location", value=default_val)
if location:
location_check = get_area(location)
if not location_check:
st.error("β Invalid location")
max_driving_hours = st.number_input(
"Enter the maximum driving hours", min_value=1, value=2
)
col_date, col_time = st.columns([2, 1])
with col_date:
date = st.date_input(
"Select a date in the future", value=datetime.now() + timedelta(days=1)
)
with col_time:
# default to 9am
time = st.selectbox(
"Select a time",
[datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
index=9,
)
date = datetime.combine(date, time)
supported_frameworks = [
framework for framework in AgentFramework if _is_tracing_supported(framework)
]
framework = st.selectbox(
"Select the agent framework to use",
supported_frameworks,
index=2,
format_func=lambda x: x.name,
)
model_id = st.selectbox(
"Select the model to use",
MODEL_OPTIONS,
index=1,
format_func=lambda x: "/".join(x.split("/")[-3:]),
)
# Add evaluation case section
with st.expander("Custom Evaluation"):
evaluation_model_id = st.selectbox(
"Select the model to use for LLM-as-a-Judge evaluation",
MODEL_OPTIONS,
index=2,
format_func=lambda x: "/".join(x.split("/")[-3:]),
)
evaluation_case = copy.deepcopy(DEFAULT_EVALUATION_CASE)
evaluation_case.llm_judge = evaluation_model_id
# make this an editable json section
# convert the checkpoints to a df series so that it can be edited
checkpoints = evaluation_case.checkpoints
checkpoints_df = pd.DataFrame(
[checkpoint.model_dump() for checkpoint in checkpoints]
)
checkpoints_df = st.data_editor(
checkpoints_df,
column_config={
"points": st.column_config.NumberColumn(label="Points"),
"criteria": st.column_config.TextColumn(label="Criteria"),
},
hide_index=True,
num_rows="dynamic",
)
# for each checkpoint, convert it back to a CheckpointCriteria object
new_ckpts = []
# don't let a user add more than 20 checkpoints
if len(checkpoints_df) > 20:
st.error(
"You can only add up to 20 checkpoints for the purpose of this demo."
)
checkpoints_df = checkpoints_df[:20]
for _, row in checkpoints_df.iterrows():
if row["criteria"] == "":
continue
try:
# Don't let people write essays for criteria in this demo
if len(row["criteria"].split(" ")) > 100:
raise ValueError("Criteria is too long")
new_crit = CheckpointCriteria(
criteria=row["criteria"], points=row["points"]
)
new_ckpts.append(new_crit)
except Exception as e:
st.error(f"Error creating checkpoint: {e}")
evaluation_case.checkpoints = new_ckpts
return UserInputs(
model_id=model_id,
location=location,
max_driving_hours=max_driving_hours,
date=date,
framework=framework,
evaluation_case=evaluation_case,
run_evaluation=st.checkbox("Run Evaluation", value=True),
)
|