Spaces:
Running
Running
from typing import Tuple | |
import gradio as gr | |
from src.helpers.data_loaders import load_places | |
from src.text_generation.mapper import MODEL_MAPPER | |
def get_places(): | |
data_file = "cities/eu_200_cities.csv" | |
df = load_places(data_file) | |
df = df.sort_values(by=['country', 'city']) | |
return df | |
# Function to update the list of cities based on the selected country | |
def update_cities(selected_country, df): | |
filtered_cities = df[df['country'] == selected_country]['city'].tolist() | |
return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default | |
def main_component() -> Tuple[gr.Dropdown, gr.Dropdown, gr.Textbox, gr.Checkbox, gr.Dropdown]: | |
""" | |
Creates the main Gradio interface components and returns them. | |
Returns: | |
Tuple containing: | |
- countries: Dropdown for selecting the country. | |
- starting_point: Dropdown for selecting the starting point. | |
- query: Textbox for entering the user query. | |
- sustainable: Checkbox for sustainable travel. | |
- model: Dropdown for selecting the model. | |
""" | |
df = get_places() | |
country_names = list(df.country.unique()) | |
with gr.Group(): | |
# Country selection dropdown | |
countries = gr.Dropdown(choices=country_names, multiselect=False, label="Country") | |
# Starting point selection dropdown | |
starting_point = gr.Dropdown(choices=[], multiselect=False, label="Select your starting point for the trip!") | |
# When a country is selected, update the starting point options | |
countries.select( | |
fn=lambda selected_country: update_cities(selected_country, df), | |
inputs=countries, | |
outputs=starting_point | |
) | |
# User query input | |
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!") | |
# Checkbox for sustainable travel option | |
sustainable = gr.Checkbox( | |
label="Sustainable", | |
info="Do you want your recommendations to be sustainable with regards to the environment, " | |
"your starting location, and month of travel?" | |
) | |
models = list(MODEL_MAPPER.keys())[:5] | |
# Model selection dropdown | |
model = gr.Dropdown( | |
choices=models, | |
label="Model", | |
info="Select your model. The model will generate the recommendations based on your query." | |
) | |
# Return all the components individually | |
return countries, starting_point, query, sustainable, model | |