Spaces:
Sleeping
Sleeping
from typing import Optional | |
import gradio as gr | |
import sys | |
sys.path.append("./src") | |
from src.pipeline import pipeline | |
from src.helpers.data_loaders import load_places | |
def clear(): | |
return None, None, None | |
# Function to update the list of cities based on the selected country | |
def update_cities(selected_country, df): | |
filtered_cities = df[df['country'] == selected_country]['city'].tolist() | |
return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default | |
def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024, | |
temp: Optional[float] = 0.49, starting_point: Optional[str] = "Munich"): | |
pipeline_response = pipeline( | |
query=query_text, | |
model_name=model_name, | |
sustainability=is_sustainable, | |
starting_point=starting_point, | |
) | |
return pipeline_response | |
def create_ui(): | |
data_file = "cities/eu_200_cities.csv" | |
df = load_places(data_file) | |
df = df.sort_values(by=['country', 'city']) | |
examples = [ | |
["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and " | |
"local cuisines to try?", "GPT-4"], | |
["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"], | |
["Suggest some cities that can be visited from London and are very rich in history and culture.", | |
"Gemini-1.0-pro"], | |
] | |
with gr.Blocks() as app: | |
gr.HTML( | |
"<center><h1 style='font-size:xx-large; font-color: green'>π Green City Finder π</h1><h3>AI Sprint 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the " | |
"compatibility of" | |
"Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 " | |
"Pro</b> \n " | |
"models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.\n " | |
"We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector " | |
"embeddings are stored in a VectorDB (LanceDB) hosted in Google Cloud.\n " | |
"<p>Sustainability is calculated based on the work by <a href=https://arxiv.org/abs/2403.18604>Banerjee " | |
"et al.</a></p>\n " | |
" </p> <br>Google Cloud credits are provided for this project. </p>\n" | |
" ") | |
with gr.Group(): | |
countries = gr.Dropdown(choices=list(df.country.unique()), multiselect=False, label="Country") | |
starting_point = gr.Dropdown(choices=[], multiselect=False, | |
label="Select your starting point for the trip!") | |
countries.select(fn=lambda selected_country: | |
update_cities(selected_country, df), | |
inputs=countries, outputs=starting_point) | |
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!") | |
sustainable = gr.Checkbox(label="Sustainable", info="Do you want your recommendations to be sustainable " | |
"with regards to the environment, your starting " | |
"location and month of travel?") | |
# TODO: Add model options, month and starting point | |
model = gr.Dropdown( | |
["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more " | |
"models " | |
"later!", | |
) | |
output = gr.Textbox(label="Generated Results", lines=4) | |
with gr.Accordion("Settings", open=False): | |
max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64, | |
interactive=True, | |
visible=True, info="The maximum number of output tokens") | |
temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49, | |
interactive=True, | |
visible=True, info="The value used to module the logits distribution") | |
with gr.Group(): | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
cancel_btn = gr.Button("Cancel", variant="stop") | |
submit_btn.click(generate_text, inputs=[query, model, sustainable, starting_point], outputs=[output]) | |
clear_btn.click(clear, inputs=[], outputs=[query, model, output]) | |
cancel_btn.click(clear, inputs=[], outputs=[query, model, output]) | |
gr.Markdown("## Examples") | |
# gr.Examples( | |
# examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output], | |
# cache_examples=True, | |
# ) | |
return app | |
if __name__ == "__main__": | |
app = create_ui() | |
app.launch(show_api=False) | |