ValadisCERTH commited on
Commit
2ae3a4d
·
1 Parent(s): 9186715

Create helper.py

Browse files
Files changed (1) hide show
  1. helper.py +295 -0
helper.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+
3
+ from geopy.geocoders import Nominatim
4
+ import geonamescache
5
+ import pycountry
6
+
7
+ from geotext import GeoText
8
+
9
+ import re
10
+
11
+ from transformers import BertTokenizer, BertModel
12
+ import torch
13
+
14
+
15
+ # initial loads
16
+
17
+ # load the spacy model
18
+ spacy.cli.download("en_core_web_lg")
19
+ nlp = spacy.load("en_core_web_lg")
20
+
21
+ # load the pre-trained BERT tokenizer and model
22
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
23
+ model = BertModel.from_pretrained('bert-base-uncased')
24
+
25
+ # Load valid city names from geonamescache
26
+ gc = geonamescache.GeonamesCache()
27
+ city_names = set([city['name'] for city in gc.get_cities().values()])
28
+
29
+
30
+ def flatten(lst):
31
+ """
32
+ Define a helper function to flatten the list recursively
33
+ """
34
+
35
+ for item in lst:
36
+ if isinstance(item, list):
37
+ yield from flatten(item)
38
+ else:
39
+ yield item
40
+
41
+
42
+ def is_country(reference):
43
+ """
44
+ Check if a given reference is a valid country name
45
+ """
46
+
47
+ try:
48
+ # use the pycountry library to verify if an input is a country
49
+ country = pycountry.countries.search_fuzzy(reference)[0]
50
+ return True
51
+ except LookupError:
52
+ return False
53
+
54
+
55
+ def is_city(reference):
56
+ """
57
+ Check if the given reference is a valid city name
58
+ """
59
+
60
+ # Check if the reference is a valid city name
61
+ if reference in city_names:
62
+ return True
63
+
64
+ # Load the Nomatim (open street maps) api
65
+ geolocator = Nominatim(user_agent="certh_serco_validate_city_app")
66
+ location = geolocator.geocode(reference, language="en")
67
+
68
+ # If a reference is identified as a 'city', 'town', or 'village', then it is indeed a city
69
+ if location.raw['type'] in ['city', 'town', 'village']:
70
+ return True
71
+
72
+ # If a reference is identified as 'administrative' (e.g. administrative area),
73
+ # then we further examine if the retrieved info is a single token (meaning a country) or a series of tokens (meaning a city)
74
+ # that condition takes place to separate some cases where small cities were identified as administrative areas
75
+ elif location.raw['type'] == 'administrative':
76
+ if len(location.raw['display_name'].split(",")) > 1:
77
+ return True
78
+
79
+ return False
80
+
81
+
82
+ def validate_locations(locations):
83
+ """
84
+ Validate that the identified references are indeed a Country and a City
85
+ """
86
+
87
+ validated_loc = []
88
+
89
+ for location in locations:
90
+ if is_city(location):
91
+ validated_loc.append((location, 'city'))
92
+ elif is_country(location):
93
+ validated_loc.append((location, 'country'))
94
+ else:
95
+ # Check if the location is a multi-word name
96
+ words = location.split()
97
+ if len(words) > 1:
98
+ # Try to find the country or city name among the words
99
+ for i in range(len(words)):
100
+ name = ' '.join(words[i:])
101
+ if is_country(name):
102
+ validated_loc.append((name, 'country'))
103
+ break
104
+ elif is_city(name):
105
+ validated_loc.append((name, 'city'))
106
+ break
107
+
108
+ return validated_loc
109
+
110
+
111
+ def identify_loc_ner(sentence):
112
+ """
113
+ Identify all the geopolitical and location entities with the spacy tool
114
+ """
115
+
116
+ doc = nlp(sentence)
117
+
118
+ ner_locations = []
119
+
120
+ # GPE and LOC are the labels for location entities in spaCy
121
+ for ent in doc.ents:
122
+ if ent.label_ in ['GPE', 'LOC']:
123
+ if len(ent.text.split()) > 1:
124
+ ner_locations.append(ent.text)
125
+ else:
126
+ for token in ent:
127
+ if token.ent_type_ == 'GPE':
128
+ ner_locations.append(ent.text)
129
+ break
130
+
131
+ return ner_locations
132
+
133
+
134
+ def identify_loc_geoparselibs(sentence):
135
+ """
136
+ Identify cities and countries with 3 different geoparsing libraries
137
+ """
138
+
139
+ geoparse_locations = []
140
+
141
+ # Geoparsing library 1
142
+
143
+ # Load geonames cache to check if a city name is valid
144
+ gc = geonamescache.GeonamesCache()
145
+
146
+ # Get a list of many countries/cities
147
+ countries = gc.get_countries()
148
+ cities = gc.get_cities()
149
+
150
+ city_names = [city['name'] for city in cities.values()]
151
+ country_names = [country['name'] for country in countries.values()]
152
+
153
+ # if any word sequence in our sentence is one of those countries/cities identify it
154
+ words = sentence.split()
155
+ for i in range(len(words)):
156
+ for j in range(i+1, len(words)+1):
157
+ word_seq = ' '.join(words[i:j])
158
+ if word_seq in city_names or word_seq in country_names:
159
+ geoparse_locations.append(word_seq)
160
+
161
+ # Geoparsing library 2
162
+
163
+ # similarly with the pycountry library
164
+ for country in pycountry.countries:
165
+ if country.name in sentence:
166
+ geoparse_locations.append(country.name)
167
+
168
+ # Geoparsing library 3
169
+
170
+ # similarly with the geotext library
171
+ places = GeoText(sentence)
172
+ cities = list(places.cities)
173
+ countries = list(places.countries)
174
+
175
+ if cities:
176
+ geoparse_locations += cities
177
+ if countries:
178
+ geoparse_locations += countries
179
+
180
+ return (geoparse_locations, countries, cities)
181
+
182
+
183
+ def identify_loc_regex(sentence):
184
+ """
185
+ Identify cities and countries with regular expression matching
186
+ """
187
+
188
+ regex_locations = []
189
+
190
+ # Country references can be preceded by 'in', 'from' or 'of'
191
+ pattern = r"\b(in|from|of)\b\s([\w\s]+)"
192
+ additional_refs = re.findall(pattern, sentence)
193
+
194
+ for match in additional_refs:
195
+ regex_locations.append(match[1])
196
+
197
+ return regex_locations
198
+
199
+
200
+ def identify_loc_embeddings(sentence, countries, cities):
201
+ """
202
+ Identify cities and countries with the BERT pre-trained embeddings matching
203
+ """
204
+
205
+ embd_locations = []
206
+
207
+ # Define a list of country and city names (those are given by the geonamescache library before)
208
+ countries_cities = countries + cities
209
+
210
+ # Concatenate multi-word countries and cities into a single string
211
+ multiword_countries = [c.replace(' ', '_') for c in countries if ' ' in c]
212
+ multiword_cities = [c.replace(' ', '_') for c in cities if ' ' in c]
213
+ countries_cities += multiword_countries + multiword_cities
214
+
215
+ # Preprocess the input sentence
216
+ tokens = tokenizer.tokenize(sentence)
217
+ input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
218
+
219
+ # Get the BERT embeddings for the input sentence
220
+ with torch.no_grad():
221
+ embeddings = model(input_ids)[0][0]
222
+
223
+ # Find the country and city names in the input sentence
224
+ for i in range(len(tokens)):
225
+ token = tokens[i]
226
+ if token in countries_cities:
227
+ embd_locations.append(token)
228
+ else:
229
+ word_vector = embeddings[i]
230
+ similarity_scores = torch.nn.functional.cosine_similarity(word_vector.unsqueeze(0), embeddings)
231
+ similar_tokens = [tokens[j] for j in similarity_scores.argsort(descending=True)[1:6]]
232
+ for word in similar_tokens:
233
+ if word in countries_cities and similarity_scores[tokens.index(word)] > 0.5:
234
+ embd_locations.append(word)
235
+
236
+ # Convert back multi-word country and city names to original form
237
+ embd_locations = [loc.replace('_', ' ') if '_' in loc else loc for loc in embd_locations]
238
+
239
+ return embd_locations
240
+
241
+
242
+ def identify_locations(sentence):
243
+ """
244
+ Identify all the possible Country and City references in the given sentence, using different approaches in a hybrid manner
245
+ """
246
+
247
+ locations = []
248
+
249
+ # add all the identified country/cities results in a list
250
+
251
+ try:
252
+
253
+ # ner
254
+ locations.append(identify_loc_ner(sentence))
255
+
256
+ # geoparse libs
257
+ geoparse_list, countries, cities = identify_loc_geoparselibs(sentence)
258
+ locations.append(geoparse_list)
259
+
260
+ # flatten the geoparse list
261
+ locations_flat_1 = list(flatten(locations))
262
+
263
+ # regex
264
+ locations_flat_1.append(identify_loc_regex(sentence))
265
+
266
+ # flatten the regex list
267
+ locations_flat_2 = list(flatten(locations))
268
+
269
+ # embeddings
270
+ locations_flat_2.append(identify_loc_embeddings(sentence, countries, cities))
271
+
272
+ # flatten the embeddings list
273
+ locations_flat_3 = list(flatten(locations))
274
+
275
+ # acquire the unique country/city names (because it is possible that many different approaches will capture the same countries/cities)
276
+ flat_loc_list = set(locations_flat_3)
277
+
278
+ # validate that indeed each one of the countries/cities are indeed countries/cities
279
+ validated_locations = validate_locations(flat_loc_list)
280
+
281
+ # create a proper dictionary with country/city tags and the relevant entries as a result
282
+ locations_dict = {}
283
+
284
+ for location, loc_type in validated_locations:
285
+ if loc_type not in locations_dict:
286
+ locations_dict[loc_type] = []
287
+ locations_dict[loc_type].append(location)
288
+
289
+ return locations_dict
290
+
291
+ except:
292
+
293
+ # handle the exception if any errors occur while identifying a country/city
294
+ print(f"An error occurred while checking if a city or country exists")
295
+ return ""