Spaces:
Sleeping
Sleeping
File size: 9,053 Bytes
4b722ec 89cd5d5 4b722ec 89cd5d5 4b722ec ac20456 4b722ec 89cd5d5 4b722ec 89cd5d5 4b722ec 89cd5d5 4b722ec 89cd5d5 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec 89cd5d5 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec ac20456 4b722ec 89cd5d5 4b722ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import sys
import re
import os
import json
from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data
sys.path.append("../")
from src.vectordb.search import search_wikivoyage_listings, search_wikivoyage_docs
from src.sustainability import s_fairness
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
from src.helpers.data_loaders import load_scores
def get_travel_months(query):
"""
Function to parse the user's query and search if month of travel has been provided by the user.
Args:
- query: str
"""
months = [
"January", "February", "March", "April", "May", "June",
"July", "August", "September", "October", "November", "December"
]
seasons = {
"spring": ["March", "April", "May"],
"summer": ["June", "July", "August"],
"fall": ["September", "October", "November"],
"autumn": ["September", "October", "November"],
"winter": ["December", "January", "February"]
}
months_in_query = []
for month in months:
if re.search(r'\b' + month + r'\b', query, re.IGNORECASE):
months_in_query.append(month)
# Check for seasons in the query
for season, season_months in seasons.items():
if re.search(r'\b' + season + r'\b', query, re.IGNORECASE):
months_in_query += season_months
# Return None if neither months nor seasons are found
return months_in_query
def get_wikivoyage_context(query, limit=10, reranking=0):
"""
Function to retrieve the relevant documents and listings from the wikivoyage database. Works in two steps:
(i) the relevant cities are returned by the wikivoyage_docs table and (ii) then passed on to the wikivoyage listings database to retrieve further information.
The user can pass a limit of how many results the search should return as well as whether to perform reranking (uses a CrossEncoderReranker)
Args:
- query: str
- limit: int
- reranking: bool
"""
# limit = params['limit']
# reranking = params['reranking']
docs = search_wikivoyage_docs(query, limit, reranking)
logger.info("Finished getting chunked wikivoyage docs.")
results = {}
for doc in docs:
results[doc['city']] = {key: value for key, value in doc.items() if key != 'city'}
results[doc['city']]['listings'] = []
cities = [result['city'] for result in docs]
listings = search_wikivoyage_listings(query, cities, limit, reranking)
logger.info("Finished getting wikivoyage listings.")
# logger.info(type(docs), type(listings))
for listing in listings:
# logger.info(listing['city'])
results[listing['city']]['listings'].append({
'type': listing['type'],
'name': listing['title'],
'description': listing['description']
})
logger.info("Returning retrieval results.")
return results
def get_sustainability_scores(starting_point: str, query: str, destinations: list):
"""
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
If multiple months are provided (or season), then the month with the minimum s-fairness score is chosen for the city.
Args:
- query: str
- destinations: list
"""
result = [] # list of dicts of the format {city: <city>, month: <month>, }
city_scores = {}
months = get_travel_months(query)
logger.info("Finished parsing query for months.")
popularity_data = load_scores("popularity")
seasonality_data = load_scores("seasonality")
emissions_data = load_scores("emissions")
data = [popularity_data, seasonality_data, emissions_data]
for city in destinations:
if city not in city_scores:
city_scores[city] = []
if not months: # no month(s) or seasons provided by the user
city_scores[city].append(s_fairness.compute_sfairness_score(data, starting_point, city))
else:
for month in months:
city_scores[city].append(s_fairness.compute_sfairness_score(data, city, month))
logger.info("Finished getting s-fairness scores.")
for city, scores in city_scores.items():
no_result = 0
for score in scores:
if not score['month']:
no_result = 1
result.append({
'city': city,
'month': 'No data available',
's-fairness': 'No data available',
'mode': 'No data available'
})
break
if not no_result:
min_score = min(scores, key=lambda x: x['s-fairness'])
result.append({
'city': city,
'month': min_score['month'],
's-fairness': min_score['s-fairness'],
'mode': min_score['mode'],
})
logger.info("Returning s-fairness results.")
return result
def get_cities(context: dict):
"""
Only to be used for testing! Function that returns a list of cities with their s-fairness scores, provided the retrieved context
Args:
- context: dict
"""
recommended_cities = []
info = context[list(context.keys())[0]]
for city, info in context.items():
city_info = {
'city': city,
'country': info['country']
}
if "sustainability" in info:
city_info['month'] = info['sustainability']['month']
city_info['s-fairness'] = info['sustainability']['s-fairness']
recommended_cities.append(city_info)
if "sustainability" in info:
def get_s_fairness_value(item):
s_fairness = item['s-fairness']
if s_fairness == 'No data available':
return float('inf') # Assign a high value for "No data available"
return s_fairness
# Sort the list using the custom key
sorted_cities = sorted(recommended_cities, key=get_s_fairness_value)
return sorted_cities
else:
return recommended_cities
def get_context(starting_point: str, query: str, **params):
"""
Function that returns all the context: from the database, as well as the respective s-fairness scores for the
destinations. The default does not consider S-Fairness scores, i.e. to append sustainability scores, a non-zero
parameter "sustainability" needs to be explicitly passed to params.
Args:
- query: str
- params: dict; contains value of the limit and reranking (and sustainability)
"""
limit = 3
reranking = 1
if 'limit' in params:
limit = params['limit']
if 'reranking' in params:
reranking = params['reranking']
wikivoyage_context = get_wikivoyage_context(query, limit, reranking)
recommended_cities = wikivoyage_context.keys()
if 'sustainability' in params and params['sustainability']:
s_fairness_scores = get_sustainability_scores(starting_point, query, recommended_cities)
for score in s_fairness_scores:
wikivoyage_context[score['city']]['sustainability'] = {
'month': score['month'],
's-fairness': score['s-fairness'],
'transport': score['mode']
}
return wikivoyage_context
def test():
queries = []
query = "Suggest some places to visit during winter. I like hiking, nature and the mountains and I enjoy skiing " \
"in winter. "
starting_point = "Munich"
context = None
try:
context = get_context(starting_point, query, sustainability=1)
# cities = get_cities(context)
# print(cities)
except FileNotFoundError as e:
try:
create_wikivoyage_docs_db_and_add_data()
create_wikivoyage_listings_db_and_add_data()
try:
context = get_context(query, sustainability=1)
# cities = get_cities(context)
# print(cities)
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
except Exception as e:
logger.error(f"Error while creating DB: {e}")
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
file_path = os.path.join(os.getcwd(), "test_results", "test_result.json")
with open(file_path, 'w') as file:
json.dump(context, file)
return context
if __name__ == "__main__":
context = test()
print(context)
|