Spaces:
Running
Running
Ashmi Banerjee
commited on
Commit
·
ac20456
1
Parent(s):
f4d1603
updates to the s-fairness calculation and refactoring code duplication
Browse files- README.md +15 -5
- app.py +85 -58
- src/augmentation/prompt_generation.py +0 -1
- src/helpers/__init__.py +0 -0
- src/helpers/creds_loader.py +0 -0
- src/helpers/data_loaders.py +52 -0
- src/information_retrieval/info_retrieval.py +20 -12
- src/pipeline.py +2 -2
- src/sustainability/s_fairness.py +89 -53
- src/text_generation/vertexai_setup.py +1 -1
README.md
CHANGED
@@ -15,8 +15,18 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
|
|
15 |
|
16 |
### TODOs
|
17 |
|
18 |
-
[
|
19 |
-
|
20 |
-
[x]
|
21 |
-
|
22 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
### TODOs
|
17 |
|
18 |
+
- [ ] Refactor the vectordb.py - remove code duplication
|
19 |
+
|
20 |
+
- [x] Sustainability - database paths - move to HF
|
21 |
+
|
22 |
+
- [ ] Fix it for the new models e.g. Llama and others
|
23 |
+
|
24 |
+
- [ ] Add the space secrets to have it running online
|
25 |
+
|
26 |
+
- [ ] Fix the google application json file
|
27 |
+
|
28 |
+
- [ ] Make the space public
|
29 |
+
|
30 |
+
- [x] Add emissions calculation and starting point
|
31 |
+
- [x] Add more cities to starting point
|
32 |
+
- [ ] Experiment with the sustainability & without sustainability prompt
|
app.py
CHANGED
@@ -1,78 +1,105 @@
|
|
1 |
from typing import Optional
|
2 |
import gradio as gr
|
3 |
-
import
|
|
|
4 |
sys.path.append("./src")
|
5 |
-
print(os.getcwd())
|
6 |
from src.pipeline import pipeline
|
|
|
7 |
|
8 |
|
9 |
def clear():
|
10 |
return None, None, None
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
|
14 |
-
temp: Optional[float] = 0.49):
|
15 |
-
if is_sustainable:
|
16 |
-
sustainability = 1
|
17 |
-
else:
|
18 |
-
sustainability = 0
|
19 |
pipeline_response = pipeline(
|
20 |
query=query_text,
|
21 |
model_name=model_name,
|
22 |
-
sustainability=
|
|
|
23 |
)
|
24 |
return pipeline_response
|
25 |
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
if __name__ == "__main__":
|
78 |
-
|
|
|
|
1 |
from typing import Optional
|
2 |
import gradio as gr
|
3 |
+
import sys
|
4 |
+
|
5 |
sys.path.append("./src")
|
|
|
6 |
from src.pipeline import pipeline
|
7 |
+
from src.helpers.data_loaders import load_places
|
8 |
|
9 |
|
10 |
def clear():
|
11 |
return None, None, None
|
12 |
|
13 |
|
14 |
+
# Function to update the list of cities based on the selected country
|
15 |
+
def update_cities(selected_country, df):
|
16 |
+
filtered_cities = df[df['country'] == selected_country]['city'].tolist()
|
17 |
+
return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
|
18 |
+
|
19 |
+
|
20 |
def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
|
21 |
+
temp: Optional[float] = 0.49, starting_point: Optional[str] = "Munich"):
|
|
|
|
|
|
|
|
|
22 |
pipeline_response = pipeline(
|
23 |
query=query_text,
|
24 |
model_name=model_name,
|
25 |
+
sustainability=is_sustainable,
|
26 |
+
starting_point=starting_point,
|
27 |
)
|
28 |
return pipeline_response
|
29 |
|
30 |
|
31 |
+
def create_ui():
|
32 |
+
data_file = "cities/eu_200_cities.csv"
|
33 |
+
df = load_places(data_file)
|
34 |
+
df = df.sort_values(by=['country', 'city'])
|
35 |
+
|
36 |
+
examples = [
|
37 |
+
["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
|
38 |
+
"local cuisines to try?", "GPT-4"],
|
39 |
+
["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
|
40 |
+
["Suggest some cities that can be visited from London and are very rich in history and culture.",
|
41 |
+
"Gemini-1.0-pro"],
|
42 |
+
]
|
43 |
+
|
44 |
+
with gr.Blocks() as app:
|
45 |
+
gr.HTML(
|
46 |
+
"<center><h1 style='font-size:xx-large; font-color: green'>🍀 Green City Finder 🍀</h1><h3>AI Sprint 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the "
|
47 |
+
"compatibility of"
|
48 |
+
"Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 "
|
49 |
+
"Pro</b> \n "
|
50 |
+
"models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.\n "
|
51 |
+
"We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector "
|
52 |
+
"embeddings are stored in a VectorDB (LanceDB) hosted in Google Cloud.\n "
|
53 |
+
"<p>Sustainability is calculated based on the work by <a href=https://arxiv.org/abs/2403.18604>Banerjee "
|
54 |
+
"et al.</a></p>\n "
|
55 |
+
" </p> <br>Google Cloud credits are provided for this project. </p>\n"
|
56 |
+
" ")
|
57 |
+
|
58 |
+
with gr.Group():
|
59 |
+
countries = gr.Dropdown(choices=list(df.country), multiselect=False, label="Countries")
|
60 |
+
starting_point = gr.Dropdown(choices=[], multiselect=False,
|
61 |
+
label="Select your starting point for the trip!")
|
62 |
+
|
63 |
+
countries.select(fn=lambda selected_country:
|
64 |
+
update_cities(selected_country, df),
|
65 |
+
inputs=countries, outputs=starting_point)
|
66 |
+
|
67 |
+
query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
|
68 |
+
sustainable = gr.Checkbox(label="Sustainable", info="Do you want your recommendations to be sustainable "
|
69 |
+
"with regards to the environment, your starting "
|
70 |
+
"location and month of travel?")
|
71 |
+
# TODO: Add model options, month and starting point
|
72 |
+
model = gr.Dropdown(
|
73 |
+
["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
|
74 |
+
"models "
|
75 |
+
"later!",
|
76 |
+
)
|
77 |
+
output = gr.Textbox(label="Generated Results", lines=4)
|
78 |
+
|
79 |
+
with gr.Accordion("Settings", open=False):
|
80 |
+
max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
|
81 |
+
interactive=True,
|
82 |
+
visible=True, info="The maximum number of output tokens")
|
83 |
+
temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
|
84 |
+
interactive=True,
|
85 |
+
visible=True, info="The value used to module the logits distribution")
|
86 |
+
with gr.Group():
|
87 |
+
with gr.Row():
|
88 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
89 |
+
clear_btn = gr.Button("Clear", variant="secondary")
|
90 |
+
cancel_btn = gr.Button("Cancel", variant="stop")
|
91 |
+
submit_btn.click(generate_text, inputs=[query, model, sustainable, starting_point], outputs=[output])
|
92 |
+
clear_btn.click(clear, inputs=[], outputs=[query, model, output])
|
93 |
+
cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
|
94 |
+
|
95 |
+
gr.Markdown("## Examples")
|
96 |
+
# gr.Examples(
|
97 |
+
# examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
|
98 |
+
# cache_examples=True,
|
99 |
+
# )
|
100 |
+
return app
|
101 |
+
|
102 |
|
103 |
if __name__ == "__main__":
|
104 |
+
app = create_ui()
|
105 |
+
app.launch(show_api=False)
|
src/augmentation/prompt_generation.py
CHANGED
@@ -158,7 +158,6 @@ def test():
|
|
158 |
|
159 |
# without sustainability
|
160 |
context = ir.get_context(query, **context_params)
|
161 |
-
# formatted_context = format_context(context)
|
162 |
|
163 |
without_sfairness = augment_prompt(
|
164 |
query=query,
|
|
|
158 |
|
159 |
# without sustainability
|
160 |
context = ir.get_context(query, **context_params)
|
|
|
161 |
|
162 |
without_sfairness = augment_prompt(
|
163 |
query=query,
|
src/helpers/__init__.py
ADDED
File without changes
|
src/helpers/creds_loader.py
ADDED
File without changes
|
src/helpers/data_loaders.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from datasets import DatasetDict
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
from typing import Optional
|
7 |
+
load_dotenv()
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
12 |
+
|
13 |
+
HF_TOKEN = os.environ["HF_TOKEN"]
|
14 |
+
|
15 |
+
|
16 |
+
def load_data_hf(repo_name: str, data_files: str, is_public: bool) -> DatasetDict:
|
17 |
+
if is_public:
|
18 |
+
dataset = load_dataset(repo_name, split="train")
|
19 |
+
else:
|
20 |
+
dataset = load_dataset(repo_name, token=True, data_files=data_files)
|
21 |
+
return dataset
|
22 |
+
|
23 |
+
|
24 |
+
def load_scores(category: str) -> pd.DataFrame | None:
|
25 |
+
repository = os.environ.get("DATA_REPO")
|
26 |
+
data_file = None
|
27 |
+
match category:
|
28 |
+
case "popularity":
|
29 |
+
data_file = "computed/popularity/popularity_scores.csv"
|
30 |
+
case "seasonality":
|
31 |
+
data_file = "computed/seasonality/seasonality_scores.csv"
|
32 |
+
case "emissions":
|
33 |
+
data_file = "computed/emissions/emissions_merged.csv"
|
34 |
+
case _:
|
35 |
+
logger.info(f"Invalid category: {category}")
|
36 |
+
if data_file: # only for valid categories
|
37 |
+
data = load_data_hf(repository, data_file, is_public=False)
|
38 |
+
df = pd.DataFrame(data["train"][:])
|
39 |
+
return df
|
40 |
+
return None
|
41 |
+
|
42 |
+
|
43 |
+
def load_places(data_file: str) -> pd.DataFrame | None:
|
44 |
+
repository = os.environ.get("DATA_REPO")
|
45 |
+
|
46 |
+
if data_file:
|
47 |
+
data = load_data_hf(repository, data_file, is_public=False)
|
48 |
+
df = pd.DataFrame(data["train"][:])
|
49 |
+
return df
|
50 |
+
|
51 |
+
return None
|
52 |
+
|
src/information_retrieval/info_retrieval.py
CHANGED
@@ -10,6 +10,7 @@ import logging
|
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
12 |
|
|
|
13 |
|
14 |
def get_travel_months(query):
|
15 |
"""
|
@@ -91,7 +92,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
|
|
91 |
return results
|
92 |
|
93 |
|
94 |
-
def get_sustainability_scores(query, destinations):
|
95 |
"""
|
96 |
|
97 |
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
|
@@ -109,15 +110,20 @@ def get_sustainability_scores(query, destinations):
|
|
109 |
months = get_travel_months(query)
|
110 |
logger.info("Finished parsing query for months.")
|
111 |
|
|
|
|
|
|
|
|
|
|
|
112 |
for city in destinations:
|
113 |
if city not in city_scores:
|
114 |
city_scores[city] = []
|
115 |
|
116 |
if not months: # no month(s) or seasons provided by the user
|
117 |
-
city_scores[city].append(s_fairness.compute_sfairness_score(city))
|
118 |
else:
|
119 |
for month in months:
|
120 |
-
city_scores[city].append(s_fairness.compute_sfairness_score(city, month))
|
121 |
|
122 |
logger.info("Finished getting s-fairness scores.")
|
123 |
|
@@ -130,7 +136,8 @@ def get_sustainability_scores(query, destinations):
|
|
130 |
result.append({
|
131 |
'city': city,
|
132 |
'month': 'No data available',
|
133 |
-
's-fairness': 'No data available'
|
|
|
134 |
})
|
135 |
break
|
136 |
|
@@ -139,14 +146,15 @@ def get_sustainability_scores(query, destinations):
|
|
139 |
result.append({
|
140 |
'city': city,
|
141 |
'month': min_score['month'],
|
142 |
-
's-fairness': min_score['s-fairness']
|
|
|
143 |
})
|
144 |
|
145 |
logger.info("Returning s-fairness results.")
|
146 |
return result
|
147 |
|
148 |
|
149 |
-
def get_cities(context):
|
150 |
"""
|
151 |
Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
|
152 |
|
@@ -184,9 +192,8 @@ def get_cities(context):
|
|
184 |
return recommended_cities
|
185 |
|
186 |
|
187 |
-
def get_context(query, **params):
|
188 |
"""
|
189 |
-
|
190 |
Function that returns all the context: from the database, as well as the respective s-fairness scores for the
|
191 |
destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
|
192 |
parameter "sustainability" needs to be explicitly passed to params.
|
@@ -210,12 +217,13 @@ def get_context(query, **params):
|
|
210 |
recommended_cities = wikivoyage_context.keys()
|
211 |
|
212 |
if 'sustainability' in params and params['sustainability']:
|
213 |
-
s_fairness_scores = get_sustainability_scores(query, recommended_cities)
|
214 |
|
215 |
for score in s_fairness_scores:
|
216 |
wikivoyage_context[score['city']]['sustainability'] = {
|
217 |
'month': score['month'],
|
218 |
-
's-fairness': score['s-fairness']
|
|
|
219 |
}
|
220 |
|
221 |
return wikivoyage_context
|
@@ -225,11 +233,11 @@ def test():
|
|
225 |
queries = []
|
226 |
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
|
227 |
"in winter. "
|
228 |
-
|
229 |
context = None
|
230 |
|
231 |
try:
|
232 |
-
context = get_context(query, sustainability=1)
|
233 |
# cities = get_cities(context)
|
234 |
# print(cities)
|
235 |
except FileNotFoundError as e:
|
|
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
12 |
|
13 |
+
from src.helpers.data_loaders import load_scores
|
14 |
|
15 |
def get_travel_months(query):
|
16 |
"""
|
|
|
92 |
return results
|
93 |
|
94 |
|
95 |
+
def get_sustainability_scores(starting_point: str , query: str, destinations: list):
|
96 |
"""
|
97 |
|
98 |
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
|
|
|
110 |
months = get_travel_months(query)
|
111 |
logger.info("Finished parsing query for months.")
|
112 |
|
113 |
+
popularity_data = load_scores("popularity")
|
114 |
+
seasonality_data = load_scores("seasonality")
|
115 |
+
emissions_data = load_scores("emissions")
|
116 |
+
data = [popularity_data, seasonality_data, emissions_data]
|
117 |
+
|
118 |
for city in destinations:
|
119 |
if city not in city_scores:
|
120 |
city_scores[city] = []
|
121 |
|
122 |
if not months: # no month(s) or seasons provided by the user
|
123 |
+
city_scores[city].append(s_fairness.compute_sfairness_score(data, starting_point, city))
|
124 |
else:
|
125 |
for month in months:
|
126 |
+
city_scores[city].append(s_fairness.compute_sfairness_score(data, city, month))
|
127 |
|
128 |
logger.info("Finished getting s-fairness scores.")
|
129 |
|
|
|
136 |
result.append({
|
137 |
'city': city,
|
138 |
'month': 'No data available',
|
139 |
+
's-fairness': 'No data available',
|
140 |
+
'mode': 'No data available'
|
141 |
})
|
142 |
break
|
143 |
|
|
|
146 |
result.append({
|
147 |
'city': city,
|
148 |
'month': min_score['month'],
|
149 |
+
's-fairness': min_score['s-fairness'],
|
150 |
+
'mode': min_score['mode'],
|
151 |
})
|
152 |
|
153 |
logger.info("Returning s-fairness results.")
|
154 |
return result
|
155 |
|
156 |
|
157 |
+
def get_cities(context: dict):
|
158 |
"""
|
159 |
Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
|
160 |
|
|
|
192 |
return recommended_cities
|
193 |
|
194 |
|
195 |
+
def get_context(starting_point: str, query: str, **params):
|
196 |
"""
|
|
|
197 |
Function that returns all the context: from the database, as well as the respective s-fairness scores for the
|
198 |
destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
|
199 |
parameter "sustainability" needs to be explicitly passed to params.
|
|
|
217 |
recommended_cities = wikivoyage_context.keys()
|
218 |
|
219 |
if 'sustainability' in params and params['sustainability']:
|
220 |
+
s_fairness_scores = get_sustainability_scores(starting_point, query, recommended_cities)
|
221 |
|
222 |
for score in s_fairness_scores:
|
223 |
wikivoyage_context[score['city']]['sustainability'] = {
|
224 |
'month': score['month'],
|
225 |
+
's-fairness': score['s-fairness'],
|
226 |
+
'transport': score['mode']
|
227 |
}
|
228 |
|
229 |
return wikivoyage_context
|
|
|
233 |
queries = []
|
234 |
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
|
235 |
"in winter. "
|
236 |
+
starting_point = "Munich"
|
237 |
context = None
|
238 |
|
239 |
try:
|
240 |
+
context = get_context(starting_point, query, sustainability=1)
|
241 |
# cities = get_cities(context)
|
242 |
# print(cities)
|
243 |
except FileNotFoundError as e:
|
src/pipeline.py
CHANGED
@@ -37,7 +37,7 @@ MODELS = {
|
|
37 |
}
|
38 |
|
39 |
|
40 |
-
def pipeline(query: str, model_name: str, test: int = 0, **params):
|
41 |
"""
|
42 |
|
43 |
Executes the entire RAG pipeline, provided the query and model class name.
|
@@ -73,7 +73,7 @@ def pipeline(query: str, model_name: str, test: int = 0, **params):
|
|
73 |
|
74 |
logger.info("Retrieving context..")
|
75 |
try:
|
76 |
-
context = ir.get_context(query=query, **context_params)
|
77 |
if test:
|
78 |
retrieved_cities = ir.get_cities(context)
|
79 |
else:
|
|
|
37 |
}
|
38 |
|
39 |
|
40 |
+
def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **params):
|
41 |
"""
|
42 |
|
43 |
Executes the entire RAG pipeline, provided the query and model class name.
|
|
|
73 |
|
74 |
logger.info("Retrieving context..")
|
75 |
try:
|
76 |
+
context = ir.get_context(starting_point=starting_point, query=query, **context_params)
|
77 |
if test:
|
78 |
retrieved_cities = ir.get_cities(context)
|
79 |
else:
|
src/sustainability/s_fairness.py
CHANGED
@@ -1,100 +1,126 @@
|
|
1 |
import sys
|
2 |
import os
|
|
|
|
|
3 |
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
import logging
|
|
|
6 |
|
|
|
7 |
logger = logging.getLogger(__name__)
|
8 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
9 |
|
10 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
11 |
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
def get_popularity(destination):
|
17 |
"""
|
18 |
-
|
19 |
-
Returns the popularity score for a particular destination.
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
"""
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
else:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
if not len(popularity_df[popularity_df['city'] == destination]):
|
36 |
-
print(f"{destination} does not have popularity data")
|
37 |
-
return None
|
38 |
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
-
def
|
|
|
43 |
"""
|
44 |
|
45 |
-
Returns the seasonality score for a particular destination
|
|
|
|
|
46 |
the best month, i.e. month of lowest seasonality is returned.
|
47 |
|
48 |
Args:
|
49 |
- destination: str
|
50 |
- month: str (default: None)
|
|
|
51 |
|
52 |
"""
|
53 |
-
parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
|
54 |
|
55 |
-
if
|
56 |
-
|
57 |
-
|
58 |
-
seasonality_path = seasonality_dir
|
59 |
-
seasonality_df = pd.read_csv(seasonality_path + "seasonality_scores.csv")
|
60 |
-
|
61 |
-
# Check if city is present in dataframe
|
62 |
-
if not len(seasonality_df[seasonality_df['city'] == destination]):
|
63 |
-
logger.info(f"{destination} does not have seasonality data for {month}")
|
64 |
return None, None
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
"""
|
79 |
|
80 |
Returns the s-fairness score for a particular destination city and (optional) month. If the destination doesn't
|
81 |
have popularity or seasonality scores, then the function returns None.
|
82 |
|
83 |
Args:
|
|
|
|
|
84 |
- destination: str
|
85 |
- month: str (default: None)
|
86 |
|
87 |
"""
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
# RECHECK
|
94 |
-
if
|
95 |
-
s_fairness = round(0.281 *
|
96 |
return {
|
97 |
'month': month,
|
|
|
98 |
's-fairness': s_fairness
|
99 |
}
|
100 |
# elif popularity is not None: # => seasonality is None
|
@@ -106,9 +132,19 @@ def compute_sfairness_score(destination, month=None):
|
|
106 |
else:
|
107 |
return {
|
108 |
'month': None,
|
|
|
109 |
's-fairness': None
|
110 |
}
|
111 |
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
if __name__ == "__main__":
|
114 |
-
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
+
from typing import Optional, Dict, Any
|
4 |
+
|
5 |
import pandas as pd
|
|
|
6 |
import logging
|
7 |
+
from dotenv import load_dotenv
|
8 |
|
9 |
+
load_dotenv()
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
12 |
|
13 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
14 |
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
15 |
|
16 |
+
def get_emission_scores(emissions_df: pd.DataFrame, starting_point: str, destination: str, ):
|
|
|
|
|
|
|
17 |
"""
|
|
|
|
|
18 |
|
19 |
+
Returns the emission score for the connection with least co2e between two cities.
|
20 |
+
:param emissions_df:
|
21 |
+
:param starting_point:
|
22 |
+
:param destination:
|
23 |
+
:return:
|
24 |
"""
|
25 |
+
df = emissions_df.loc[(emissions_df["city_1"] == starting_point) & (emissions_df["city_2"] == destination)]
|
26 |
+
if len(df) == 0:
|
27 |
+
logger.info(f"Connection not found between {starting_point} and {destination}")
|
28 |
+
return 0, None
|
29 |
+
df.loc[:, 'min_co2e'] = df[['fly_co2e_kg', 'drive_co2e_kg', 'train_co2e_kg']].min(axis=1)
|
30 |
+
df.loc[:, 'min_co2e_colname'] = df[['fly_co2e_kg', 'drive_co2e_kg', 'train_co2e_kg']].idxmin(axis=1)
|
31 |
+
min_co2e = df.min_co2e.values[0]
|
32 |
+
mode_prefix = (df.min_co2e_colname.values[0]).split("_")[0]
|
33 |
+
min_cost = df[mode_prefix + "_cost_EUR"].values[0]
|
34 |
+
if mode_prefix == "train":
|
35 |
+
min_travel_time = df[mode_prefix + "_time_mins"].values[0] / 60
|
36 |
else:
|
37 |
+
min_travel_time = df[mode_prefix + "_time_hrs"].values[0]
|
38 |
+
emission_score = 0.352 * min_travel_time + 0.218 * min_co2e + 0.431 * min_cost
|
39 |
+
return emission_score, mode_prefix
|
40 |
|
|
|
|
|
|
|
41 |
|
42 |
+
def _check_city_present(df: pd.DataFrame, starting_point: Optional[str] = None, destination: str = "",
|
43 |
+
category: str = "popularity"):
|
44 |
+
if category == "emissions":
|
45 |
+
if not ((df['city_1'] == starting_point) & (df['city_2'] == destination)).any():
|
46 |
+
return False
|
47 |
+
else:
|
48 |
+
return True
|
49 |
+
if not len(df[df['city'] == destination]):
|
50 |
+
return False
|
51 |
+
return True
|
52 |
|
53 |
|
54 |
+
def get_scores(df: pd.DataFrame, starting_point: Optional[str] = None, destination="",
|
55 |
+
month: Optional[str] = None, category: str = "popularity"):
|
56 |
"""
|
57 |
|
58 |
+
Returns the seasonality/popularity score for a particular destination.
|
59 |
+
Seasonality is calculated for a particular month, while popularity is year-round.
|
60 |
+
If no month is provided then
|
61 |
the best month, i.e. month of lowest seasonality is returned.
|
62 |
|
63 |
Args:
|
64 |
- destination: str
|
65 |
- month: str (default: None)
|
66 |
+
- category: str (default: "popularity")
|
67 |
|
68 |
"""
|
|
|
69 |
|
70 |
+
# Check if city is present in dataframe
|
71 |
+
if not _check_city_present(df, starting_point, destination, category):
|
72 |
+
logger.info(f"{destination} does not have {category} data")
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
return None, None
|
74 |
|
75 |
+
match category:
|
76 |
+
case "popularity":
|
77 |
+
return df[df['city'] == destination]['weighted_pop_score'].item()
|
78 |
+
case "seasonality":
|
79 |
+
dest_df = df.loc[df['city'] == destination]
|
80 |
+
if month:
|
81 |
+
m = month.capitalize()[:3]
|
82 |
+
else:
|
83 |
+
dest_df['lowest_col'] = dest_df.loc[:, dest_df.columns != 'city'].idxmin(axis="columns")
|
84 |
+
m = dest_df[dest_df['city'] == destination]['lowest_col'].item()
|
85 |
+
return m, dest_df[dest_df['city'] == destination][m].item()
|
86 |
+
case "emissions":
|
87 |
+
emissions = get_emission_scores(df, starting_point, destination)
|
88 |
+
return emissions
|
89 |
+
|
90 |
+
|
91 |
+
def compute_sfairness_score(data: list[pd.DataFrame],
|
92 |
+
starting_point: str, destination: str,
|
93 |
+
month: Optional[str] = None) -> dict[str, Any] | dict[str, None]:
|
94 |
"""
|
95 |
|
96 |
Returns the s-fairness score for a particular destination city and (optional) month. If the destination doesn't
|
97 |
have popularity or seasonality scores, then the function returns None.
|
98 |
|
99 |
Args:
|
100 |
+
- data: list[pd.DataFrame]
|
101 |
+
- starting_point: str
|
102 |
- destination: str
|
103 |
- month: str (default: None)
|
104 |
|
105 |
"""
|
106 |
+
popularity_score = get_scores(df=data[0],
|
107 |
+
starting_point=None,
|
108 |
+
destination=destination, month=None, category="popularity")
|
109 |
+
month, seasonality_score = get_scores(df=data[1],
|
110 |
+
starting_point=None, destination=destination,
|
111 |
+
month=month, category="seasonality")
|
112 |
+
|
113 |
+
emission_score, mode = get_scores(df=data[2],
|
114 |
+
starting_point=starting_point, destination=destination, category="emissions")
|
115 |
+
if emission_score is None:
|
116 |
+
emission_score = 0
|
117 |
|
118 |
# RECHECK
|
119 |
+
if seasonality_score is not None and popularity_score is not None:
|
120 |
+
s_fairness = round(0.281 * emission_score + 0.334 * popularity_score + 0.385 * seasonality_score, 3)
|
121 |
return {
|
122 |
'month': month,
|
123 |
+
'mode': mode, # 'fly', 'drive', 'train'
|
124 |
's-fairness': s_fairness
|
125 |
}
|
126 |
# elif popularity is not None: # => seasonality is None
|
|
|
132 |
else:
|
133 |
return {
|
134 |
'month': None,
|
135 |
+
'mode': None, # 'fly', 'drive', 'train'
|
136 |
's-fairness': None
|
137 |
}
|
138 |
|
139 |
|
140 |
+
def test():
|
141 |
+
popularity_data = load_data("popularity")
|
142 |
+
seasonality_data = load_data("seasonality")
|
143 |
+
emissions_data = load_data("emissions")
|
144 |
+
data = [popularity_data, seasonality_data, emissions_data]
|
145 |
+
print(compute_sfairness_score(data=data, starting_point="Munich", destination="Dijon"))
|
146 |
+
print(compute_sfairness_score(data=data, starting_point="Munich", destination="Strasbourg", month="Dec"))
|
147 |
+
|
148 |
+
|
149 |
if __name__ == "__main__":
|
150 |
+
test()
|
src/text_generation/vertexai_setup.py
CHANGED
@@ -21,7 +21,7 @@ def decode_service_key():
|
|
21 |
|
22 |
def initialize_vertexai_params(location: Optional[str] = "us-central1"):
|
23 |
|
24 |
-
creds_file_name = os.getcwd() + "/.config/
|
25 |
print(creds_file_name)
|
26 |
if not os.path.exists(os.path.dirname(creds_file_name)):
|
27 |
credentials = decode_service_key()
|
|
|
21 |
|
22 |
def initialize_vertexai_params(location: Optional[str] = "us-central1"):
|
23 |
|
24 |
+
creds_file_name = os.getcwd() + "/.config/gcp_default_credentials.json"
|
25 |
print(creds_file_name)
|
26 |
if not os.path.exists(os.path.dirname(creds_file_name)):
|
27 |
credentials = decode_service_key()
|