File size: 11,371 Bytes
7fc87fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import re
import spacy
from nltk.tokenize import sent_tokenize, word_tokenize
import nltk
nltk.download('punkt_tab')
#import coreferee
import copy
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import cosine_distances
from collections import defaultdict
import numpy as np
#from mtdna_classifier import infer_fromQAModel
# 1. SENTENCE-BERT MODEL
# Step 1: Preprocess the text
def normalize_text(text):
    # Normalize various separators to "-"
    text = re.sub(r'\s*(–+|β€”+|--+>|–>|->|-->|to|β†’|➝|βž”|➑)\s*', '-', text, flags=re.IGNORECASE)
    # Fix GEN10GEN30 β†’ GEN10-GEN30
    text = re.sub(r'\b([a-zA-Z]+)(\d+)(\1)(\d+)\b', r'\1\2-\1\4', text)
    # Fix GEN10-30 β†’ GEN10-GEN30
    text = re.sub(r'\b([a-zA-Z]+)(\d+)-(\d+)\b', r'\1\2-\1\3', text)
    return text

def preprocess_text(text):
    normalized = normalize_text(text)
    sentences = sent_tokenize(normalized)
    return [re.sub(r"[^a-zA-Z0-9\s\-]", "", s).strip() for s in sentences]

# Before step 2, check NLP cache to avoid calling it muliple times:
# Global model cache
_spacy_models = {}

def get_spacy_model(model_name, add_coreferee=False):
    global _spacy_models
    if model_name not in _spacy_models:
        nlp = spacy.load(model_name)
        if add_coreferee and "coreferee" not in nlp.pipe_names:
            nlp.add_pipe("coreferee")
        _spacy_models[model_name] = nlp
    return _spacy_models[model_name]

# Step 2: NER to Extract Locations and Sample Names
def extract_entities(text, sample_id=None):
    nlp = get_spacy_model("en_core_web_sm")
    doc = nlp(text)
    
    # Filter entities by GPE, but exclude things that match sample ID format
    gpe_candidates = [ent.text for ent in doc.ents if ent.label_ == "GPE"]
    
    # Remove entries that match SAMPLE ID patterns like XXX123 or similar
    gpe_filtered = [gpe for gpe in gpe_candidates if not re.fullmatch(r'[A-Z]{2,5}\d{2,4}', gpe.strip())]
    
    # Optional: further filter known invalid patterns (e.g., things shorter than 3 chars, numeric only)
    gpe_filtered = [gpe for gpe in gpe_filtered if len(gpe) > 2 and not gpe.strip().isdigit()]
    
    if sample_id is None:
        return list(set(gpe_filtered)), []
    else:
        sample_prefix = re.match(r'[A-Z]+', sample_id).group()
        samples = re.findall(rf'{sample_prefix}\d+', text)
        return list(set(gpe_filtered)), list(set(samples))

# Step 3: Build a Soft Matching Layer
# Handle patterns like "BRU1–BRU20" and identify BRU18 as part of it.
def is_sample_in_range(sample_id, sentence):
    # Match prefix up to digits
    sample_prefix_match = re.match(r'^([A-Z0-9]+?)(?=\d+$)', sample_id)
    sample_number_match = re.search(r'(\d+)$', sample_id)

    if not sample_prefix_match or not sample_number_match:
        return False

    sample_prefix = sample_prefix_match.group(1)
    sample_number = int(sample_number_match.group(1))
    sentence = normalize_text(sentence)
    # Case 1: Full prefix on both sides
    pattern1 = rf'{sample_prefix}(\d+)\s*-\s*{sample_prefix}(\d+)'
    for match in re.findall(pattern1, sentence):
        start, end = int(match[0]), int(match[1])
        if start <= sample_number <= end:
            return True

    # Case 2: Prefix only on first number
    pattern2 = rf'{sample_prefix}(\d+)\s*-\s*(\d+)'
    for match in re.findall(pattern2, sentence):
        start, end = int(match[0]), int(match[1])
        if start <= sample_number <= end:
            return True

    return False

# Step 4: Use coreferree to merge the sentences have same coreference # still cannot cause packages conflict
# ========== HEURISTIC GROUP β†’ LOCATION MAPPERS ==========
# === Generalized version to replace your old extract_sample_to_group_general ===
# === Generalized version to replace your old extract_group_to_location_general ===
def extract_population_locations(text):
    text = normalize_text(text)
    pattern = r'([A-Za-z ,\-]+)\n([A-Z]+\d*)\n([A-Za-z ,\-]+)\n([A-Za-z ,\-]+)'
    pop_to_location = {}

    for match in re.finditer(pattern, text, flags=re.IGNORECASE):
        _, pop_code, region, country = match.groups()
        pop_to_location[pop_code.upper()] = f"{region.strip()}\n{country.strip()}"

    return pop_to_location
 
def extract_sample_ranges(text):
    text = normalize_text(text)
    # Updated pattern to handle punctuation and line breaks
    pattern = r'\b([A-Z0-9]+\d+)[–\-]([A-Z0-9]+\d+)[,:\.\s]*([A-Z0-9]+\d+)\b'
    sample_to_pop = {}
    for match in re.finditer(pattern, text, flags=re.IGNORECASE):
        start_id, end_id, pop_code = match.groups()
        start_prefix = re.match(r'^([A-Z0-9]+?)(?=\d+$)', start_id, re.IGNORECASE).group(1).upper()
        end_prefix = re.match(r'^([A-Z0-9]+?)(?=\d+$)', end_id, re.IGNORECASE).group(1).upper()
        if start_prefix != end_prefix:
            continue
        start_num = int(re.search(r'(\d+)$', start_id).group())
        end_num = int(re.search(r'(\d+)$', end_id).group())
        for i in range(start_num, end_num + 1):
            sample_id = f"{start_prefix}{i:03d}"
            sample_to_pop[sample_id] = pop_code.upper()

    return sample_to_pop

def filter_context_for_sample(sample_id, full_text, window_size=2):

    # Normalize and tokenize
    full_text = normalize_text(full_text)
    sentences = sent_tokenize(full_text)

    # Step 1: Find indices with direct mention or range match
    match_indices = [
        i for i, s in enumerate(sentences)
        if sample_id in s or is_sample_in_range(sample_id, s)
    ]

    # Step 2: Get sample β†’ group mapping from full text
    sample_to_group = extract_sample_ranges(full_text)
    group_id = sample_to_group.get(sample_id)

    # Step 3: Find group-related sentences
    group_indices = []
    if group_id:
        for i, s in enumerate(sentences):
            if group_id in s:
                group_indices.append(i)

    # Step 4: Collect sentences within window
    selected_indices = set()
    if len(match_indices + group_indices) > 0:
      for i in match_indices + group_indices:
          start = max(0, i - window_size)
          end = min(len(sentences), i + window_size + 1)
          selected_indices.update(range(start, end))

      filtered_sentences = [sentences[i] for i in sorted(selected_indices)]
      return " ".join(filtered_sentences) 
    return full_text  
# Load the SpaCy transformer model with coreferee
def mergeCorefSen(text):
  sen = preprocess_text(text)
  return sen

# Before step 5 and below, let check transformer cache to avoid calling again
# Global SBERT model cache
_sbert_models = {}

def get_sbert_model(model_name="all-MiniLM-L6-v2"):
    global _sbert_models
    if model_name not in _sbert_models:
        _sbert_models[model_name] = SentenceTransformer(model_name)
    return _sbert_models[model_name]

# Step 5: Sentence-BERT retriever β†’ Find top paragraphs related to keyword.
'''Use sentence transformers to embed the sentence that mentions the sample and

compare it to sentences that mention locations.'''

def find_top_para(sample_id, text,top_k=5):
    sentences = mergeCorefSen(text)
    model = get_sbert_model("all-mpnet-base-v2")
    embeddings = model.encode(sentences, convert_to_tensor=True)

    # Find the sentence that best matches the sample_id
    sample_matches = [s for s in sentences if sample_id in s or is_sample_in_range(sample_id, s)]
    if not sample_matches:
        return [],"No context found for sample"

    sample_embedding = model.encode(sample_matches[0], convert_to_tensor=True)
    cos_scores = util.pytorch_cos_sim(sample_embedding, embeddings)[0]

    # Get top-k most similar sentence indices
    top_indices = cos_scores.argsort(descending=True)[:top_k]
    return top_indices, sentences

# Step 6: DBSCAN to cluster the group of similar paragraphs.
def clusterPara(tokens):
  # Load Sentence-BERT model
  sbert_model = get_sbert_model("all-mpnet-base-v2")
  sentence_embeddings = sbert_model.encode(tokens)

  # Compute cosine distance matrix
  distance_matrix = cosine_distances(sentence_embeddings)

  # DBSCAN clustering
  clustering_model = DBSCAN(eps=0.3, min_samples=1, metric="precomputed")
  cluster_labels = clustering_model.fit_predict(distance_matrix)

  # Group sentences by cluster
  clusters = defaultdict(list)
  cluster_embeddings = defaultdict(list)
  sentence_to_cluster = {}
  for i, label in enumerate(cluster_labels):
    clusters[label].append(tokens[i])
    cluster_embeddings[label].append(sentence_embeddings[i])
    sentence_to_cluster[tokens[i]] = label
  # Compute cluster centroids
  centroids = {
      label: np.mean(embs, axis=0)
      for label, embs in cluster_embeddings.items()
  }
  return clusters, sentence_to_cluster, centroids

def rankSenFromCluster(clusters, sentence_to_cluster, centroids, target_sentence):
  target_cluster = sentence_to_cluster[target_sentence]
  target_centroid = centroids[target_cluster]
  sen_rank = []
  sen_order = list(sentence_to_cluster.keys())
  # Compute distances to other cluster centroids
  dists = []
  for label, centroid in centroids.items():
    dist = cosine_distances([target_centroid], [centroid])[0][0]
    dists.append((label, dist))
  dists.sort(key=lambda x: x[1])  # sort by proximity
  for d in dists:
    cluster = clusters[d[0]]
    for sen in cluster:
      if sen != target_sentence:
        sen_rank.append(sen_order.index(sen))
  return sen_rank
# Step 7: Final Inference Wrapper
def infer_location_for_sample(sample_id, context_text):
    # Go through each of the top sentences in order
    top_indices, sentences = find_top_para(sample_id, context_text,top_k=5)
    if top_indices==[] or sentences == "No context found for sample":
      return "No clear location found in top matches"
    clusters, sentence_to_cluster, centroids = clusterPara(sentences)
    topRankSen_DBSCAN = []
    mostTopSen = ""
    locations = ""
    i = 0
    while len(locations) == 0 or i < len(top_indices):
      # Firstly, start with the top-ranked Sentence-BERT result
      idx = top_indices[i]
      best_sentence = sentences[idx]
      if i == 0:
        mostTopSen = best_sentence
      locations, _ = extract_entities(best_sentence, sample_id)
      if locations:
        return locations
      # If no location, then look for sample overlap in the same DBSCAN cluster
      # Compute distances to other cluster centroids
      if len(topRankSen_DBSCAN)==0 and mostTopSen:
        topRankSen_DBSCAN = rankSenFromCluster(clusters, sentence_to_cluster, centroids, mostTopSen)
      if i >= len(topRankSen_DBSCAN): break
      idx_DBSCAN = topRankSen_DBSCAN[i]
      best_sentence_DBSCAN = sentences[idx_DBSCAN]
      locations, _ = extract_entities(best_sentence, sample_id)
      if locations:
        return locations
      # If no, then backtrack to next best Sentence-BERT sentence (such as 2nd rank sentence), and repeat step 1 and 2 until run out
      i += 1
    # Last resort: LLM (e.g. chatGPT, deepseek, etc.)
    #if len(locations) == 0:
    return "No clear location found in top matches"