ValadisCERTH commited on
Commit
7cd3150
·
1 Parent(s): b6076c4

Update countriesIdentification.py

Browse files
Files changed (1) hide show
  1. countriesIdentification.py +3 -58
countriesIdentification.py CHANGED
@@ -8,18 +8,11 @@ from geotext import GeoText
8
 
9
  import re
10
 
11
- from transformers import BertTokenizer, BertModel
12
- import torch
13
-
14
  spacy.cli.download("en_core_web_lg")
15
 
16
  # Load the spacy model with GloVe embeddings
17
  nlp = spacy.load("en_core_web_lg")
18
 
19
- # load the pre-trained BERT tokenizer and model
20
- tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
21
- model = BertModel.from_pretrained('bert-base-cased')
22
-
23
  # Load valid city names from geonamescache
24
  gc = geonamescache.GeonamesCache()
25
 
@@ -267,48 +260,6 @@ def identify_loc_regex(sentence):
267
  return regex_locations
268
 
269
 
270
- def identify_loc_embeddings(sentence, countries, cities):
271
- """
272
- Identify cities and countries with the BERT pre-trained embeddings matching
273
- """
274
-
275
- embd_locations = []
276
-
277
- # Define a list of country and city names (those are given by the geonamescache library before)
278
- countries_cities = countries + cities
279
-
280
- # Concatenate multi-word countries and cities into a single string
281
- multiword_countries = [c.replace(' ', '_') for c in countries if ' ' in c]
282
- multiword_cities = [c.replace(' ', '_') for c in cities if ' ' in c]
283
- countries_cities += multiword_countries + multiword_cities
284
-
285
- # Preprocess the input sentence
286
- tokens = tokenizer.tokenize(sentence)
287
- input_ids = torch.tensor([tokenizer.convert_tokens_to_ids(tokens)])
288
-
289
- # Get the BERT embeddings for the input sentence
290
- with torch.no_grad():
291
- embeddings = model(input_ids)[0][0]
292
-
293
- # Find the country and city names in the input sentence
294
- for i in range(len(tokens)):
295
- token = tokens[i]
296
- if token in countries_cities:
297
- embd_locations.append(token)
298
- else:
299
- word_vector = embeddings[i]
300
- similarity_scores = torch.nn.functional.cosine_similarity(word_vector.unsqueeze(0), embeddings)
301
- similar_tokens = [tokens[j] for j in similarity_scores.argsort(descending=True)[1:6]]
302
- for word in similar_tokens:
303
- if word in countries_cities and similarity_scores[tokens.index(word)] > 0.5:
304
- embd_locations.append(word)
305
-
306
- # Convert back multi-word country and city names to original form
307
- embd_locations = [loc.replace('_', ' ') if '_' in loc else loc for loc in embd_locations]
308
-
309
- return embd_locations
310
-
311
-
312
 
313
  def multiple_country_city_identifications_solve(country_city_dict):
314
  """
@@ -580,19 +531,13 @@ def identify_locations(sentence):
580
  # flatten the regex list
581
  locations_flat_2 = list(flatten(locations))
582
 
583
- # embeddings
584
- locations_flat_2.append(identify_loc_embeddings(sentence, countries, cities))
585
-
586
- # flatten the embeddings list
587
- locations_flat_3 = list(flatten(locations))
588
-
589
  # remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
590
  # Lowercase the words and get their unique references using set()
591
- loc_unique = set([loc.lower() for loc in locations_flat_3])
592
 
593
  # Create a new list of locations with initial capitalization, removing duplicates
594
  loc_capitalization = list(
595
- set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in locations_flat_3]))
596
 
597
  # That calculation checks whether there are substrings contained in another string. E.g. for the case of [timor leste, timor], it should remove "timor"
598
  if extra_serco_countries:
@@ -705,5 +650,5 @@ def identify_locations(sentence):
705
  return (0, "LOCATION", "no_country")
706
 
707
  except:
708
- # handle the exception if any errors occur while identifying a country/city
709
  return (0, "LOCATION", "unknown_error")
 
8
 
9
  import re
10
 
 
 
 
11
  spacy.cli.download("en_core_web_lg")
12
 
13
  # Load the spacy model with GloVe embeddings
14
  nlp = spacy.load("en_core_web_lg")
15
 
 
 
 
 
16
  # Load valid city names from geonamescache
17
  gc = geonamescache.GeonamesCache()
18
 
 
260
  return regex_locations
261
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def multiple_country_city_identifications_solve(country_city_dict):
265
  """
 
531
  # flatten the regex list
532
  locations_flat_2 = list(flatten(locations))
533
 
 
 
 
 
 
 
534
  # remove duplicates while also taking under consideration capitalization (e.g. a reference of italy should be valid, while also a reference of Italy and italy)
535
  # Lowercase the words and get their unique references using set()
536
+ loc_unique = set([loc.lower() for loc in locations_flat_2])
537
 
538
  # Create a new list of locations with initial capitalization, removing duplicates
539
  loc_capitalization = list(
540
+ set([loc.capitalize() if loc.lower() in loc_unique else loc.lower() for loc in locations_flat_2]))
541
 
542
  # That calculation checks whether there are substrings contained in another string. E.g. for the case of [timor leste, timor], it should remove "timor"
543
  if extra_serco_countries:
 
650
  return (0, "LOCATION", "no_country")
651
 
652
  except:
653
+ # handle the exception if any errors occur while identifying a country/city
654
  return (0, "LOCATION", "unknown_error")