Spaces:
Runtime error
Runtime error
Commit
·
7cd3150
1
Parent(s):
b6076c4
Update countriesIdentification.py
Browse files- countriesIdentification.py +3 -58
countriesIdentification.py
CHANGED
@@ -8,18 +8,11 @@ from geotext import GeoText
|
|
8 |
|
9 |
import re
|
10 |
|
11 |
-
from transformers import BertTokenizer, BertModel
|
12 |
-
import torch
|
13 |
-
|
14 |
spacy.cli.download("en_core_web_lg")
|
15 |
|
16 |
# Load the spacy model with GloVe embeddings
|
17 |
nlp = spacy.load("en_core_web_lg")
|
18 |
|
19 |
-
# load the pre-trained BERT tokenizer and model
|
20 |
-
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
21 |
-
model = BertModel.from_pretrained('bert-base-cased')
|
22 |
-
|
23 |
# Load valid city names from geonamescache
|
24 |
gc = geonamescache.GeonamesCache()
|
25 |
|
@@ -267,48 +260,6 @@ def identify_loc_regex(sentence):
|
|
267 |
return regex_locations
|
268 |
|
269 |
|
270 |
-
def identify_loc_embeddings(sentence, countries, cities):
|
271 |
-
"""
|
272 |
-
Identify cities and countries with the BERT pre-trained embeddings matching
|
273 |
-
"""
|
274 |
-
|
275 |
-
embd_locations = []
|
276 |
-
|
277 |
-
# Define a list of country and city names (those are given by the geonamescache library before)
|
278 |
-
countries_cities = countries + cities
|
279 |
-
|
280 |
-
# Concatenate multi-word countries and cities into a single string
|
281 |
-
multiword_countries = [c.replace(' ', '_') for c in countries if ' ' in c]
|
282 |
-
multiword_cities = [c.replace(' ', '_') for c in cities if ' ' in c]
|
283 |
-
countries_cities += multiword_countries + multiword_cities
|
284 |
-
|
285 |
-
# Preprocess the input sentence
|
286 |
-
tokens = tokenizer.tokenize(sentence)
|
287 |
-
input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
|
288 |
-
|
289 |
-
# Get the BERT embeddings for the input sentence
|
290 |
-
with torch.no_grad():
|
291 |
-
embeddings = model(input_ids)[0][0]
|
292 |
-
|
293 |
-
# Find the country and city names in the input sentence
|
294 |
-
for i in range(len(tokens)):
|
295 |
-
token = tokens[i]
|
296 |
-
if token in countries_cities:
|
297 |
-
embd_locations.append(token)
|
298 |
-
else:
|
299 |
-
word_vector = embeddings[i]
|
300 |
-
similarity_scores = torch.nn.functional.cosine_similarity(word_vector.unsqueeze(0), embeddings)
|
301 |
-
similar_tokens = [tokens[j] for j in similarity_scores.argsort(descending=True)[1:6]]
|
302 |
-
for word in similar_tokens:
|
303 |
-
if word in countries_cities and similarity_scores[tokens.index(word)] > 0.5:
|
304 |
-
embd_locations.append(word)
|
305 |
-
|
306 |
-
# Convert back multi-word country and city names to original form
|
307 |
-
embd_locations = [loc.replace('_', ' ') if '_' in loc else loc for loc in embd_locations]
|
308 |
-
|
309 |
-
return embd_locations
|
310 |
-
|
311 |
-
|
312 |
|
313 |
def multiple_country_city_identifications_solve(country_city_dict):
|
314 |
"""
|
@@ -580,19 +531,13 @@ def identify_locations(sentence):
|
|
580 |
# flatten the regex list
|
581 |
locations_flat_2 = list(flatten(locations))
|
582 |
|
583 |
-
# embeddings
|
584 |
-
locations_flat_2.append(identify_loc_embeddings(sentence, countries, cities))
|
585 |
-
|
586 |
-
# flatten the embeddings list
|
587 |
-
locations_flat_3 = list(flatten(locations))
|
588 |
-
|
589 |
# remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
|
590 |
# Lowercase the words and get their unique references using set()
|
591 |
-
loc_unique = set([loc.lower() for loc in
|
592 |
|
593 |
# Create a new list of locations with initial capitalization, removing duplicates
|
594 |
loc_capitalization = list(
|
595 |
-
set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in
|
596 |
|
597 |
# That calculation checks whether there are substrings contained in another string. E.g. for the case of [timor leste, timor], it should remove "timor"
|
598 |
if extra_serco_countries:
|
@@ -705,5 +650,5 @@ def identify_locations(sentence):
|
|
705 |
return (0, "LOCATION", "no_country")
|
706 |
|
707 |
except:
|
708 |
-
# handle the exception if any errors occur while
|
709 |
return (0, "LOCATION", "unknown_error")
|
|
|
8 |
|
9 |
import re
|
10 |
|
|
|
|
|
|
|
11 |
spacy.cli.download("en_core_web_lg")
|
12 |
|
13 |
# Load the spacy model with GloVe embeddings
|
14 |
nlp = spacy.load("en_core_web_lg")
|
15 |
|
|
|
|
|
|
|
|
|
16 |
# Load valid city names from geonamescache
|
17 |
gc = geonamescache.GeonamesCache()
|
18 |
|
|
|
260 |
return regex_locations
|
261 |
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
def multiple_country_city_identifications_solve(country_city_dict):
|
265 |
"""
|
|
|
531 |
# flatten the regex list
|
532 |
locations_flat_2 = list(flatten(locations))
|
533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
# remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
|
535 |
# Lowercase the words and get their unique references using set()
|
536 |
+
loc_unique = set([loc.lower() for loc in locations_flat_2])
|
537 |
|
538 |
# Create a new list of locations with initial capitalization, removing duplicates
|
539 |
loc_capitalization = list(
|
540 |
+
set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in locations_flat_2]))
|
541 |
|
542 |
# That calculation checks whether there are substrings contained in another string. E.g. for the case of [timor leste, timor], it should remove "timor"
|
543 |
if extra_serco_countries:
|
|
|
650 |
return (0, "LOCATION", "no_country")
|
651 |
|
652 |
except:
|
653 |
+
# handle the exception if any errors occur while identifying a country/city
|
654 |
return (0, "LOCATION", "unknown_error")
|