File size: 4,908 Bytes
27f8cfc
 
 
 
 
 
 
 
 
 
35b631a
27f8cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89845e5
 
 
 
 
27f8cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a80cc10
27f8cfc
 
 
 
 
 
a80cc10
27f8cfc
 
 
 
 
 
 
 
 
 
 
35b631a
27f8cfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from datetime import datetime, timedelta
import json
import requests
import streamlit as st
from any_agent import AgentFramework
from any_agent.tracing.trace import _is_tracing_supported
from any_agent.evaluation import EvaluationCase
from any_agent.evaluation.schemas import CheckpointCriteria
import pandas as pd
from constants import DEFAULT_EVALUATION_CASE, MODEL_OPTIONS
import copy

from pydantic import BaseModel, ConfigDict


class UserInputs(BaseModel):
    model_config = ConfigDict(extra="forbid")
    model_id: str
    location: str
    max_driving_hours: int
    date: datetime
    framework: str
    evaluation_case: EvaluationCase
    run_evaluation: bool


@st.cache_resource
def get_area(area_name: str) -> dict:
    """Get the area from Nominatim.

    Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).

    Args:
        area_name (str): The name of the area.

    Returns:
        dict: The area found.
    """
    response = requests.get(
        f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
        headers={"User-Agent": "Mozilla/5.0"},
        timeout=5,
    )
    response.raise_for_status()
    response_json = json.loads(response.content.decode())
    return response_json


def get_user_inputs() -> UserInputs:
    default_val = "Los Angeles California, US"

    location = st.text_input("Enter a location", value=default_val)
    if location:
        location_check = get_area(location)
        if not location_check:
            st.error("❌ Invalid location")

    max_driving_hours = st.number_input(
        "Enter the maximum driving hours", min_value=1, value=2
    )

    col_date, col_time = st.columns([2, 1])
    with col_date:
        date = st.date_input(
            "Select a date in the future", value=datetime.now() + timedelta(days=1)
        )
    with col_time:
        # default to 9am
        time = st.selectbox(
            "Select a time",
            [datetime.strptime(f"{i:02d}:00", "%H:%M").time() for i in range(24)],
            index=9,
        )
    date = datetime.combine(date, time)

    supported_frameworks = [
        framework for framework in AgentFramework if _is_tracing_supported(framework)
    ]

    framework = st.selectbox(
        "Select the agent framework to use",
        supported_frameworks,
        index=2,
        format_func=lambda x: x.name,
    )

    model_id = st.selectbox(
        "Select the model to use",
        MODEL_OPTIONS,
        index=1,
        format_func=lambda x: "/".join(x.split("/")[-3:]),
    )

    # Add evaluation case section
    with st.expander("Custom Evaluation"):
        evaluation_model_id = st.selectbox(
            "Select the model to use for LLM-as-a-Judge evaluation",
            MODEL_OPTIONS,
            index=2,
            format_func=lambda x: "/".join(x.split("/")[-3:]),
        )
        evaluation_case = copy.deepcopy(DEFAULT_EVALUATION_CASE)
        evaluation_case.llm_judge = evaluation_model_id
        # make this an editable json section
        # convert the checkpoints to a df series so that it can be edited
        checkpoints = evaluation_case.checkpoints
        checkpoints_df = pd.DataFrame(
            [checkpoint.model_dump() for checkpoint in checkpoints]
        )
        checkpoints_df = st.data_editor(
            checkpoints_df,
            column_config={
                "points": st.column_config.NumberColumn(label="Points"),
                "criteria": st.column_config.TextColumn(label="Criteria"),
            },
            hide_index=True,
            num_rows="dynamic",
        )
        # for each checkpoint, convert it back to a CheckpointCriteria object
        new_ckpts = []

        # don't let a user add more than 20 checkpoints
        if len(checkpoints_df) > 20:
            st.error(
                "You can only add up to 20 checkpoints for the purpose of this demo."
            )
            checkpoints_df = checkpoints_df[:20]

        for _, row in checkpoints_df.iterrows():
            if row["criteria"] == "":
                continue
            try:
                # Don't let people write essays for criteria in this demo
                if len(row["criteria"].split(" ")) > 100:
                    raise ValueError("Criteria is too long")
                new_crit = CheckpointCriteria(
                    criteria=row["criteria"], points=row["points"]
                )
                new_ckpts.append(new_crit)
            except Exception as e:
                st.error(f"Error creating checkpoint: {e}")
        evaluation_case.checkpoints = new_ckpts

    return UserInputs(
        model_id=model_id,
        location=location,
        max_driving_hours=max_driving_hours,
        date=date,
        framework=framework,
        evaluation_case=evaluation_case,
        run_evaluation=st.checkbox("Run Evaluation", value=True),
    )