Ashmi Banerjee commited on
Commit
ac20456
·
1 Parent(s): f4d1603

updates to the s-fairness calculation and refactoring code duplication

Browse files
README.md CHANGED
@@ -15,8 +15,18 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
15
 
16
  ### TODOs
17
 
18
- [x] Refactor the vectordb.py - remove code duplication
19
- [x] Sustainability - database paths - move to HF
20
- [x] Fix it for the new models e.g. Llama and others
21
- [x] Add the space secrets to have it running online
22
- [x] Make the space public
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ### TODOs
17
 
18
+ - [ ] Refactor the vectordb.py - remove code duplication
19
+
20
+ - [x] Sustainability - database paths - move to HF
21
+
22
+ - [ ] Fix it for the new models e.g. Llama and others
23
+
24
+ - [ ] Add the space secrets to have it running online
25
+
26
+ - [ ] Fix the google application json file
27
+
28
+ - [ ] Make the space public
29
+
30
+ - [x] Add emissions calculation and starting point
31
+ - [x] Add more cities to starting point
32
+ - [ ] Experiment with the sustainability & without sustainability prompt
app.py CHANGED
@@ -1,78 +1,105 @@
1
  from typing import Optional
2
  import gradio as gr
3
- import os, sys
 
4
  sys.path.append("./src")
5
- print(os.getcwd())
6
  from src.pipeline import pipeline
 
7
 
8
 
9
  def clear():
10
  return None, None, None
11
 
12
 
 
 
 
 
 
 
13
  def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
14
- temp: Optional[float] = 0.49):
15
- if is_sustainable:
16
- sustainability = 1
17
- else:
18
- sustainability = 0
19
  pipeline_response = pipeline(
20
  query=query_text,
21
  model_name=model_name,
22
- sustainability= sustainability
 
23
  )
24
  return pipeline_response
25
 
26
 
27
- examples = [["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
28
- "local cuisines to try?", "GPT-4"],
29
- ["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
30
- ["Suggest some cities that can be visited from London and are very rich in history and culture.",
31
- "Gemini-1.0-pro"],
32
- ]
33
-
34
- with gr.Blocks() as demo:
35
- gr.HTML("""<center><h1 style='font-size:xx-large;'>🇪🇺 Euro City Recommender using Gemini & Gemma 🇪🇺</h1><br><h3>Gemini
36
- & Gemma Sprints 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the compatibility of
37
- Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 Pro</b>
38
- models through HuggingFace and VertexAI, respectively, to generate travel recommendations. This early version (read
39
- quick and dirty implementation) aims to see if functionalities work smoothly. It relies on Wikipedia abstracts
40
- from 160 European cities to provide answers to your questions. Please be kind with it, as it's a work in progress!
41
- </p> <br>Google Cloud credits are provided for this project. </p>
42
- """)
43
-
44
- with gr.Group():
45
- query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
46
- sustainable = gr.Checkbox(label="Sustainable", info="If you want sustainable recommendations for "
47
- "hidden gems?")
48
- model = gr.Dropdown(
49
- ["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
50
- "models "
51
- "later!",
52
- )
53
- output = gr.Textbox(label="Generated Results", lines=4)
54
-
55
- with gr.Accordion("Settings", open=False):
56
- max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
57
- interactive=True,
58
- visible=True, info="The maximum number of output tokens")
59
- temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
60
- interactive=True,
61
- visible=True, info="The value used to module the logits distribution")
62
- with gr.Group():
63
- with gr.Row():
64
- submit_btn = gr.Button("Submit", variant="primary")
65
- clear_btn = gr.Button("Clear", variant="secondary")
66
- cancel_btn = gr.Button("Cancel", variant="stop")
67
- submit_btn.click(generate_text, inputs=[query, model, sustainable], outputs=[output])
68
- clear_btn.click(clear, inputs=[], outputs=[query, model, output])
69
- cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
70
-
71
- gr.Markdown("## Examples")
72
- gr.Examples(
73
- examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
74
- cache_examples=True,
75
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
- demo.launch(show_api=False)
 
 
1
  from typing import Optional
2
  import gradio as gr
3
+ import sys
4
+
5
  sys.path.append("./src")
 
6
  from src.pipeline import pipeline
7
+ from src.helpers.data_loaders import load_places
8
 
9
 
10
  def clear():
11
  return None, None, None
12
 
13
 
14
+ # Function to update the list of cities based on the selected country
15
+ def update_cities(selected_country, df):
16
+ filtered_cities = df[df['country'] == selected_country]['city'].tolist()
17
+ return gr.Dropdown(choices=filtered_cities, interactive=True) # Make it interactive as it is not by default
18
+
19
+
20
  def generate_text(query_text, model_name: Optional[str], is_sustainable: Optional[bool], tokens: Optional[int] = 1024,
21
+ temp: Optional[float] = 0.49, starting_point: Optional[str] = "Munich"):
 
 
 
 
22
  pipeline_response = pipeline(
23
  query=query_text,
24
  model_name=model_name,
25
+ sustainability=is_sustainable,
26
+ starting_point=starting_point,
27
  )
28
  return pipeline_response
29
 
30
 
31
+ def create_ui():
32
+ data_file = "cities/eu_200_cities.csv"
33
+ df = load_places(data_file)
34
+ df = df.sort_values(by=['country', 'city'])
35
+
36
+ examples = [
37
+ ["I'm planning a vacation to France. Can you suggest a one-week itinerary including must-visit places and "
38
+ "local cuisines to try?", "GPT-4"],
39
+ ["I want to explore off-the-beaten-path destinations in Europe, any suggestions?", "Gemini-1.0-pro"],
40
+ ["Suggest some cities that can be visited from London and are very rich in history and culture.",
41
+ "Gemini-1.0-pro"],
42
+ ]
43
+
44
+ with gr.Blocks() as app:
45
+ gr.HTML(
46
+ "<center><h1 style='font-size:xx-large; font-color: green'>🍀 Green City Finder 🍀</h1><h3>AI Sprint 2024 submissions by Ashmi Banerjee. </h3></center> <br><p>We're testing the "
47
+ "compatibility of"
48
+ "Retrieval Augmented Generation (RAG) implementations with Google's <b>Gemma-2b-it</b> & <b>Gemini 1.0 "
49
+ "Pro</b> \n "
50
+ "models through HuggingFace and VertexAI, respectively, to generate sustainable travel recommendations.\n "
51
+ "We use the Wikivoyage dataset to provide city recommendations based on user queries. The vector "
52
+ "embeddings are stored in a VectorDB (LanceDB) hosted in Google Cloud.\n "
53
+ "<p>Sustainability is calculated based on the work by <a href=https://arxiv.org/abs/2403.18604>Banerjee "
54
+ "et al.</a></p>\n "
55
+ " </p> <br>Google Cloud credits are provided for this project. </p>\n"
56
+ " ")
57
+
58
+ with gr.Group():
59
+ countries = gr.Dropdown(choices=list(df.country), multiselect=False, label="Countries")
60
+ starting_point = gr.Dropdown(choices=[], multiselect=False,
61
+ label="Select your starting point for the trip!")
62
+
63
+ countries.select(fn=lambda selected_country:
64
+ update_cities(selected_country, df),
65
+ inputs=countries, outputs=starting_point)
66
+
67
+ query = gr.Textbox(label="Query", placeholder="Ask for your city recommendation here!")
68
+ sustainable = gr.Checkbox(label="Sustainable", info="Do you want your recommendations to be sustainable "
69
+ "with regards to the environment, your starting "
70
+ "location and month of travel?")
71
+ # TODO: Add model options, month and starting point
72
+ model = gr.Dropdown(
73
+ ["GPT-4", "Gemini-1.0-pro"], label="Model", info="Select your model. Will add more "
74
+ "models "
75
+ "later!",
76
+ )
77
+ output = gr.Textbox(label="Generated Results", lines=4)
78
+
79
+ with gr.Accordion("Settings", open=False):
80
+ max_new_tokens = gr.Slider(label="Max new tokens", value=1024, minimum=0, maximum=8192, step=64,
81
+ interactive=True,
82
+ visible=True, info="The maximum number of output tokens")
83
+ temperature = gr.Slider(label="Temperature", step=0.01, minimum=0.01, maximum=1.0, value=0.49,
84
+ interactive=True,
85
+ visible=True, info="The value used to module the logits distribution")
86
+ with gr.Group():
87
+ with gr.Row():
88
+ submit_btn = gr.Button("Submit", variant="primary")
89
+ clear_btn = gr.Button("Clear", variant="secondary")
90
+ cancel_btn = gr.Button("Cancel", variant="stop")
91
+ submit_btn.click(generate_text, inputs=[query, model, sustainable, starting_point], outputs=[output])
92
+ clear_btn.click(clear, inputs=[], outputs=[query, model, output])
93
+ cancel_btn.click(clear, inputs=[], outputs=[query, model, output])
94
+
95
+ gr.Markdown("## Examples")
96
+ # gr.Examples(
97
+ # examples, inputs=[query, model], label="Examples", fn=generate_text, outputs=[output],
98
+ # cache_examples=True,
99
+ # )
100
+ return app
101
+
102
 
103
  if __name__ == "__main__":
104
+ app = create_ui()
105
+ app.launch(show_api=False)
src/augmentation/prompt_generation.py CHANGED
@@ -158,7 +158,6 @@ def test():
158
 
159
  # without sustainability
160
  context = ir.get_context(query, **context_params)
161
- # formatted_context = format_context(context)
162
 
163
  without_sfairness = augment_prompt(
164
  query=query,
 
158
 
159
  # without sustainability
160
  context = ir.get_context(query, **context_params)
 
161
 
162
  without_sfairness = augment_prompt(
163
  query=query,
src/helpers/__init__.py ADDED
File without changes
src/helpers/creds_loader.py ADDED
File without changes
src/helpers/data_loaders.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from dotenv import load_dotenv
3
+ from datasets import DatasetDict
4
+ import os
5
+ import pandas as pd
6
+ from typing import Optional
7
+ load_dotenv()
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
12
+
13
+ HF_TOKEN = os.environ["HF_TOKEN"]
14
+
15
+
16
+ def load_data_hf(repo_name: str, data_files: str, is_public: bool) -> DatasetDict:
17
+ if is_public:
18
+ dataset = load_dataset(repo_name, split="train")
19
+ else:
20
+ dataset = load_dataset(repo_name, token=True, data_files=data_files)
21
+ return dataset
22
+
23
+
24
+ def load_scores(category: str) -> pd.DataFrame | None:
25
+ repository = os.environ.get("DATA_REPO")
26
+ data_file = None
27
+ match category:
28
+ case "popularity":
29
+ data_file = "computed/popularity/popularity_scores.csv"
30
+ case "seasonality":
31
+ data_file = "computed/seasonality/seasonality_scores.csv"
32
+ case "emissions":
33
+ data_file = "computed/emissions/emissions_merged.csv"
34
+ case _:
35
+ logger.info(f"Invalid category: {category}")
36
+ if data_file: # only for valid categories
37
+ data = load_data_hf(repository, data_file, is_public=False)
38
+ df = pd.DataFrame(data["train"][:])
39
+ return df
40
+ return None
41
+
42
+
43
+ def load_places(data_file: str) -> pd.DataFrame | None:
44
+ repository = os.environ.get("DATA_REPO")
45
+
46
+ if data_file:
47
+ data = load_data_hf(repository, data_file, is_public=False)
48
+ df = pd.DataFrame(data["train"][:])
49
+ return df
50
+
51
+ return None
52
+
src/information_retrieval/info_retrieval.py CHANGED
@@ -10,6 +10,7 @@ import logging
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
12
 
 
13
 
14
  def get_travel_months(query):
15
  """
@@ -91,7 +92,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
91
  return results
92
 
93
 
94
- def get_sustainability_scores(query, destinations):
95
  """
96
 
97
  Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
@@ -109,15 +110,20 @@ def get_sustainability_scores(query, destinations):
109
  months = get_travel_months(query)
110
  logger.info("Finished parsing query for months.")
111
 
 
 
 
 
 
112
  for city in destinations:
113
  if city not in city_scores:
114
  city_scores[city] = []
115
 
116
  if not months: # no month(s) or seasons provided by the user
117
- city_scores[city].append(s_fairness.compute_sfairness_score(city))
118
  else:
119
  for month in months:
120
- city_scores[city].append(s_fairness.compute_sfairness_score(city, month))
121
 
122
  logger.info("Finished getting s-fairness scores.")
123
 
@@ -130,7 +136,8 @@ def get_sustainability_scores(query, destinations):
130
  result.append({
131
  'city': city,
132
  'month': 'No data available',
133
- 's-fairness': 'No data available'
 
134
  })
135
  break
136
 
@@ -139,14 +146,15 @@ def get_sustainability_scores(query, destinations):
139
  result.append({
140
  'city': city,
141
  'month': min_score['month'],
142
- 's-fairness': min_score['s-fairness']
 
143
  })
144
 
145
  logger.info("Returning s-fairness results.")
146
  return result
147
 
148
 
149
- def get_cities(context):
150
  """
151
  Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
152
 
@@ -184,9 +192,8 @@ def get_cities(context):
184
  return recommended_cities
185
 
186
 
187
- def get_context(query, **params):
188
  """
189
-
190
  Function that returns all the context: from the database, as well as the respective s-fairness scores for the
191
  destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
192
  parameter "sustainability" needs to be explicitly passed to params.
@@ -210,12 +217,13 @@ def get_context(query, **params):
210
  recommended_cities = wikivoyage_context.keys()
211
 
212
  if 'sustainability' in params and params['sustainability']:
213
- s_fairness_scores = get_sustainability_scores(query, recommended_cities)
214
 
215
  for score in s_fairness_scores:
216
  wikivoyage_context[score['city']]['sustainability'] = {
217
  'month': score['month'],
218
- 's-fairness': score['s-fairness']
 
219
  }
220
 
221
  return wikivoyage_context
@@ -225,11 +233,11 @@ def test():
225
  queries = []
226
  query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
227
  "in winter. "
228
-
229
  context = None
230
 
231
  try:
232
- context = get_context(query, sustainability=1)
233
  # cities = get_cities(context)
234
  # print(cities)
235
  except FileNotFoundError as e:
 
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
12
 
13
+ from src.helpers.data_loaders import load_scores
14
 
15
  def get_travel_months(query):
16
  """
 
92
  return results
93
 
94
 
95
+ def get_sustainability_scores(starting_point: str , query: str, destinations: list):
96
  """
97
 
98
  Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
 
110
  months = get_travel_months(query)
111
  logger.info("Finished parsing query for months.")
112
 
113
+ popularity_data = load_scores("popularity")
114
+ seasonality_data = load_scores("seasonality")
115
+ emissions_data = load_scores("emissions")
116
+ data = [popularity_data, seasonality_data, emissions_data]
117
+
118
  for city in destinations:
119
  if city not in city_scores:
120
  city_scores[city] = []
121
 
122
  if not months: # no month(s) or seasons provided by the user
123
+ city_scores[city].append(s_fairness.compute_sfairness_score(data, starting_point, city))
124
  else:
125
  for month in months:
126
+ city_scores[city].append(s_fairness.compute_sfairness_score(data, city, month))
127
 
128
  logger.info("Finished getting s-fairness scores.")
129
 
 
136
  result.append({
137
  'city': city,
138
  'month': 'No data available',
139
+ 's-fairness': 'No data available',
140
+ 'mode': 'No data available'
141
  })
142
  break
143
 
 
146
  result.append({
147
  'city': city,
148
  'month': min_score['month'],
149
+ 's-fairness': min_score['s-fairness'],
150
+ 'mode': min_score['mode'],
151
  })
152
 
153
  logger.info("Returning s-fairness results.")
154
  return result
155
 
156
 
157
+ def get_cities(context: dict):
158
  """
159
  Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
160
 
 
192
  return recommended_cities
193
 
194
 
195
+ def get_context(starting_point: str, query: str, **params):
196
  """
 
197
  Function that returns all the context: from the database, as well as the respective s-fairness scores for the
198
  destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
199
  parameter "sustainability" needs to be explicitly passed to params.
 
217
  recommended_cities = wikivoyage_context.keys()
218
 
219
  if 'sustainability' in params and params['sustainability']:
220
+ s_fairness_scores = get_sustainability_scores(starting_point, query, recommended_cities)
221
 
222
  for score in s_fairness_scores:
223
  wikivoyage_context[score['city']]['sustainability'] = {
224
  'month': score['month'],
225
+ 's-fairness': score['s-fairness'],
226
+ 'transport': score['mode']
227
  }
228
 
229
  return wikivoyage_context
 
233
  queries = []
234
  query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
235
  "in winter. "
236
+ starting_point = "Munich"
237
  context = None
238
 
239
  try:
240
+ context = get_context(starting_point, query, sustainability=1)
241
  # cities = get_cities(context)
242
  # print(cities)
243
  except FileNotFoundError as e:
src/pipeline.py CHANGED
@@ -37,7 +37,7 @@ MODELS = {
37
  }
38
 
39
 
40
- def pipeline(query: str, model_name: str, test: int = 0, **params):
41
  """
42
 
43
  Executes the entire RAG pipeline, provided the query and model class name.
@@ -73,7 +73,7 @@ def pipeline(query: str, model_name: str, test: int = 0, **params):
73
 
74
  logger.info("Retrieving context..")
75
  try:
76
- context = ir.get_context(query=query, **context_params)
77
  if test:
78
  retrieved_cities = ir.get_cities(context)
79
  else:
 
37
  }
38
 
39
 
40
+ def pipeline(starting_point: str, query: str, model_name: str, test: int = 0, **params):
41
  """
42
 
43
  Executes the entire RAG pipeline, provided the query and model class name.
 
73
 
74
  logger.info("Retrieving context..")
75
  try:
76
+ context = ir.get_context(starting_point=starting_point, query=query, **context_params)
77
  if test:
78
  retrieved_cities = ir.get_cities(context)
79
  else:
src/sustainability/s_fairness.py CHANGED
@@ -1,100 +1,126 @@
1
  import sys
2
  import os
 
 
3
  import pandas as pd
4
- import numpy as np
5
  import logging
 
6
 
 
7
  logger = logging.getLogger(__name__)
8
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
9
 
10
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
11
  sys.path.append(os.path.dirname(SCRIPT_DIR))
12
 
13
- from data_directories import *
14
-
15
-
16
- def get_popularity(destination):
17
  """
18
-
19
- Returns the popularity score for a particular destination.
20
 
21
- Args:
22
- - destination: str
23
-
 
 
24
  """
25
-
26
- parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
27
-
28
- if "src" in os.getcwd() and os.path.exists(os.path.join(parent_path, "european-city-data")):
29
- popularity_path = popularity_dir.replace("../../", "../")
 
 
 
 
 
 
30
  else:
31
- popularity_path = popularity_dir
32
-
33
- popularity_df = pd.read_csv(popularity_path + "popularity_scores.csv")
34
 
35
- if not len(popularity_df[popularity_df['city'] == destination]):
36
- print(f"{destination} does not have popularity data")
37
- return None
38
 
39
- return popularity_df[popularity_df['city'] == destination]['weighted_pop_score'].item()
 
 
 
 
 
 
 
 
 
40
 
41
 
42
- def get_seasonality(destination, month=None):
 
43
  """
44
 
45
- Returns the seasonality score for a particular destination for a particular month. If no month is provided then
 
 
46
  the best month, i.e. month of lowest seasonality is returned.
47
 
48
  Args:
49
  - destination: str
50
  - month: str (default: None)
 
51
 
52
  """
53
- parent_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
54
 
55
- if "src" in os.getcwd() and os.path.exists(os.path.join(parent_path, "european-city-data")):
56
- seasonality_path = seasonality_dir.replace("../../", "../")
57
- else:
58
- seasonality_path = seasonality_dir
59
- seasonality_df = pd.read_csv(seasonality_path + "seasonality_scores.csv")
60
-
61
- # Check if city is present in dataframe
62
- if not len(seasonality_df[seasonality_df['city'] == destination]):
63
- logger.info(f"{destination} does not have seasonality data for {month}")
64
  return None, None
65
 
66
- if month:
67
- m = month.capitalize()[:3]
68
- else:
69
- seasonality_df['lowest_col'] = seasonality_df.loc[:, seasonality_df.columns != 'city'].idxmin(axis="columns")
70
- m = seasonality_df[seasonality_df['city'] == destination]['lowest_col'].item()
71
-
72
- # print(destination, m, seasonality_df[seasonality_df['city'] == destination][m])
73
-
74
- return m, seasonality_df[seasonality_df['city'] == destination][m].item()
75
-
76
-
77
- def compute_sfairness_score(destination, month=None):
 
 
 
 
 
 
 
78
  """
79
 
80
  Returns the s-fairness score for a particular destination city and (optional) month. If the destination doesn't
81
  have popularity or seasonality scores, then the function returns None.
82
 
83
  Args:
 
 
84
  - destination: str
85
  - month: str (default: None)
86
 
87
  """
88
- seasonality = get_seasonality(destination, month)
89
- month = seasonality[0]
90
- popularity = get_popularity(destination)
91
- emissions = 0
 
 
 
 
 
 
 
92
 
93
  # RECHECK
94
- if seasonality[1] is not None and popularity is not None:
95
- s_fairness = round(0.281 * emissions + 0.334 * popularity + 0.385 * seasonality[1], 3)
96
  return {
97
  'month': month,
 
98
  's-fairness': s_fairness
99
  }
100
  # elif popularity is not None: # => seasonality is None
@@ -106,9 +132,19 @@ def compute_sfairness_score(destination, month=None):
106
  else:
107
  return {
108
  'month': None,
 
109
  's-fairness': None
110
  }
111
 
112
 
 
 
 
 
 
 
 
 
 
113
  if __name__ == "__main__":
114
- print(compute_sfairness_score("Paris", "Oct"))
 
1
  import sys
2
  import os
3
+ from typing import Optional, Dict, Any
4
+
5
  import pandas as pd
 
6
  import logging
7
+ from dotenv import load_dotenv
8
 
9
+ load_dotenv()
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
12
 
13
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
14
  sys.path.append(os.path.dirname(SCRIPT_DIR))
15
 
16
+ def get_emission_scores(emissions_df: pd.DataFrame, starting_point: str, destination: str, ):
 
 
 
17
  """
 
 
18
 
19
+ Returns the emission score for the connection with least co2e between two cities.
20
+ :param emissions_df:
21
+ :param starting_point:
22
+ :param destination:
23
+ :return:
24
  """
25
+ df = emissions_df.loc[(emissions_df["city_1"] == starting_point) & (emissions_df["city_2"] == destination)]
26
+ if len(df) == 0:
27
+ logger.info(f"Connection not found between {starting_point} and {destination}")
28
+ return 0, None
29
+ df.loc[:, 'min_co2e'] = df[['fly_co2e_kg', 'drive_co2e_kg', 'train_co2e_kg']].min(axis=1)
30
+ df.loc[:, 'min_co2e_colname'] = df[['fly_co2e_kg', 'drive_co2e_kg', 'train_co2e_kg']].idxmin(axis=1)
31
+ min_co2e = df.min_co2e.values[0]
32
+ mode_prefix = (df.min_co2e_colname.values[0]).split("_")[0]
33
+ min_cost = df[mode_prefix + "_cost_EUR"].values[0]
34
+ if mode_prefix == "train":
35
+ min_travel_time = df[mode_prefix + "_time_mins"].values[0] / 60
36
  else:
37
+ min_travel_time = df[mode_prefix + "_time_hrs"].values[0]
38
+ emission_score = 0.352 * min_travel_time + 0.218 * min_co2e + 0.431 * min_cost
39
+ return emission_score, mode_prefix
40
 
 
 
 
41
 
42
+ def _check_city_present(df: pd.DataFrame, starting_point: Optional[str] = None, destination: str = "",
43
+ category: str = "popularity"):
44
+ if category == "emissions":
45
+ if not ((df['city_1'] == starting_point) & (df['city_2'] == destination)).any():
46
+ return False
47
+ else:
48
+ return True
49
+ if not len(df[df['city'] == destination]):
50
+ return False
51
+ return True
52
 
53
 
54
+ def get_scores(df: pd.DataFrame, starting_point: Optional[str] = None, destination="",
55
+ month: Optional[str] = None, category: str = "popularity"):
56
  """
57
 
58
+ Returns the seasonality/popularity score for a particular destination.
59
+ Seasonality is calculated for a particular month, while popularity is year-round.
60
+ If no month is provided then
61
  the best month, i.e. month of lowest seasonality is returned.
62
 
63
  Args:
64
  - destination: str
65
  - month: str (default: None)
66
+ - category: str (default: "popularity")
67
 
68
  """
 
69
 
70
+ # Check if city is present in dataframe
71
+ if not _check_city_present(df, starting_point, destination, category):
72
+ logger.info(f"{destination} does not have {category} data")
 
 
 
 
 
 
73
  return None, None
74
 
75
+ match category:
76
+ case "popularity":
77
+ return df[df['city'] == destination]['weighted_pop_score'].item()
78
+ case "seasonality":
79
+ dest_df = df.loc[df['city'] == destination]
80
+ if month:
81
+ m = month.capitalize()[:3]
82
+ else:
83
+ dest_df['lowest_col'] = dest_df.loc[:, dest_df.columns != 'city'].idxmin(axis="columns")
84
+ m = dest_df[dest_df['city'] == destination]['lowest_col'].item()
85
+ return m, dest_df[dest_df['city'] == destination][m].item()
86
+ case "emissions":
87
+ emissions = get_emission_scores(df, starting_point, destination)
88
+ return emissions
89
+
90
+
91
+ def compute_sfairness_score(data: list[pd.DataFrame],
92
+ starting_point: str, destination: str,
93
+ month: Optional[str] = None) -> dict[str, Any] | dict[str, None]:
94
  """
95
 
96
  Returns the s-fairness score for a particular destination city and (optional) month. If the destination doesn't
97
  have popularity or seasonality scores, then the function returns None.
98
 
99
  Args:
100
+ - data: list[pd.DataFrame]
101
+ - starting_point: str
102
  - destination: str
103
  - month: str (default: None)
104
 
105
  """
106
+ popularity_score = get_scores(df=data[0],
107
+ starting_point=None,
108
+ destination=destination, month=None, category="popularity")
109
+ month, seasonality_score = get_scores(df=data[1],
110
+ starting_point=None, destination=destination,
111
+ month=month, category="seasonality")
112
+
113
+ emission_score, mode = get_scores(df=data[2],
114
+ starting_point=starting_point, destination=destination, category="emissions")
115
+ if emission_score is None:
116
+ emission_score = 0
117
 
118
  # RECHECK
119
+ if seasonality_score is not None and popularity_score is not None:
120
+ s_fairness = round(0.281 * emission_score + 0.334 * popularity_score + 0.385 * seasonality_score, 3)
121
  return {
122
  'month': month,
123
+ 'mode': mode, # 'fly', 'drive', 'train'
124
  's-fairness': s_fairness
125
  }
126
  # elif popularity is not None: # => seasonality is None
 
132
  else:
133
  return {
134
  'month': None,
135
+ 'mode': None, # 'fly', 'drive', 'train'
136
  's-fairness': None
137
  }
138
 
139
 
140
+ def test():
141
+ popularity_data = load_data("popularity")
142
+ seasonality_data = load_data("seasonality")
143
+ emissions_data = load_data("emissions")
144
+ data = [popularity_data, seasonality_data, emissions_data]
145
+ print(compute_sfairness_score(data=data, starting_point="Munich", destination="Dijon"))
146
+ print(compute_sfairness_score(data=data, starting_point="Munich", destination="Strasbourg", month="Dec"))
147
+
148
+
149
  if __name__ == "__main__":
150
+ test()
src/text_generation/vertexai_setup.py CHANGED
@@ -21,7 +21,7 @@ def decode_service_key():
21
 
22
  def initialize_vertexai_params(location: Optional[str] = "us-central1"):
23
 
24
- creds_file_name = os.getcwd() + "/.config/application_default_credentials.json"
25
  print(creds_file_name)
26
  if not os.path.exists(os.path.dirname(creds_file_name)):
27
  credentials = decode_service_key()
 
21
 
22
  def initialize_vertexai_params(location: Optional[str] = "us-central1"):
23
 
24
+ creds_file_name = os.getcwd() + "/.config/gcp_default_credentials.json"
25
  print(creds_file_name)
26
  if not os.path.exists(os.path.dirname(creds_file_name)):
27
  credentials = decode_service_key()