VyLala commited on
Commit
ee2f501
·
verified ·
1 Parent(s): 1cb3515

Update mtdna_classifier.py

Browse files
Files changed (1) hide show
  1. mtdna_classifier.py +769 -713
mtdna_classifier.py CHANGED
@@ -1,714 +1,770 @@
1
- # mtDNA Location Classifier MVP (Google Colab)
2
- # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
- import os
4
- #import streamlit as st
5
- import subprocess
6
- import re
7
- from Bio import Entrez
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
-
20
- # Set your email (required by NCBI Entrez)
21
- #Entrez.email = "[email protected]"
22
- import nltk
23
-
24
- nltk.download("stopwords")
25
- nltk.download("punkt")
26
- nltk.download('punkt_tab')
27
- # Step 1: Get PubMed ID from Accession using EDirect
28
- from Bio import Entrez, Medline
29
- import re
30
-
31
- Entrez.email = "[email protected]"
32
-
33
- # --- Helper Functions (Re-organized and Upgraded) ---
34
-
35
- def fetch_ncbi_metadata(accession_number):
36
- """
37
- Fetches metadata directly from NCBI GenBank using Entrez.
38
- Includes robust error handling and improved field extraction.
39
- Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
- Also attempts to extract ethnicity and sample_type (ancient/modern).
41
-
42
- Args:
43
- accession_number (str): The NCBI accession number (e.g., "ON792208").
44
-
45
- Returns:
46
- dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
- 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
- """
49
- Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
-
51
- country = "unknown"
52
- specific_location = "unknown"
53
- ethnicity = "unknown"
54
- sample_type = "unknown"
55
- collection_date = "unknown"
56
- isolate = "unknown"
57
- title = "unknown"
58
- doi = "unknown"
59
- pubmed_id = None
60
- all_feature = "unknown"
61
-
62
- KNOWN_COUNTRIES = [
63
- "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
- "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
- "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
- "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
- "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
- "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
- "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
- "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
- "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
- "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
- "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
- "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
- "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
- "Yemen", "Zambia", "Zimbabwe"
77
- ]
78
- COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
-
80
- try:
81
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
- record = Entrez.read(handle)
83
- handle.close()
84
-
85
- gb_seq = None
86
- # Validate record structure: It should be a list with at least one element (a dict)
87
- if isinstance(record, list) and len(record) > 0:
88
- if isinstance(record[0], dict):
89
- gb_seq = record[0]
90
- else:
91
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
- else:
93
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
-
95
- # If gb_seq is still None, return defaults
96
- if gb_seq is None:
97
- return {"country": "unknown",
98
- "specific_location": "unknown",
99
- "ethnicity": "unknown",
100
- "sample_type": "unknown",
101
- "collection_date": "unknown",
102
- "isolate": "unknown",
103
- "title": "unknown",
104
- "doi": "unknown",
105
- "pubmed_id": None,
106
- "all_features": "unknown"}
107
-
108
-
109
- # If gb_seq is valid, proceed with extraction
110
- collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
-
112
- references = gb_seq.get("GBSeq_references", [])
113
- for ref in references:
114
- if not pubmed_id:
115
- pubmed_id = ref.get("GBReference_pubmed",None)
116
- if title == "unknown":
117
- title = ref.get("GBReference_title","unknown")
118
- for xref in ref.get("GBReference_xref", []):
119
- if xref.get("GBXref_dbname") == "doi":
120
- doi = xref.get("GBXref_id")
121
- break
122
-
123
- features = gb_seq.get("GBSeq_feature-table", [])
124
-
125
- context_for_flagging = "" # Accumulate text for ancient/modern detection
126
- features_context = ""
127
- for feature in features:
128
- if feature.get("GBFeature_key") == "source":
129
- feature_context = ""
130
- qualifiers = feature.get("GBFeature_quals", [])
131
- found_country = "unknown"
132
- found_specific_location = "unknown"
133
- found_ethnicity = "unknown"
134
-
135
- temp_geo_loc_name = "unknown"
136
- temp_note_origin_locality = "unknown"
137
- temp_country_qual = "unknown"
138
- temp_locality_qual = "unknown"
139
- temp_collection_location_qual = "unknown"
140
- temp_isolation_source_qual = "unknown"
141
- temp_env_sample_qual = "unknown"
142
- temp_pop_qual = "unknown"
143
- temp_organism_qual = "unknown"
144
- temp_specimen_qual = "unknown"
145
- temp_strain_qual = "unknown"
146
-
147
- for qual in qualifiers:
148
- qual_name = qual.get("GBQualifier_name")
149
- qual_value = qual.get("GBQualifier_value")
150
- feature_context += qual_name + ": " + qual_value +"\n"
151
- if qual_name == "collection_date":
152
- collection_date = qual_value
153
- elif qual_name == "isolate":
154
- isolate = qual_value
155
- elif qual_name == "population":
156
- temp_pop_qual = qual_value
157
- elif qual_name == "organism":
158
- temp_organism_qual = qual_value
159
- elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
- temp_specimen_qual = qual_value
161
- elif qual_name == "strain":
162
- temp_strain_qual = qual_value
163
- elif qual_name == "isolation_source":
164
- temp_isolation_source_qual = qual_value
165
- elif qual_name == "environmental_sample":
166
- temp_env_sample_qual = qual_value
167
-
168
- if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
- elif qual_name == "note":
170
- if qual_value.startswith("origin_locality:"):
171
- temp_note_origin_locality = qual_value
172
- context_for_flagging += qual_value + " " # Capture all notes for flagging
173
- elif qual_name == "country": temp_country_qual = qual_value
174
- elif qual_name == "locality": temp_locality_qual = qual_value
175
- elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
-
177
-
178
- # --- Aggregate all relevant info into context_for_flagging ---
179
- context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
- context_for_flagging = context_for_flagging.strip()
181
-
182
- # --- Determine final country and specific_location based on priority ---
183
- if temp_geo_loc_name != "unknown":
184
- parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
- if len(parts) > 1:
186
- found_specific_location = parts[-1]; found_country = parts[0]
187
- else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
- elif temp_note_origin_locality != "unknown":
189
- match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
- if match:
191
- location_string = match.group(1).strip()
192
- parts = [p.strip() for p in location_string.split(':')]
193
- if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
194
- else: found_country = location_string; found_specific_location = "unknown"
195
- elif temp_locality_qual != "unknown":
196
- found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
197
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
198
- else: found_specific_location = temp_locality_qual; found_country = "unknown"
199
- elif temp_collection_location_qual != "unknown":
200
- found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
201
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
202
- else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
203
- elif temp_isolation_source_qual != "unknown":
204
- found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
205
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
206
- else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
207
- elif temp_env_sample_qual != "unknown":
208
- found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
209
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
210
- else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
211
- if found_country == "unknown" and temp_country_qual != "unknown":
212
- found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
213
- if found_country_match: found_country = found_country_match.group(1)
214
-
215
- country = found_country
216
- specific_location = found_specific_location
217
- # --- Determine final ethnicity ---
218
- if temp_pop_qual != "unknown":
219
- found_ethnicity = temp_pop_qual
220
- elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
221
- found_ethnicity = isolate
222
- elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
223
- eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
224
- if eth_match:
225
- found_ethnicity = eth_match.group(1).strip()
226
-
227
- ethnicity = found_ethnicity
228
-
229
- # --- Determine sample_type (ancient/modern) ---
230
- if context_for_flagging:
231
- sample_type, explain = detect_ancient_flag(context_for_flagging)
232
- features_context += feature_context + "\n"
233
- break
234
-
235
- if specific_location != "unknown" and specific_location.lower() == country.lower():
236
- specific_location = "unknown"
237
- if not features_context: features_context = "unknown"
238
- return {"country": country.lower(),
239
- "specific_location": specific_location.lower(),
240
- "ethnicity": ethnicity.lower(),
241
- "sample_type": sample_type.lower(),
242
- "collection_date": collection_date,
243
- "isolate": isolate,
244
- "title": title,
245
- "doi": doi,
246
- "pubmed_id": pubmed_id,
247
- "all_features": features_context}
248
-
249
- except:
250
- print(f"Error fetching NCBI data for {accession_number}")
251
- return {"country": "unknown",
252
- "specific_location": "unknown",
253
- "ethnicity": "unknown",
254
- "sample_type": "unknown",
255
- "collection_date": "unknown",
256
- "isolate": "unknown",
257
- "title": "unknown",
258
- "doi": "unknown",
259
- "pubmed_id": None,
260
- "all_features": "unknown"}
261
-
262
- # --- Helper function for country matching (re-defined from main code to be self-contained) ---
263
- _country_keywords = {
264
- "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
265
- "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
266
- "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
267
- "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
268
- "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
269
- "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
270
- "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
271
- "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
272
- "central india": "India", "east india": "India", "northeast india": "India",
273
- "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
274
- "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
275
- }
276
-
277
- def get_country_from_text(text):
278
- text_lower = text.lower()
279
- for keyword, country in _country_keywords.items():
280
- if keyword in text_lower:
281
- return country
282
- return "unknown"
283
- # The result will be seen as manualLink for the function get_paper_text
284
- def search_google_custom(query, max_results=3):
285
- # query should be the title from ncbi or paper/source title
286
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
287
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
288
- endpoint = os.environ["SEARCH_ENDPOINT"]
289
- params = {
290
- "key": GOOGLE_CSE_API_KEY,
291
- "cx": GOOGLE_CSE_CX,
292
- "q": query,
293
- "num": max_results
294
- }
295
- try:
296
- response = requests.get(endpoint, params=params)
297
- if response.status_code == 429:
298
- print("Rate limit hit. Try again later.")
299
- return []
300
- response.raise_for_status()
301
- data = response.json().get("items", [])
302
- return [item.get("link") for item in data if item.get("link")]
303
- except Exception as e:
304
- print("Google CSE error:", e)
305
- return []
306
- # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
307
- # Step 3.1: Extract Text
308
- # sub: download excel file
309
- def download_excel_file(url, save_path="temp.xlsx"):
310
- if "view.officeapps.live.com" in url:
311
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
312
- real_url = urllib.parse.unquote(parsed_url["src"][0])
313
- response = requests.get(real_url)
314
- with open(save_path, "wb") as f:
315
- f.write(response.content)
316
- return save_path
317
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
318
- response = requests.get(url)
319
- response.raise_for_status() # Raises error if download fails
320
- with open(save_path, "wb") as f:
321
- f.write(response.content)
322
- return save_path
323
- else:
324
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
325
- return url
326
- def get_paper_text(doi,id,manualLinks=None):
327
- # create the temporary folder to contain the texts
328
- folder_path = Path("data/"+str(id))
329
- if not folder_path.exists():
330
- cmd = f'mkdir data/{id}'
331
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
332
- print("data/"+str(id) +" created.")
333
- else:
334
- print("data/"+str(id) +" already exists.")
335
- saveLinkFolder = "data/"+id
336
-
337
- link = 'https://doi.org/' + doi
338
- '''textsToExtract = { "doiLink":"paperText"
339
- "file1.pdf":"text1",
340
- "file2.doc":"text2",
341
- "file3.xlsx":excelText3'''
342
- textsToExtract = {}
343
- # get the file to create listOfFile for each id
344
- html = extractHTML.HTML("",link)
345
- jsonSM = html.getSupMaterial()
346
- text = ""
347
- links = [link] + sum((jsonSM[key] for key in jsonSM),[])
348
- if manualLinks != None:
349
- links += manualLinks
350
- for l in links:
351
- # get the main paper
352
- name = l.split("/")[-1]
353
- file_path = folder_path / name
354
- if l == link:
355
- text = html.getListSection()
356
- textsToExtract[link] = text
357
- elif l.endswith(".pdf"):
358
- if file_path.is_file():
359
- l = saveLinkFolder + "/" + name
360
- print("File exists.")
361
- p = pdf.PDF(l,saveLinkFolder,doi)
362
- f = p.openPDFFile()
363
- pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
364
- doc = fitz.open(pdf_path)
365
- text = "\n".join([page.get_text() for page in doc])
366
- textsToExtract[l] = text
367
- elif l.endswith(".doc") or l.endswith(".docx"):
368
- d = wordDoc.wordDoc(l,saveLinkFolder)
369
- text = d.extractTextByPage()
370
- textsToExtract[l] = text
371
- elif l.split(".")[-1].lower() in "xlsx":
372
- wc = word2vec.word2Vec()
373
- # download excel file if it not downloaded yet
374
- savePath = saveLinkFolder +"/"+ l.split("/")[-1]
375
- excelPath = download_excel_file(l, savePath)
376
- corpus = wc.tableTransformToCorpusText([],excelPath)
377
- text = ''
378
- for c in corpus:
379
- para = corpus[c]
380
- for words in para:
381
- text += " ".join(words)
382
- textsToExtract[l] = text
383
- # delete folder after finishing getting text
384
- #cmd = f'rm -r data/{id}'
385
- #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
386
- return textsToExtract
387
- # Step 3.2: Extract context
388
- def extract_context(text, keyword, window=500):
389
- # firstly try accession number
390
- idx = text.find(keyword)
391
- if idx == -1:
392
- return "Sample ID not found."
393
- return text[max(0, idx-window): idx+window]
394
- def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
395
- if keep_if is None:
396
- keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
397
-
398
- outputs = ""
399
- text = text.lower()
400
-
401
- # If isolate is provided, prioritize paragraphs that mention it
402
- # If isolate is provided, prioritize paragraphs that mention it
403
- if accession and accession.lower() in text:
404
- if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
405
- outputs += extract_context(text, accession.lower(), window=700)
406
- if isolate and isolate.lower() in text:
407
- if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
408
- outputs += extract_context(text, isolate.lower(), window=700)
409
- for keyword in keep_if:
410
- para = extract_context(text, keyword)
411
- if para and para not in outputs:
412
- outputs += para + "\n"
413
- return outputs
414
- # Step 4: Classification for now (demo purposes)
415
- # 4.1: Using a HuggingFace model (question-answering)
416
- def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
417
- try:
418
- qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
419
- result = qa({"context": context, "question": question})
420
- return result.get("answer", "Unknown")
421
- except Exception as e:
422
- return f"Error: {str(e)}"
423
-
424
- # 4.2: Infer from haplogroup
425
- # Load pre-trained spaCy model for NER
426
- try:
427
- nlp = spacy.load("en_core_web_sm")
428
- except OSError:
429
- download("en_core_web_sm")
430
- nlp = spacy.load("en_core_web_sm")
431
-
432
- # Define the haplogroup-to-region mapping (simple rule-based)
433
- import csv
434
-
435
- def load_haplogroup_mapping(csv_path):
436
- mapping = {}
437
- with open(csv_path) as f:
438
- reader = csv.DictReader(f)
439
- for row in reader:
440
- mapping[row["haplogroup"]] = [row["region"],row["source"]]
441
- return mapping
442
-
443
- # Function to extract haplogroup from the text
444
- def extract_haplogroup(text):
445
- match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
446
- if match:
447
- submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
448
- if submatch:
449
- return submatch.group(0)
450
- else:
451
- return match.group(1) # fallback
452
- fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
453
- if fallback:
454
- return fallback.group(1)
455
- return None
456
-
457
-
458
- # Function to extract location based on NER
459
- def extract_location(text):
460
- doc = nlp(text)
461
- locations = []
462
- for ent in doc.ents:
463
- if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
464
- locations.append(ent.text)
465
- return locations
466
-
467
- # Function to infer location from haplogroup
468
- def infer_location_from_haplogroup(haplogroup):
469
- haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
470
- return haplo_map.get(haplogroup, ["Unknown","Unknown"])
471
-
472
- # Function to classify the mtDNA sample
473
- def classify_mtDNA_sample_from_haplo(text):
474
- # Extract haplogroup
475
- haplogroup = extract_haplogroup(text)
476
- # Extract location based on NER
477
- locations = extract_location(text)
478
- # Infer location based on haplogroup
479
- inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
480
- return {
481
- "source":sourceHaplo,
482
- "locations_found_in_context": locations,
483
- "haplogroup": haplogroup,
484
- "inferred_location": inferred_location
485
-
486
- }
487
- # 4.3 Get from available NCBI
488
- def infer_location_fromNCBI(accession):
489
- try:
490
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
491
- text = handle.read()
492
- handle.close()
493
- match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
494
- if match:
495
- return match.group(2), match.group(0) # This is the value like "Brunei"
496
- return "Not found", "Not found"
497
-
498
- except Exception as e:
499
- print("❌ Entrez error:", e)
500
- return "Not found", "Not found"
501
-
502
- ### ANCIENT/MODERN FLAG
503
- from Bio import Entrez
504
- import re
505
-
506
- def flag_ancient_modern(accession, textsToExtract, isolate=None):
507
- """
508
- Try to classify a sample as Ancient or Modern using:
509
- 1. NCBI accession (if available)
510
- 2. Supplementary text or context fallback
511
- """
512
- context = ""
513
- label, explain = "", ""
514
-
515
- try:
516
- # Check if we can fetch metadata from NCBI using the accession
517
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
518
- text = handle.read()
519
- handle.close()
520
-
521
- isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
522
- if isolate_source:
523
- context += isolate_source.group(0) + " "
524
-
525
- specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
526
- if specimen:
527
- context += specimen.group(0) + " "
528
-
529
- if context.strip():
530
- label, explain = detect_ancient_flag(context)
531
- if label!="Unknown":
532
- return label, explain + " from NCBI\n(" + context + ")"
533
-
534
- # If no useful NCBI metadata, check supplementary texts
535
- if textsToExtract:
536
- labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
537
-
538
- for source in textsToExtract:
539
- text_block = textsToExtract[source]
540
- context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
541
- label, explain = detect_ancient_flag(context)
542
-
543
- if label == "Ancient":
544
- labels["ancient"][0] += 1
545
- labels["ancient"][1] += f"{source}:\n{explain}\n\n"
546
- elif label == "Modern":
547
- labels["modern"][0] += 1
548
- labels["modern"][1] += f"{source}:\n{explain}\n\n"
549
- else:
550
- labels["unknown"] += 1
551
-
552
- if max(labels["modern"][0],labels["ancient"][0]) > 0:
553
- if labels["modern"][0] > labels["ancient"][0]:
554
- return "Modern", labels["modern"][1]
555
- else:
556
- return "Ancient", labels["ancient"][1]
557
- else:
558
- return "Unknown", "No strong keywords detected"
559
- else:
560
- print("No DOI or PubMed ID available for inference.")
561
- return "", ""
562
-
563
- except Exception as e:
564
- print("Error:", e)
565
- return "", ""
566
-
567
-
568
- def detect_ancient_flag(context_snippet):
569
- context = context_snippet.lower()
570
-
571
- ancient_keywords = [
572
- "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
573
- "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
574
- "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
575
- "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
576
- ]
577
-
578
- modern_keywords = [
579
- "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
580
- "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
581
- "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
582
- "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
583
- "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
584
- ]
585
-
586
- ancient_hits = [k for k in ancient_keywords if k in context]
587
- modern_hits = [k for k in modern_keywords if k in context]
588
-
589
- if ancient_hits and not modern_hits:
590
- return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
591
- elif modern_hits and not ancient_hits:
592
- return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
593
- elif ancient_hits and modern_hits:
594
- if len(ancient_hits) >= len(modern_hits):
595
- return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
596
- else:
597
- return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
598
-
599
- # Fallback to QA
600
- answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
601
- if answer.startswith("Error"):
602
- return "Unknown", answer
603
- if "ancient" in answer.lower():
604
- return "Ancient", f"Leaning ancient based on QA: {answer}"
605
- elif "modern" in answer.lower():
606
- return "Modern", f"Leaning modern based on QA: {answer}"
607
- else:
608
- return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
609
-
610
- # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
611
- def classify_sample_location(accession):
612
- outputs = {}
613
- keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
614
- # Step 1: get pubmed id and isolate
615
- pubmedID, isolate = get_info_from_accession(accession)
616
- '''if not pubmedID:
617
- return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
618
- if not isolate:
619
- isolate = "UNKNOWN_ISOLATE"
620
- # Step 2: get doi
621
- doi = get_doi_from_pubmed_id(pubmedID)
622
- '''if not doi:
623
- return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
624
- # Step 3: get text
625
- '''textsToExtract = { "doiLink":"paperText"
626
- "file1.pdf":"text1",
627
- "file2.doc":"text2",
628
- "file3.xlsx":excelText3'''
629
- if doi and pubmedID:
630
- textsToExtract = get_paper_text(doi,pubmedID)
631
- else: textsToExtract = {}
632
- '''if not textsToExtract:
633
- return {"error": f"No texts extracted for DOI {doi}"}'''
634
- if isolate not in [None, "UNKNOWN_ISOLATE"]:
635
- label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
636
- else:
637
- label, explain = flag_ancient_modern(accession,textsToExtract)
638
- # Step 4: prediction
639
- outputs[accession] = {}
640
- outputs[isolate] = {}
641
- # 4.0 Infer from NCBI
642
- location, outputNCBI = infer_location_fromNCBI(accession)
643
- NCBI_result = {
644
- "source": "NCBI",
645
- "sample_id": accession,
646
- "predicted_location": location,
647
- "context_snippet": outputNCBI}
648
- outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
649
- if textsToExtract:
650
- long_text = ""
651
- for key in textsToExtract:
652
- text = textsToExtract[key]
653
- # try accession number first
654
- outputs[accession][key] = {}
655
- keyword = accession
656
- context = extract_context(text, keyword, window=500)
657
- # 4.1: Using a HuggingFace model (question-answering)
658
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
659
- qa_result = {
660
- "source": key,
661
- "sample_id": keyword,
662
- "predicted_location": location,
663
- "context_snippet": context
664
- }
665
- outputs[keyword][key]["QAModel"] = qa_result
666
- # 4.2: Infer from haplogroup
667
- haplo_result = classify_mtDNA_sample_from_haplo(context)
668
- outputs[keyword][key]["haplogroup"] = haplo_result
669
- # try isolate
670
- keyword = isolate
671
- outputs[isolate][key] = {}
672
- context = extract_context(text, keyword, window=500)
673
- # 4.1.1: Using a HuggingFace model (question-answering)
674
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
675
- qa_result = {
676
- "source": key,
677
- "sample_id": keyword,
678
- "predicted_location": location,
679
- "context_snippet": context
680
- }
681
- outputs[keyword][key]["QAModel"] = qa_result
682
- # 4.2.1: Infer from haplogroup
683
- haplo_result = classify_mtDNA_sample_from_haplo(context)
684
- outputs[keyword][key]["haplogroup"] = haplo_result
685
- # add long text
686
- long_text += text + ". \n"
687
- # 4.3: UpgradeClassify
688
- # try sample_id as accession number
689
- sample_id = accession
690
- if sample_id:
691
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
692
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
693
- if locations!="No clear location found in top matches":
694
- outputs[sample_id]["upgradeClassifier"] = {}
695
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
696
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
697
- "sample_id": sample_id,
698
- "predicted_location": ", ".join(locations),
699
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
700
- }
701
- # try sample_id as isolate name
702
- sample_id = isolate
703
- if sample_id:
704
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
705
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
706
- if locations!="No clear location found in top matches":
707
- outputs[sample_id]["upgradeClassifier"] = {}
708
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
709
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
710
- "sample_id": sample_id,
711
- "predicted_location": ", ".join(locations),
712
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
713
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  return outputs, label, explain
 
1
+ # mtDNA Location Classifier MVP (Google Colab)
2
+ # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
+ import os
4
+ #import streamlit as st
5
+ import subprocess
6
+ import re
7
+ from Bio import Entrez
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
+
20
+ # Set your email (required by NCBI Entrez)
21
+ #Entrez.email = "[email protected]"
22
+ import nltk
23
+
24
+ nltk.download("stopwords")
25
+ nltk.download("punkt")
26
+ nltk.download('punkt_tab')
27
+ # Step 1: Get PubMed ID from Accession using EDirect
28
+ from Bio import Entrez, Medline
29
+ import re
30
+
31
+ Entrez.email = "[email protected]"
32
+
33
+ # --- Helper Functions (Re-organized and Upgraded) ---
34
+
35
+ def fetch_ncbi_metadata(accession_number):
36
+ """
37
+ Fetches metadata directly from NCBI GenBank using Entrez.
38
+ Includes robust error handling and improved field extraction.
39
+ Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
+ Also attempts to extract ethnicity and sample_type (ancient/modern).
41
+
42
+ Args:
43
+ accession_number (str): The NCBI accession number (e.g., "ON792208").
44
+
45
+ Returns:
46
+ dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
+ 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
+ """
49
+ Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
+
51
+ country = "unknown"
52
+ specific_location = "unknown"
53
+ ethnicity = "unknown"
54
+ sample_type = "unknown"
55
+ collection_date = "unknown"
56
+ isolate = "unknown"
57
+ title = "unknown"
58
+ doi = "unknown"
59
+ pubmed_id = None
60
+ all_feature = "unknown"
61
+
62
+ KNOWN_COUNTRIES = [
63
+ "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
+ "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
+ "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
+ "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
+ "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
+ "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
+ "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
+ "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
+ "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
+ "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
+ "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
+ "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
+ "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
+ "Yemen", "Zambia", "Zimbabwe"
77
+ ]
78
+ COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
+
80
+ try:
81
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
+ record = Entrez.read(handle)
83
+ handle.close()
84
+
85
+ gb_seq = None
86
+ # Validate record structure: It should be a list with at least one element (a dict)
87
+ if isinstance(record, list) and len(record) > 0:
88
+ if isinstance(record[0], dict):
89
+ gb_seq = record[0]
90
+ else:
91
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
+ else:
93
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
+
95
+ # If gb_seq is still None, return defaults
96
+ if gb_seq is None:
97
+ return {"country": "unknown",
98
+ "specific_location": "unknown",
99
+ "ethnicity": "unknown",
100
+ "sample_type": "unknown",
101
+ "collection_date": "unknown",
102
+ "isolate": "unknown",
103
+ "title": "unknown",
104
+ "doi": "unknown",
105
+ "pubmed_id": None,
106
+ "all_features": "unknown"}
107
+
108
+
109
+ # If gb_seq is valid, proceed with extraction
110
+ collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
+
112
+ references = gb_seq.get("GBSeq_references", [])
113
+ for ref in references:
114
+ if not pubmed_id:
115
+ pubmed_id = ref.get("GBReference_pubmed",None)
116
+ if title == "unknown":
117
+ title = ref.get("GBReference_title","unknown")
118
+ for xref in ref.get("GBReference_xref", []):
119
+ if xref.get("GBXref_dbname") == "doi":
120
+ doi = xref.get("GBXref_id")
121
+ break
122
+
123
+ features = gb_seq.get("GBSeq_feature-table", [])
124
+
125
+ context_for_flagging = "" # Accumulate text for ancient/modern detection
126
+ features_context = ""
127
+ for feature in features:
128
+ if feature.get("GBFeature_key") == "source":
129
+ feature_context = ""
130
+ qualifiers = feature.get("GBFeature_quals", [])
131
+ found_country = "unknown"
132
+ found_specific_location = "unknown"
133
+ found_ethnicity = "unknown"
134
+
135
+ temp_geo_loc_name = "unknown"
136
+ temp_note_origin_locality = "unknown"
137
+ temp_country_qual = "unknown"
138
+ temp_locality_qual = "unknown"
139
+ temp_collection_location_qual = "unknown"
140
+ temp_isolation_source_qual = "unknown"
141
+ temp_env_sample_qual = "unknown"
142
+ temp_pop_qual = "unknown"
143
+ temp_organism_qual = "unknown"
144
+ temp_specimen_qual = "unknown"
145
+ temp_strain_qual = "unknown"
146
+
147
+ for qual in qualifiers:
148
+ qual_name = qual.get("GBQualifier_name")
149
+ qual_value = qual.get("GBQualifier_value")
150
+ feature_context += qual_name + ": " + qual_value +"\n"
151
+ if qual_name == "collection_date":
152
+ collection_date = qual_value
153
+ elif qual_name == "isolate":
154
+ isolate = qual_value
155
+ elif qual_name == "population":
156
+ temp_pop_qual = qual_value
157
+ elif qual_name == "organism":
158
+ temp_organism_qual = qual_value
159
+ elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
+ temp_specimen_qual = qual_value
161
+ elif qual_name == "strain":
162
+ temp_strain_qual = qual_value
163
+ elif qual_name == "isolation_source":
164
+ temp_isolation_source_qual = qual_value
165
+ elif qual_name == "environmental_sample":
166
+ temp_env_sample_qual = qual_value
167
+
168
+ if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
+ elif qual_name == "note":
170
+ if qual_value.startswith("origin_locality:"):
171
+ temp_note_origin_locality = qual_value
172
+ context_for_flagging += qual_value + " " # Capture all notes for flagging
173
+ elif qual_name == "country": temp_country_qual = qual_value
174
+ elif qual_name == "locality": temp_locality_qual = qual_value
175
+ elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
+
177
+
178
+ # --- Aggregate all relevant info into context_for_flagging ---
179
+ context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
+ context_for_flagging = context_for_flagging.strip()
181
+
182
+ # --- Determine final country and specific_location based on priority ---
183
+ if temp_geo_loc_name != "unknown":
184
+ parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
+ if len(parts) > 1:
186
+ found_specific_location = parts[-1]; found_country = parts[0]
187
+ else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
+ elif temp_note_origin_locality != "unknown":
189
+ match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
+ if match:
191
+ location_string = match.group(1).strip()
192
+ parts = [p.strip() for p in location_string.split(':')]
193
+ if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
194
+ else: found_country = location_string; found_specific_location = "unknown"
195
+ elif temp_locality_qual != "unknown":
196
+ found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
197
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
198
+ else: found_specific_location = temp_locality_qual; found_country = "unknown"
199
+ elif temp_collection_location_qual != "unknown":
200
+ found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
201
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
202
+ else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
203
+ elif temp_isolation_source_qual != "unknown":
204
+ found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
205
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
206
+ else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
207
+ elif temp_env_sample_qual != "unknown":
208
+ found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
209
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
210
+ else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
211
+ if found_country == "unknown" and temp_country_qual != "unknown":
212
+ found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
213
+ if found_country_match: found_country = found_country_match.group(1)
214
+
215
+ country = found_country
216
+ specific_location = found_specific_location
217
+ # --- Determine final ethnicity ---
218
+ if temp_pop_qual != "unknown":
219
+ found_ethnicity = temp_pop_qual
220
+ elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
221
+ found_ethnicity = isolate
222
+ elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
223
+ eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
224
+ if eth_match:
225
+ found_ethnicity = eth_match.group(1).strip()
226
+
227
+ ethnicity = found_ethnicity
228
+
229
+ # --- Determine sample_type (ancient/modern) ---
230
+ if context_for_flagging:
231
+ sample_type, explain = detect_ancient_flag(context_for_flagging)
232
+ features_context += feature_context + "\n"
233
+ break
234
+
235
+ if specific_location != "unknown" and specific_location.lower() == country.lower():
236
+ specific_location = "unknown"
237
+ if not features_context: features_context = "unknown"
238
+ return {"country": country.lower(),
239
+ "specific_location": specific_location.lower(),
240
+ "ethnicity": ethnicity.lower(),
241
+ "sample_type": sample_type.lower(),
242
+ "collection_date": collection_date,
243
+ "isolate": isolate,
244
+ "title": title,
245
+ "doi": doi,
246
+ "pubmed_id": pubmed_id,
247
+ "all_features": features_context}
248
+
249
+ except:
250
+ print(f"Error fetching NCBI data for {accession_number}")
251
+ return {"country": "unknown",
252
+ "specific_location": "unknown",
253
+ "ethnicity": "unknown",
254
+ "sample_type": "unknown",
255
+ "collection_date": "unknown",
256
+ "isolate": "unknown",
257
+ "title": "unknown",
258
+ "doi": "unknown",
259
+ "pubmed_id": None,
260
+ "all_features": "unknown"}
261
+
262
+ # --- Helper function for country matching (re-defined from main code to be self-contained) ---
263
+ _country_keywords = {
264
+ "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
265
+ "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
266
+ "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
267
+ "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
268
+ "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
269
+ "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
270
+ "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
271
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
272
+ "central india": "India", "east india": "India", "northeast india": "India",
273
+ "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
274
+ "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
275
+ }
276
+
277
+ def get_country_from_text(text):
278
+ text_lower = text.lower()
279
+ for keyword, country in _country_keywords.items():
280
+ if keyword in text_lower:
281
+ return country
282
+ return "unknown"
283
+ # The result will be seen as manualLink for the function get_paper_text
284
+ # def search_google_custom(query, max_results=3):
285
+ # # query should be the title from ncbi or paper/source title
286
+ # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
287
+ # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
288
+ # endpoint = os.environ["SEARCH_ENDPOINT"]
289
+ # params = {
290
+ # "key": GOOGLE_CSE_API_KEY,
291
+ # "cx": GOOGLE_CSE_CX,
292
+ # "q": query,
293
+ # "num": max_results
294
+ # }
295
+ # try:
296
+ # response = requests.get(endpoint, params=params)
297
+ # if response.status_code == 429:
298
+ # print("Rate limit hit. Try again later.")
299
+ # return []
300
+ # response.raise_for_status()
301
+ # data = response.json().get("items", [])
302
+ # return [item.get("link") for item in data if item.get("link")]
303
+ # except Exception as e:
304
+ # print("Google CSE error:", e)
305
+ # return []
306
+
307
+ def search_google_custom(query, max_results=3):
308
+ # query should be the title from ncbi or paper/source title
309
+ # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
310
+ # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
311
+ # endpoint = os.environ["SEARCH_ENDPOINT"]
312
+ GOOGLE_CSE_API_KEY = os.getenv("GOOGLE_CSE_API_KEY", "AIzaSyAg_Hi5DPit2bvvwCs1PpUkAPRZun7yCRQ") # account: [email protected]
313
+ GOOGLE_CSE_CX = os.getenv("GOOGLE_CSE_CX", "25a51c433f148490c")
314
+ endpoint = "https://www.googleapis.com/customsearch/v1"
315
+ params = {
316
+ "key": GOOGLE_CSE_API_KEY,
317
+ "cx": GOOGLE_CSE_CX,
318
+ "q": query,
319
+ "num": max_results
320
+ }
321
+ try:
322
+ response = requests.get(endpoint, params=params)
323
+ if response.status_code == 429:
324
+ print("Rate limit hit. Try again later.")
325
+ print("try with back up account")
326
+ try:
327
+ return search_google_custom_backup(query, max_results)
328
+ except:
329
+ return []
330
+ response.raise_for_status()
331
+ data = response.json().get("items", [])
332
+ return [item.get("link") for item in data if item.get("link")]
333
+ except Exception as e:
334
+ print("Google CSE error:", e)
335
+ return []
336
+
337
+ def search_google_custom_backup(query, max_results=3):
338
+ # query should be the title from ncbi or paper/source title
339
+ # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
340
+ # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
341
+ # endpoint = os.environ["SEARCH_ENDPOINT"]
342
+ GOOGLE_CSE_API_KEY = os.getenv("GOOGLE_CSE_API_KEY", "AIzaSyBDkTo3QSAUHEPSBaWq5fX9Be4l-2EhAUM") # account: [email protected]
343
+ GOOGLE_CSE_CX = os.getenv("GOOGLE_CSE_CX", "00231c463e9464bdc")
344
+ endpoint = "https://www.googleapis.com/customsearch/v1"
345
+ params = {
346
+ "key": GOOGLE_CSE_API_KEY,
347
+ "cx": GOOGLE_CSE_CX,
348
+ "q": query,
349
+ "num": max_results
350
+ }
351
+ try:
352
+ response = requests.get(endpoint, params=params)
353
+ if response.status_code == 429:
354
+ print("Rate limit hit. Try again later.")
355
+ return []
356
+ response.raise_for_status()
357
+ data = response.json().get("items", [])
358
+ return [item.get("link") for item in data if item.get("link")]
359
+ except Exception as e:
360
+ print("Google CSE error:", e)
361
+ return []
362
+ # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
363
+ # Step 3.1: Extract Text
364
+ # sub: download excel file
365
+ def download_excel_file(url, save_path="temp.xlsx"):
366
+ if "view.officeapps.live.com" in url:
367
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
368
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
369
+ response = requests.get(real_url)
370
+ with open(save_path, "wb") as f:
371
+ f.write(response.content)
372
+ return save_path
373
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
374
+ response = requests.get(url)
375
+ response.raise_for_status() # Raises error if download fails
376
+ with open(save_path, "wb") as f:
377
+ f.write(response.content)
378
+ return save_path
379
+ else:
380
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
381
+ return url
382
+ def get_paper_text(doi,id,manualLinks=None):
383
+ # create the temporary folder to contain the texts
384
+ folder_path = Path("data/"+str(id))
385
+ if not folder_path.exists():
386
+ cmd = f'mkdir data/{id}'
387
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
388
+ print("data/"+str(id) +" created.")
389
+ else:
390
+ print("data/"+str(id) +" already exists.")
391
+ saveLinkFolder = "data/"+id
392
+
393
+ link = 'https://doi.org/' + doi
394
+ '''textsToExtract = { "doiLink":"paperText"
395
+ "file1.pdf":"text1",
396
+ "file2.doc":"text2",
397
+ "file3.xlsx":excelText3'''
398
+ textsToExtract = {}
399
+ # get the file to create listOfFile for each id
400
+ html = extractHTML.HTML("",link)
401
+ jsonSM = html.getSupMaterial()
402
+ text = ""
403
+ links = [link] + sum((jsonSM[key] for key in jsonSM),[])
404
+ if manualLinks != None:
405
+ links += manualLinks
406
+ for l in links:
407
+ # get the main paper
408
+ name = l.split("/")[-1]
409
+ file_path = folder_path / name
410
+ if l == link:
411
+ text = html.getListSection()
412
+ textsToExtract[link] = text
413
+ elif l.endswith(".pdf"):
414
+ if file_path.is_file():
415
+ l = saveLinkFolder + "/" + name
416
+ print("File exists.")
417
+ p = pdf.PDF(l,saveLinkFolder,doi)
418
+ f = p.openPDFFile()
419
+ pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
420
+ doc = fitz.open(pdf_path)
421
+ text = "\n".join([page.get_text() for page in doc])
422
+ textsToExtract[l] = text
423
+ elif l.endswith(".doc") or l.endswith(".docx"):
424
+ d = wordDoc.wordDoc(l,saveLinkFolder)
425
+ text = d.extractTextByPage()
426
+ textsToExtract[l] = text
427
+ elif l.split(".")[-1].lower() in "xlsx":
428
+ wc = word2vec.word2Vec()
429
+ # download excel file if it not downloaded yet
430
+ savePath = saveLinkFolder +"/"+ l.split("/")[-1]
431
+ excelPath = download_excel_file(l, savePath)
432
+ corpus = wc.tableTransformToCorpusText([],excelPath)
433
+ text = ''
434
+ for c in corpus:
435
+ para = corpus[c]
436
+ for words in para:
437
+ text += " ".join(words)
438
+ textsToExtract[l] = text
439
+ # delete folder after finishing getting text
440
+ #cmd = f'rm -r data/{id}'
441
+ #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
442
+ return textsToExtract
443
+ # Step 3.2: Extract context
444
+ def extract_context(text, keyword, window=500):
445
+ # firstly try accession number
446
+ idx = text.find(keyword)
447
+ if idx == -1:
448
+ return "Sample ID not found."
449
+ return text[max(0, idx-window): idx+window]
450
+ def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
451
+ if keep_if is None:
452
+ keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
453
+
454
+ outputs = ""
455
+ text = text.lower()
456
+
457
+ # If isolate is provided, prioritize paragraphs that mention it
458
+ # If isolate is provided, prioritize paragraphs that mention it
459
+ if accession and accession.lower() in text:
460
+ if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
461
+ outputs += extract_context(text, accession.lower(), window=700)
462
+ if isolate and isolate.lower() in text:
463
+ if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
464
+ outputs += extract_context(text, isolate.lower(), window=700)
465
+ for keyword in keep_if:
466
+ para = extract_context(text, keyword)
467
+ if para and para not in outputs:
468
+ outputs += para + "\n"
469
+ return outputs
470
+ # Step 4: Classification for now (demo purposes)
471
+ # 4.1: Using a HuggingFace model (question-answering)
472
+ def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
473
+ try:
474
+ qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
475
+ result = qa({"context": context, "question": question})
476
+ return result.get("answer", "Unknown")
477
+ except Exception as e:
478
+ return f"Error: {str(e)}"
479
+
480
+ # 4.2: Infer from haplogroup
481
+ # Load pre-trained spaCy model for NER
482
+ try:
483
+ nlp = spacy.load("en_core_web_sm")
484
+ except OSError:
485
+ download("en_core_web_sm")
486
+ nlp = spacy.load("en_core_web_sm")
487
+
488
+ # Define the haplogroup-to-region mapping (simple rule-based)
489
+ import csv
490
+
491
+ def load_haplogroup_mapping(csv_path):
492
+ mapping = {}
493
+ with open(csv_path) as f:
494
+ reader = csv.DictReader(f)
495
+ for row in reader:
496
+ mapping[row["haplogroup"]] = [row["region"],row["source"]]
497
+ return mapping
498
+
499
+ # Function to extract haplogroup from the text
500
+ def extract_haplogroup(text):
501
+ match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
502
+ if match:
503
+ submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
504
+ if submatch:
505
+ return submatch.group(0)
506
+ else:
507
+ return match.group(1) # fallback
508
+ fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
509
+ if fallback:
510
+ return fallback.group(1)
511
+ return None
512
+
513
+
514
+ # Function to extract location based on NER
515
+ def extract_location(text):
516
+ doc = nlp(text)
517
+ locations = []
518
+ for ent in doc.ents:
519
+ if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
520
+ locations.append(ent.text)
521
+ return locations
522
+
523
+ # Function to infer location from haplogroup
524
+ def infer_location_from_haplogroup(haplogroup):
525
+ haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
526
+ return haplo_map.get(haplogroup, ["Unknown","Unknown"])
527
+
528
+ # Function to classify the mtDNA sample
529
+ def classify_mtDNA_sample_from_haplo(text):
530
+ # Extract haplogroup
531
+ haplogroup = extract_haplogroup(text)
532
+ # Extract location based on NER
533
+ locations = extract_location(text)
534
+ # Infer location based on haplogroup
535
+ inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
536
+ return {
537
+ "source":sourceHaplo,
538
+ "locations_found_in_context": locations,
539
+ "haplogroup": haplogroup,
540
+ "inferred_location": inferred_location
541
+
542
+ }
543
+ # 4.3 Get from available NCBI
544
+ def infer_location_fromNCBI(accession):
545
+ try:
546
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
547
+ text = handle.read()
548
+ handle.close()
549
+ match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
550
+ if match:
551
+ return match.group(2), match.group(0) # This is the value like "Brunei"
552
+ return "Not found", "Not found"
553
+
554
+ except Exception as e:
555
+ print("❌ Entrez error:", e)
556
+ return "Not found", "Not found"
557
+
558
+ ### ANCIENT/MODERN FLAG
559
+ from Bio import Entrez
560
+ import re
561
+
562
+ def flag_ancient_modern(accession, textsToExtract, isolate=None):
563
+ """
564
+ Try to classify a sample as Ancient or Modern using:
565
+ 1. NCBI accession (if available)
566
+ 2. Supplementary text or context fallback
567
+ """
568
+ context = ""
569
+ label, explain = "", ""
570
+
571
+ try:
572
+ # Check if we can fetch metadata from NCBI using the accession
573
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
574
+ text = handle.read()
575
+ handle.close()
576
+
577
+ isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
578
+ if isolate_source:
579
+ context += isolate_source.group(0) + " "
580
+
581
+ specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
582
+ if specimen:
583
+ context += specimen.group(0) + " "
584
+
585
+ if context.strip():
586
+ label, explain = detect_ancient_flag(context)
587
+ if label!="Unknown":
588
+ return label, explain + " from NCBI\n(" + context + ")"
589
+
590
+ # If no useful NCBI metadata, check supplementary texts
591
+ if textsToExtract:
592
+ labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
593
+
594
+ for source in textsToExtract:
595
+ text_block = textsToExtract[source]
596
+ context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
597
+ label, explain = detect_ancient_flag(context)
598
+
599
+ if label == "Ancient":
600
+ labels["ancient"][0] += 1
601
+ labels["ancient"][1] += f"{source}:\n{explain}\n\n"
602
+ elif label == "Modern":
603
+ labels["modern"][0] += 1
604
+ labels["modern"][1] += f"{source}:\n{explain}\n\n"
605
+ else:
606
+ labels["unknown"] += 1
607
+
608
+ if max(labels["modern"][0],labels["ancient"][0]) > 0:
609
+ if labels["modern"][0] > labels["ancient"][0]:
610
+ return "Modern", labels["modern"][1]
611
+ else:
612
+ return "Ancient", labels["ancient"][1]
613
+ else:
614
+ return "Unknown", "No strong keywords detected"
615
+ else:
616
+ print("No DOI or PubMed ID available for inference.")
617
+ return "", ""
618
+
619
+ except Exception as e:
620
+ print("Error:", e)
621
+ return "", ""
622
+
623
+
624
+ def detect_ancient_flag(context_snippet):
625
+ context = context_snippet.lower()
626
+
627
+ ancient_keywords = [
628
+ "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
629
+ "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
630
+ "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
631
+ "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
632
+ ]
633
+
634
+ modern_keywords = [
635
+ "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
636
+ "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
637
+ "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
638
+ "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
639
+ "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
640
+ ]
641
+
642
+ ancient_hits = [k for k in ancient_keywords if k in context]
643
+ modern_hits = [k for k in modern_keywords if k in context]
644
+
645
+ if ancient_hits and not modern_hits:
646
+ return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
647
+ elif modern_hits and not ancient_hits:
648
+ return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
649
+ elif ancient_hits and modern_hits:
650
+ if len(ancient_hits) >= len(modern_hits):
651
+ return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
652
+ else:
653
+ return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
654
+
655
+ # Fallback to QA
656
+ answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
657
+ if answer.startswith("Error"):
658
+ return "Unknown", answer
659
+ if "ancient" in answer.lower():
660
+ return "Ancient", f"Leaning ancient based on QA: {answer}"
661
+ elif "modern" in answer.lower():
662
+ return "Modern", f"Leaning modern based on QA: {answer}"
663
+ else:
664
+ return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
665
+
666
+ # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
667
+ def classify_sample_location(accession):
668
+ outputs = {}
669
+ keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
670
+ # Step 1: get pubmed id and isolate
671
+ pubmedID, isolate = get_info_from_accession(accession)
672
+ '''if not pubmedID:
673
+ return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
674
+ if not isolate:
675
+ isolate = "UNKNOWN_ISOLATE"
676
+ # Step 2: get doi
677
+ doi = get_doi_from_pubmed_id(pubmedID)
678
+ '''if not doi:
679
+ return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
680
+ # Step 3: get text
681
+ '''textsToExtract = { "doiLink":"paperText"
682
+ "file1.pdf":"text1",
683
+ "file2.doc":"text2",
684
+ "file3.xlsx":excelText3'''
685
+ if doi and pubmedID:
686
+ textsToExtract = get_paper_text(doi,pubmedID)
687
+ else: textsToExtract = {}
688
+ '''if not textsToExtract:
689
+ return {"error": f"No texts extracted for DOI {doi}"}'''
690
+ if isolate not in [None, "UNKNOWN_ISOLATE"]:
691
+ label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
692
+ else:
693
+ label, explain = flag_ancient_modern(accession,textsToExtract)
694
+ # Step 4: prediction
695
+ outputs[accession] = {}
696
+ outputs[isolate] = {}
697
+ # 4.0 Infer from NCBI
698
+ location, outputNCBI = infer_location_fromNCBI(accession)
699
+ NCBI_result = {
700
+ "source": "NCBI",
701
+ "sample_id": accession,
702
+ "predicted_location": location,
703
+ "context_snippet": outputNCBI}
704
+ outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
705
+ if textsToExtract:
706
+ long_text = ""
707
+ for key in textsToExtract:
708
+ text = textsToExtract[key]
709
+ # try accession number first
710
+ outputs[accession][key] = {}
711
+ keyword = accession
712
+ context = extract_context(text, keyword, window=500)
713
+ # 4.1: Using a HuggingFace model (question-answering)
714
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
715
+ qa_result = {
716
+ "source": key,
717
+ "sample_id": keyword,
718
+ "predicted_location": location,
719
+ "context_snippet": context
720
+ }
721
+ outputs[keyword][key]["QAModel"] = qa_result
722
+ # 4.2: Infer from haplogroup
723
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
724
+ outputs[keyword][key]["haplogroup"] = haplo_result
725
+ # try isolate
726
+ keyword = isolate
727
+ outputs[isolate][key] = {}
728
+ context = extract_context(text, keyword, window=500)
729
+ # 4.1.1: Using a HuggingFace model (question-answering)
730
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
731
+ qa_result = {
732
+ "source": key,
733
+ "sample_id": keyword,
734
+ "predicted_location": location,
735
+ "context_snippet": context
736
+ }
737
+ outputs[keyword][key]["QAModel"] = qa_result
738
+ # 4.2.1: Infer from haplogroup
739
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
740
+ outputs[keyword][key]["haplogroup"] = haplo_result
741
+ # add long text
742
+ long_text += text + ". \n"
743
+ # 4.3: UpgradeClassify
744
+ # try sample_id as accession number
745
+ sample_id = accession
746
+ if sample_id:
747
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
748
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
749
+ if locations!="No clear location found in top matches":
750
+ outputs[sample_id]["upgradeClassifier"] = {}
751
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
752
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
753
+ "sample_id": sample_id,
754
+ "predicted_location": ", ".join(locations),
755
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
756
+ }
757
+ # try sample_id as isolate name
758
+ sample_id = isolate
759
+ if sample_id:
760
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
761
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
762
+ if locations!="No clear location found in top matches":
763
+ outputs[sample_id]["upgradeClassifier"] = {}
764
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
765
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
766
+ "sample_id": sample_id,
767
+ "predicted_location": ", ".join(locations),
768
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
769
+ }
770
  return outputs, label, explain