Ashmi Banerjee
updates
32cae45
from typing import Tuple, Optional
import gradio as gr
from src.helpers.data_loaders import load_places
from src.text_generation.mapper import MODEL_MAPPER
def get_places():
data_file = "cities/eu_200_cities.csv"
df = load_places(data_file)
df = df.sort_values(by=['country', 'city'])
return df
def update_cities(selected_country, df):
if not selected_country:
return gr.update(choices=[], value=None, interactive=False)
# Filter cities based on the selected country
filtered_cities = df[df['country'] == selected_country]['city'].tolist()
return gr.Dropdown(choices=filtered_cities, interactive=True)
def main_component() -> Tuple[gr.Dropdown, gr.Dropdown, gr.Textbox, gr.Dropdown]:
"""
Creates the main Gradio interface components and returns them.
Returns:
Tuple containing:
- countries: Dropdown for selecting the country.
- starting_point: Dropdown for selecting the starting point.
- query: Textbox for entering the user query.
- sustainable: Checkbox for sustainable travel.
- model: Dropdown for selecting the model.
"""
df = get_places()
country_names = list(df.country.unique())
cities = list(df.city.unique())
with gr.Group():
# Country selection dropdown
country = gr.Dropdown(choices=country_names, multiselect=False, label="Country")
# Starting point selection dropdown
starting_point = gr.Dropdown(choices=cities, multiselect=False,
label="City",
info="Select a city as your starting point.")
# # When a country is selected, update the starting point options
country.select(
fn=lambda selected_country: update_cities(selected_country, df),
inputs=country,
outputs=starting_point
)
# User query input
query = gr.Textbox(label="Enter your preferences e.g. beaches, night life etc. and ask for your "
"recommendation for European cities!", placeholder="Ask for your city recommendation"
" here!")
# Checkbox for sustainable travel option
# sustainable = gr.Checkbox(
# label="Sustainable",
# info="Do you want your recommendations to be sustainable with regards to the environment, "
# "your starting location, and month of travel?"
# )
models = list(MODEL_MAPPER.keys())[:2]
# Model selection dropdown
model = gr.Dropdown(
choices=models,
label="Model",
info="Select your model. The model will generate sustainable recommendations based on your query."
)
# Return all the components individually
return country, starting_point, query, model