VyLala commited on
Commit
4a80798
·
verified ·
1 Parent(s): e15fc9b

Upload 54 files

Browse files

Update 8_20_2025

NER/PDF/pdf.py CHANGED
@@ -1,193 +1,330 @@
1
- #!pip install pdfreader
2
- import pdfreader
3
- from pdfreader import PDFDocument, SimplePDFViewer
4
- #!pip install bs4
5
- from bs4 import BeautifulSoup
6
- import requests
7
- from NER import cleanText
8
- #!pip install tabula-py
9
- import tabula
10
- import fitz # PyMuPDF
11
- import os
12
-
13
- class PDF():
14
- def __init__(self, pdf, saveFolder, doi=None):
15
- self.pdf = pdf
16
- self.doi = doi
17
- self.saveFolder = saveFolder
18
-
19
- def openPDFFile(self):
20
- if "https" in self.pdf:
21
- name = self.pdf.split("/")[-1]
22
- name = self.downloadPDF(self.saveFolder)
23
- if name != "no pdfLink to download":
24
- fileToOpen = os.path.join(self.saveFolder, name)
25
- else:
26
- fileToOpen = self.pdf
27
- else:
28
- fileToOpen = self.pdf
29
- return open(fileToOpen, "rb")
30
-
31
- def downloadPDF(self, saveFolder):
32
- pdfLink = ''
33
- if ".pdf" not in self.pdf and "https" not in self.pdf:
34
- r = requests.get(self.pdf)
35
- soup = BeautifulSoup(r.content, 'html.parser')
36
- links = soup.find_all("a")
37
- for link in links:
38
- if ".pdf" in link.get("href", ""):
39
- if self.doi in link.get("href"):
40
- pdfLink = link.get("href")
41
- break
42
- else:
43
- pdfLink = self.pdf
44
-
45
- if pdfLink != '':
46
- response = requests.get(pdfLink)
47
- name = pdfLink.split("/")[-1]
48
- print("inside download PDF and name and link are: ", pdfLink, name)
49
- print("saveFolder is: ", saveFolder)
50
- with open(os.path.join(saveFolder, name), 'wb') as pdf:
51
- print("len of response content: ", len(response.content))
52
- pdf.write(response.content)
53
- print("pdf downloaded")
54
- return name
55
- else:
56
- return "no pdfLink to download"
57
-
58
- def extractText(self):
59
- try:
60
- fileToOpen = self.openPDFFile().name
61
- try:
62
- doc = fitz.open(fileToOpen)
63
- text = ""
64
- for page in doc:
65
- text += page.get_text("text") + "\n\n"
66
- doc.close()
67
-
68
- if len(text.strip()) < 100:
69
- print("Fallback to PDFReader due to weak text extraction.")
70
- text = self.extractTextWithPDFReader()
71
- return text
72
- except Exception as e:
73
- print("Failed with PyMuPDF, fallback to PDFReader:", e)
74
- return self.extractTextWithPDFReader()
75
- except:
76
- return ""
77
- def extract_text_excluding_tables(self):
78
- fileToOpen = self.openPDFFile().name
79
- text = ""
80
- try:
81
- doc = fitz.open(fileToOpen)
82
- for page in doc:
83
- blocks = page.get_text("dict")["blocks"]
84
-
85
- for block in blocks:
86
- if block["type"] == 0: # text block
87
- lines = block.get("lines", [])
88
-
89
- if not lines:
90
- continue
91
- avg_words_per_line = sum(len(l["spans"]) for l in lines) / len(lines)
92
- if avg_words_per_line > 1: # Heuristic: paragraph-like blocks
93
- for line in lines:
94
- text += " ".join(span["text"] for span in line["spans"]) + "\n"
95
- doc.close()
96
- if len(text.strip()) < 100:
97
- print("Fallback to PDFReader due to weak text extraction.")
98
- text = self.extractTextWithPDFReader()
99
- return text
100
- except Exception as e:
101
- print("Failed with PyMuPDF, fallback to PDFReader:", e)
102
- return self.extractTextWithPDFReader()
103
-
104
- def extractTextWithPDFReader(self):
105
- jsonPage = {}
106
- try:
107
- pdf = self.openPDFFile()
108
- print("open pdf file")
109
- print(pdf)
110
- doc = PDFDocument(pdf)
111
- viewer = SimplePDFViewer(pdf)
112
- all_pages = [p for p in doc.pages()]
113
- cl = cleanText.cleanGenText()
114
- pdfText = ""
115
- for page in range(1, len(all_pages)):
116
- viewer.navigate(page)
117
- viewer.render()
118
- if str(page) not in jsonPage:
119
- jsonPage[str(page)] = {}
120
- text = "".join(viewer.canvas.strings)
121
- clean, filteredWord = cl.textPreprocessing(text)
122
- jsonPage[str(page)]["normalText"] = [text]
123
- jsonPage[str(page)]["cleanText"] = [' '.join(filteredWord)]
124
- jsonPage[str(page)]["image"] = [viewer.canvas.images]
125
- jsonPage[str(page)]["form"] = [viewer.canvas.forms]
126
- jsonPage[str(page)]["content"] = [viewer.canvas.text_content]
127
- jsonPage[str(page)]["inline_image"] = [viewer.canvas.inline_images]
128
- pdf.close()
129
- except:
130
- jsonPage = {}
131
- return self.mergeTextinJson(jsonPage)
132
-
133
- def extractTable(self,pages="all",saveFile=None,outputFormat=None):
134
- '''pages (str, int, iterable of int, optional) –
135
- An optional values specifying pages to extract from. It allows str,`int`, iterable of :int. Default: 1
136
- Examples: '1-2,3', 'all', [1,2]'''
137
- df = []
138
- if "https" in self.pdf:
139
- name = self.pdf.split("/")[-1]
140
- name = self.downloadPDF(self.saveFolder)
141
- if name != "no pdfLink to download":
142
- fileToOpen = self.saveFolder + "/" + name
143
- else: fileToOpen = self.pdf
144
- else: fileToOpen = self.pdf
145
- try:
146
- df = tabula.read_pdf(fileToOpen, pages=pages)
147
- # saveFile: "/content/drive/MyDrive/CollectData/NER/PDF/tableS1.csv"
148
- # outputFormat: "csv"
149
- #tabula.convert_into(self.pdf, saveFile, output_format=outputFormat, pages=pages)
150
- except:# ValueError:
151
- df = []
152
- print("No tables found in PDF file")
153
- return df
154
-
155
- def mergeTextinJson(self, jsonPDF):
156
- try:
157
- cl = cleanText.cleanGenText()
158
- pdfText = ""
159
- if jsonPDF:
160
- for page in jsonPDF:
161
- if len(jsonPDF[page]["normalText"]) > 0:
162
- for i in range(len(jsonPDF[page]["normalText"])):
163
- text = jsonPDF[page]["normalText"][i]
164
- if len(text) > 0:
165
- text = cl.removeTabWhiteSpaceNewLine(text)
166
- text = cl.removeExtraSpaceBetweenWords(text)
167
- jsonPDF[page]["normalText"][i] = text
168
- if i - 1 > 0:
169
- if jsonPDF[page]["normalText"][i - 1][-1] != ".":
170
- pdfText += ". "
171
- pdfText += jsonPDF[page]["normalText"][i]
172
- if len(jsonPDF[page]["normalText"][i]) > 0:
173
- if jsonPDF[page]["normalText"][i][-1] != ".":
174
- pdfText += "."
175
- pdfText += "\n\n"
176
- return pdfText
177
- except:
178
- return ""
179
-
180
- def getReference(self):
181
- pass
182
-
183
- def getSupMaterial(self):
184
- pass
185
-
186
- def removeHeaders(self):
187
- pass
188
-
189
- def removeFooters(self):
190
- pass
191
-
192
- def removeReference(self):
193
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!pip install pdfreader
2
+ import pdfreader
3
+ from pdfreader import PDFDocument, SimplePDFViewer
4
+ #!pip install bs4
5
+ from bs4 import BeautifulSoup
6
+ import requests
7
+ from NER import cleanText
8
+ #!pip install tabula-py
9
+ import tabula
10
+ import fitz # PyMuPDF
11
+ import os
12
+
13
+ class PDF():
14
+ def __init__(self, pdf, saveFolder, doi=None):
15
+ self.pdf = pdf
16
+ self.doi = doi
17
+ self.saveFolder = saveFolder
18
+
19
+ def openPDFFile(self):
20
+ if "https" in self.pdf:
21
+ name = self.pdf.split("/")[-1]
22
+ name = self.downloadPDF(self.saveFolder)
23
+ if name != "no pdfLink to download":
24
+ fileToOpen = os.path.join(self.saveFolder, name)
25
+ else:
26
+ fileToOpen = self.pdf
27
+ else:
28
+ fileToOpen = self.pdf
29
+ return open(fileToOpen, "rb")
30
+
31
+ def downloadPDF(self, saveFolder):
32
+ pdfLink = ''
33
+ if ".pdf" not in self.pdf and "https" not in self.pdf:
34
+ r = requests.get(self.pdf)
35
+ soup = BeautifulSoup(r.content, 'html.parser')
36
+ links = soup.find_all("a")
37
+ for link in links:
38
+ if ".pdf" in link.get("href", ""):
39
+ if self.doi in link.get("href"):
40
+ pdfLink = link.get("href")
41
+ break
42
+ else:
43
+ pdfLink = self.pdf
44
+
45
+ if pdfLink != '':
46
+ response = requests.get(pdfLink)
47
+ name = pdfLink.split("/")[-1]
48
+ print("inside download PDF and name and link are: ", pdfLink, name)
49
+ print("saveFolder is: ", saveFolder)
50
+ with open(os.path.join(saveFolder, name), 'wb') as pdf:
51
+ print("len of response content: ", len(response.content))
52
+ pdf.write(response.content)
53
+ print("pdf downloaded")
54
+ return name
55
+ else:
56
+ return "no pdfLink to download"
57
+
58
+ def extractText(self):
59
+ try:
60
+ fileToOpen = self.openPDFFile().name
61
+ try:
62
+ doc = fitz.open(fileToOpen)
63
+ text = ""
64
+ for page in doc:
65
+ text += page.get_text("text") + "\n\n"
66
+ doc.close()
67
+
68
+ if len(text.strip()) < 100:
69
+ print("Fallback to PDFReader due to weak text extraction.")
70
+ text = self.extractTextWithPDFReader()
71
+ return text
72
+ except Exception as e:
73
+ print("Failed with PyMuPDF, fallback to PDFReader:", e)
74
+ return self.extractTextWithPDFReader()
75
+ except:
76
+ return ""
77
+ def extract_text_excluding_tables(self):
78
+ fileToOpen = self.openPDFFile().name
79
+ text = ""
80
+ try:
81
+ doc = fitz.open(fileToOpen)
82
+ for page in doc:
83
+ blocks = page.get_text("dict")["blocks"]
84
+
85
+ for block in blocks:
86
+ if block["type"] == 0: # text block
87
+ lines = block.get("lines", [])
88
+
89
+ if not lines:
90
+ continue
91
+ avg_words_per_line = sum(len(l["spans"]) for l in lines) / len(lines)
92
+ if avg_words_per_line > 1: # Heuristic: paragraph-like blocks
93
+ for line in lines:
94
+ text += " ".join(span["text"] for span in line["spans"]) + "\n"
95
+ doc.close()
96
+ if len(text.strip()) < 100:
97
+ print("Fallback to PDFReader due to weak text extraction.")
98
+ text = self.extractTextWithPDFReader()
99
+ return text
100
+ except Exception as e:
101
+ print("Failed with PyMuPDF, fallback to PDFReader:", e)
102
+ return self.extractTextWithPDFReader()
103
+
104
+ def extractTextWithPDFReader(self):
105
+ jsonPage = {}
106
+ try:
107
+ pdf = self.openPDFFile()
108
+ print("open pdf file")
109
+ print(pdf)
110
+ doc = PDFDocument(pdf)
111
+ viewer = SimplePDFViewer(pdf)
112
+ all_pages = [p for p in doc.pages()]
113
+ cl = cleanText.cleanGenText()
114
+ pdfText = ""
115
+ for page in range(1, len(all_pages)):
116
+ viewer.navigate(page)
117
+ viewer.render()
118
+ if str(page) not in jsonPage:
119
+ jsonPage[str(page)] = {}
120
+ text = "".join(viewer.canvas.strings)
121
+ clean, filteredWord = cl.textPreprocessing(text)
122
+ jsonPage[str(page)]["normalText"] = [text]
123
+ jsonPage[str(page)]["cleanText"] = [' '.join(filteredWord)]
124
+ jsonPage[str(page)]["image"] = [viewer.canvas.images]
125
+ jsonPage[str(page)]["form"] = [viewer.canvas.forms]
126
+ jsonPage[str(page)]["content"] = [viewer.canvas.text_content]
127
+ jsonPage[str(page)]["inline_image"] = [viewer.canvas.inline_images]
128
+ pdf.close()
129
+ except:
130
+ jsonPage = {}
131
+ return self.mergeTextinJson(jsonPage)
132
+
133
+ def extractTable(self,pages="all",saveFile=None,outputFormat=None):
134
+ '''pages (str, int, iterable of int, optional) –
135
+ An optional values specifying pages to extract from. It allows str,`int`, iterable of :int. Default: 1
136
+ Examples: '1-2,3', 'all', [1,2]'''
137
+ df = []
138
+ if "https" in self.pdf:
139
+ name = self.pdf.split("/")[-1]
140
+ name = self.downloadPDF(self.saveFolder)
141
+ if name != "no pdfLink to download":
142
+ fileToOpen = self.saveFolder + "/" + name
143
+ else: fileToOpen = self.pdf
144
+ else: fileToOpen = self.pdf
145
+ try:
146
+ df = tabula.read_pdf(fileToOpen, pages=pages)
147
+ # saveFile: "/content/drive/MyDrive/CollectData/NER/PDF/tableS1.csv"
148
+ # outputFormat: "csv"
149
+ #tabula.convert_into(self.pdf, saveFile, output_format=outputFormat, pages=pages)
150
+ except:# ValueError:
151
+ df = []
152
+ print("No tables found in PDF file")
153
+ return df
154
+
155
+ def mergeTextinJson(self, jsonPDF):
156
+ try:
157
+ cl = cleanText.cleanGenText()
158
+ pdfText = ""
159
+ if jsonPDF:
160
+ for page in jsonPDF:
161
+ if len(jsonPDF[page]["normalText"]) > 0:
162
+ for i in range(len(jsonPDF[page]["normalText"])):
163
+ text = jsonPDF[page]["normalText"][i]
164
+ if len(text) > 0:
165
+ text = cl.removeTabWhiteSpaceNewLine(text)
166
+ text = cl.removeExtraSpaceBetweenWords(text)
167
+ jsonPDF[page]["normalText"][i] = text
168
+ if i - 1 > 0:
169
+ if jsonPDF[page]["normalText"][i - 1][-1] != ".":
170
+ pdfText += ". "
171
+ pdfText += jsonPDF[page]["normalText"][i]
172
+ if len(jsonPDF[page]["normalText"][i]) > 0:
173
+ if jsonPDF[page]["normalText"][i][-1] != ".":
174
+ pdfText += "."
175
+ pdfText += "\n\n"
176
+ return pdfText
177
+ except:
178
+ return ""
179
+
180
+ import os
181
+ import requests
182
+ from bs4 import BeautifulSoup
183
+ import fitz # PyMuPDF
184
+ import tabula
185
+ from pdfreader import PDFDocument, SimplePDFViewer
186
+ from NER import cleanText
187
+
188
+ class PDFFast:
189
+ _cache = {} # cache for loaded documents
190
+
191
+ def __init__(self, pdf_path_or_url, saveFolder, doi=None):
192
+ self.pdf = pdf_path_or_url
193
+ self.saveFolder = saveFolder or "."
194
+ self.doi = doi
195
+ self.local_path = self._ensure_local()
196
+ self.doc = None # Lazy load in PyMuPDF
197
+
198
+ def _ensure_local(self):
199
+ """Download if URL, else return local path."""
200
+ try:
201
+ if self.pdf.startswith("http"):
202
+ name = os.path.basename(self.pdf.split("?")[0])
203
+ local_path = os.path.join(self.saveFolder, name)
204
+ if not os.path.exists(local_path):
205
+ pdf_link = self._resolve_pdf_link(self.pdf)
206
+ if not pdf_link:
207
+ raise FileNotFoundError(f"No PDF link found for {self.pdf}")
208
+ print(f"⬇ Downloading PDF: {pdf_link}")
209
+ r = requests.get(pdf_link, timeout=15)
210
+ r.raise_for_status()
211
+ with open(local_path, "wb") as f:
212
+ f.write(r.content)
213
+ return local_path
214
+ return self.pdf
215
+ except:
216
+ try:
217
+ import requests
218
+ if self.pdf.startswith("http"):
219
+ url = self.pdf
220
+ headers = {
221
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36",
222
+ "Accept": "application/pdf",
223
+ "Referer": "https://www.researchgate.net/",
224
+ }
225
+
226
+ r = requests.get(url, headers=headers)
227
+ r.raise_for_status()
228
+ local_path = os.path.join(self.saveFolder, name)
229
+
230
+ with open(local_path, "wb") as f:
231
+ f.write(r.content)
232
+ return local_path
233
+ except:
234
+ return self.pdf
235
+
236
+ def _resolve_pdf_link(self, url):
237
+ """If URL is HTML, parse for .pdf link."""
238
+ if url.lower().endswith(".pdf"):
239
+ return url
240
+ try:
241
+ r = requests.get(url, timeout=15)
242
+ soup = BeautifulSoup(r.content, "html.parser")
243
+ for link in soup.find_all("a"):
244
+ href = link.get("href", "")
245
+ if ".pdf" in href and (not self.doi or self.doi in href):
246
+ return href if href.startswith("http") else f"https://{r.url.split('/')[2]}{href}"
247
+ except Exception as e:
248
+ print(f"❌ Failed to resolve PDF link: {e}")
249
+ return None
250
+
251
+ def _load_doc(self):
252
+ """Load PyMuPDF document with caching."""
253
+ if self.local_path in PDFFast._cache:
254
+ return PDFFast._cache[self.local_path]
255
+ doc = fitz.open(self.local_path)
256
+ PDFFast._cache[self.local_path] = doc
257
+ return doc
258
+
259
+ def extract_text(self):
260
+ """Extract all text quickly with PyMuPDF."""
261
+ try:
262
+ doc = self._load_doc()
263
+ text = "\n\n".join(page.get_text(flags=1) for page in doc)
264
+ return text.strip() or self.extract_text_pdfreader()
265
+ except Exception as e:
266
+ print(f"⚠️ PyMuPDF failed: {e}")
267
+ return self.extract_text_pdfreader()
268
+
269
+ def extract_text_excluding_tables(self):
270
+ """Heuristic: skip table-like blocks."""
271
+ text_parts = []
272
+ try:
273
+ doc = self._load_doc()
274
+ for page in doc:
275
+ for block in page.get_text("dict")["blocks"]:
276
+ if block["type"] != 0: # skip non-text
277
+ continue
278
+ lines = block.get("lines", [])
279
+ avg_words = sum(len(l["spans"]) for l in lines) / max(1, len(lines))
280
+ if avg_words > 1:
281
+ for line in lines:
282
+ text_parts.append(" ".join(span["text"] for span in line["spans"]))
283
+ return "\n".join(text_parts).strip()
284
+ except Exception as e:
285
+ print(f"⚠️ Table-exclusion failed: {e}")
286
+ return self.extract_text_pdfreader()
287
+
288
+ def extract_text_pdfreader(self):
289
+ """Fallback using PDFReader."""
290
+ try:
291
+ with open(self.local_path, "rb") as f:
292
+ doc = PDFDocument(f)
293
+ viewer = SimplePDFViewer(f)
294
+ jsonPage = {}
295
+ cl = cleanText.cleanGenText()
296
+
297
+ all_pages = [p for p in doc.pages()]
298
+ for page_num in range(1, len(all_pages)):
299
+ viewer.navigate(page_num)
300
+ viewer.render()
301
+ text = "".join(viewer.canvas.strings)
302
+ clean, filtered = cl.textPreprocessing(text)
303
+ jsonPage[str(page_num)] = {
304
+ "normalText": [text],
305
+ "cleanText": [' '.join(filtered)],
306
+ "image": [viewer.canvas.images],
307
+ "form": [viewer.canvas.forms]
308
+ }
309
+ return self._merge_text(jsonPage)
310
+ except Exception as e:
311
+ print(f"❌ PDFReader failed: {e}")
312
+ return ""
313
+
314
+ def _merge_text(self, jsonPDF):
315
+ """Merge pages into one text string."""
316
+ cl = cleanText.cleanGenText()
317
+ pdfText = ""
318
+ for page in jsonPDF:
319
+ for text in jsonPDF[page]["normalText"]:
320
+ t = cl.removeExtraSpaceBetweenWords(cl.removeTabWhiteSpaceNewLine(text))
321
+ pdfText += t + "\n\n"
322
+ return pdfText.strip()
323
+
324
+ def extract_tables(self, pages="all"):
325
+ """Extract tables with Tabula."""
326
+ try:
327
+ return tabula.read_pdf(self.local_path, pages=pages)
328
+ except Exception:
329
+ print("⚠️ No tables found.")
330
+ return []
NER/WordDoc/wordDoc.py CHANGED
@@ -175,4 +175,104 @@ class wordDoc(): # using python-docx
175
  def getReference(self):
176
  pass
177
  def getSupMaterial(self):
178
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def getReference(self):
176
  pass
177
  def getSupMaterial(self):
178
+ pass
179
+
180
+ import os
181
+ import requests
182
+ from spire.doc import Document
183
+ from spire.doc.common import *
184
+ from spire.xls import Workbook, FileFormat
185
+
186
+ class WordDocFast:
187
+ _cache = {} # Cache Document objects by file path/URL
188
+
189
+ def __init__(self, wordDoc, saveFolder):
190
+ self.wordDoc = wordDoc
191
+ self.saveFolder = saveFolder or "."
192
+ self.doc = self._load_document()
193
+
194
+ def _load_document(self):
195
+ # Use cache if available
196
+ if self.wordDoc in WordDocFast._cache:
197
+ return WordDocFast._cache[self.wordDoc]
198
+
199
+ local_path = self.wordDoc
200
+ if self.wordDoc.startswith("http"):
201
+ name = os.path.basename(self.wordDoc)
202
+ local_path = os.path.join(self.saveFolder, name)
203
+ if not os.path.exists(local_path):
204
+ r = requests.get(self.wordDoc, timeout=15)
205
+ r.raise_for_status()
206
+ with open(local_path, "wb") as f:
207
+ f.write(r.content)
208
+
209
+ doc = Document()
210
+ doc.LoadFromFile(local_path)
211
+ WordDocFast._cache[self.wordDoc] = doc
212
+ return doc
213
+
214
+ def extractText(self):
215
+ """Extract full text (faster than page-by-page parsing)."""
216
+ try:
217
+ return self.doc.GetText()
218
+ except:
219
+ try:
220
+ return self.extractTextBySections()
221
+ except:
222
+ print("extract word doc text failed")
223
+ return ''
224
+ def extractTextBySections(self):
225
+ """Stream text section-by-section (can be faster for large docs)."""
226
+ all_text = []
227
+ for s in range(self.doc.Sections.Count):
228
+ section = self.doc.Sections.get_Item(s)
229
+ for p in range(section.Paragraphs.Count):
230
+ text = section.Paragraphs.get_Item(p).Text.strip()
231
+ if text:
232
+ all_text.append(text)
233
+ return "\n".join(all_text)
234
+
235
+ def extractTablesAsList(self):
236
+ """Extract tables as list-of-lists (faster)."""
237
+ tables = []
238
+ for s in range(self.doc.Sections.Count):
239
+ section = self.doc.Sections.get_Item(s)
240
+ for t in range(section.Tables.Count):
241
+ table = section.Tables.get_Item(t)
242
+ table_data = []
243
+ for r in range(table.Rows.Count):
244
+ row_data = []
245
+ for c in range(table.Rows.get_Item(r).Cells.Count):
246
+ cell = table.Rows.get_Item(r).Cells.get_Item(c)
247
+ cell_text = " ".join(
248
+ cell.Paragraphs.get_Item(p).Text.strip()
249
+ for p in range(cell.Paragraphs.Count)
250
+ ).strip()
251
+ row_data.append(cell_text)
252
+ table_data.append(row_data)
253
+ tables.append(table_data)
254
+ return tables
255
+
256
+ def extractTablesAsExcel(self):
257
+ """Export tables to Excel."""
258
+ wb = Workbook()
259
+ wb.Worksheets.Clear()
260
+ for s in range(self.doc.Sections.Count):
261
+ section = self.doc.Sections.get_Item(s)
262
+ for t in range(section.Tables.Count):
263
+ table = section.Tables.get_Item(t)
264
+ ws = wb.Worksheets.Add(f"Table_{s+1}_{t+1}")
265
+ for r in range(table.Rows.Count):
266
+ row = table.Rows.get_Item(r)
267
+ for c in range(row.Cells.Count):
268
+ cell = row.Cells.get_Item(c)
269
+ cell_text = " ".join(
270
+ cell.Paragraphs.get_Item(p).Text
271
+ for p in range(cell.Paragraphs.Count)
272
+ ).strip()
273
+ ws.SetCellValue(r + 1, c + 1, cell_text)
274
+ name = os.path.basename(self.wordDoc) + ".xlsx"
275
+ out_path = os.path.join(self.saveFolder, name)
276
+ wb.SaveToFile(out_path, FileFormat.Version2016)
277
+ wb.Dispose()
278
+ return out_path
NER/html/extractHTML.py CHANGED
@@ -1,226 +1,249 @@
1
- # reference: https://www.crummy.com/software/BeautifulSoup/bs4/doc/#for-html-documents
2
- from bs4 import BeautifulSoup
3
- import requests
4
- from DefaultPackages import openFile, saveFile
5
- from NER import cleanText
6
- import pandas as pd
7
- class HTML():
8
- def __init__(self, htmlFile, htmlLink):
9
- self.htmlLink = htmlLink
10
- self.htmlFile = htmlFile
11
- # def openHTMLFile(self):
12
- # headers = {
13
- # "User-Agent": (
14
- # "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
15
- # "AppleWebKit/537.36 (KHTML, like Gecko) "
16
- # "Chrome/114.0.0.0 Safari/537.36"
17
- # ),
18
- # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
19
- # "Referer": self.htmlLink,
20
- # "Connection": "keep-alive"
21
- # }
22
-
23
- # session = requests.Session()
24
- # session.headers.update(headers)
25
-
26
- # if self.htmlLink != "None":
27
- # try:
28
- # r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
29
- # if r.status_code != 200:
30
- # print(f"❌ HTML GET failed: {r.status_code} — {self.htmlLink}")
31
- # return BeautifulSoup("", 'html.parser')
32
- # soup = BeautifulSoup(r.content, 'html.parser')
33
- # except Exception as e:
34
- # print(f"❌ Exception fetching HTML: {e}")
35
- # return BeautifulSoup("", 'html.parser')
36
- # else:
37
- # with open(self.htmlFile) as fp:
38
- # soup = BeautifulSoup(fp, 'html.parser')
39
- # return soup
40
- from lxml.etree import ParserError, XMLSyntaxError
41
-
42
- def openHTMLFile(self):
43
- not_need_domain = ['https://broadinstitute.github.io/picard/',
44
- 'https://software.broadinstitute.org/gatk/best-practices/',
45
- 'https://www.ncbi.nlm.nih.gov/genbank/',
46
- 'https://www.mitomap.org/']
47
- headers = {
48
- "User-Agent": (
49
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
50
- "AppleWebKit/537.36 (KHTML, like Gecko) "
51
- "Chrome/114.0.0.0 Safari/537.36"
52
- ),
53
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
54
- "Referer": self.htmlLink,
55
- "Connection": "keep-alive"
56
- }
57
-
58
- session = requests.Session()
59
- session.headers.update(headers)
60
- if self.htmlLink in not_need_domain:
61
- return BeautifulSoup("", 'html.parser')
62
- try:
63
- if self.htmlLink and self.htmlLink != "None":
64
- r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
65
- if r.status_code != 200 or not r.text.strip():
66
- print(f"❌ HTML GET failed ({r.status_code}) or empty page: {self.htmlLink}")
67
- return BeautifulSoup("", 'html.parser')
68
- soup = BeautifulSoup(r.content, 'html.parser')
69
- else:
70
- with open(self.htmlFile, encoding='utf-8') as fp:
71
- soup = BeautifulSoup(fp, 'html.parser')
72
- except (ParserError, XMLSyntaxError, OSError) as e:
73
- print(f"🚫 HTML parse error for {self.htmlLink}: {type(e).__name__}")
74
- return BeautifulSoup("", 'html.parser')
75
- except Exception as e:
76
- print(f"❌ General exception for {self.htmlLink}: {e}")
77
- return BeautifulSoup("", 'html.parser')
78
-
79
- return soup
80
-
81
- def getText(self):
82
- soup = self.openHTMLFile()
83
- s = soup.find_all("html")
84
- text = ""
85
- if s:
86
- for t in range(len(s)):
87
- text = s[t].get_text()
88
- cl = cleanText.cleanGenText()
89
- text = cl.removeExtraSpaceBetweenWords(text)
90
- return text
91
- def getListSection(self, scienceDirect=None):
92
- try:
93
- json = {}
94
- text = ""
95
- textJson, textHTML = "",""
96
- if scienceDirect == None:
97
- soup = self.openHTMLFile()
98
- # get list of section
99
- json = {}
100
- for h2Pos in range(len(soup.find_all('h2'))):
101
- if soup.find_all('h2')[h2Pos].text not in json:
102
- json[soup.find_all('h2')[h2Pos].text] = []
103
- if h2Pos + 1 < len(soup.find_all('h2')):
104
- content = soup.find_all('h2')[h2Pos].find_next("p")
105
- nexth2Content = soup.find_all('h2')[h2Pos+1].find_next("p")
106
- while content.text != nexth2Content.text:
107
- json[soup.find_all('h2')[h2Pos].text].append(content.text)
108
- content = content.find_next("p")
109
- else:
110
- content = soup.find_all('h2')[h2Pos].find_all_next("p",string=True)
111
- json[soup.find_all('h2')[h2Pos].text] = list(i.text for i in content)
112
- # format
113
- '''json = {'Abstract':[], 'Introduction':[], 'Methods'[],
114
- 'Results':[], 'Discussion':[], 'References':[],
115
- 'Acknowledgements':[], 'Author information':[], 'Ethics declarations':[],
116
- 'Additional information':[], 'Electronic supplementary material':[],
117
- 'Rights and permissions':[], 'About this article':[], 'Search':[], 'Navigation':[]}'''
118
- if scienceDirect!= None or len(json)==0:
119
- # Replace with your actual Elsevier API key
120
- api_key = os.environ["SCIENCE_DIRECT_API"]
121
- # ScienceDirect article DOI or PI (Example DOI)
122
- doi = self.htmlLink.split("https://doi.org/")[-1] #"10.1016/j.ajhg.2011.01.009"
123
- # Base URL for the Elsevier API
124
- base_url = "https://api.elsevier.com/content/article/doi/"
125
- # Set headers with API key
126
- headers = {
127
- "Accept": "application/json",
128
- "X-ELS-APIKey": api_key
129
- }
130
- # Make the API request
131
- response = requests.get(base_url + doi, headers=headers)
132
- # Check if the request was successful
133
- if response.status_code == 200:
134
- data = response.json()
135
- supp_data = data["full-text-retrieval-response"]#["coredata"]["link"]
136
- if "originalText" in list(supp_data.keys()):
137
- if type(supp_data["originalText"])==str:
138
- json["originalText"] = [supp_data["originalText"]]
139
- if type(supp_data["originalText"])==dict:
140
- json["originalText"] = [supp_data["originalText"][key] for key in supp_data["originalText"]]
141
- else:
142
- if type(supp_data)==dict:
143
- for key in supp_data:
144
- json[key] = [supp_data[key]]
145
-
146
- textJson = self.mergeTextInJson(json)
147
- textHTML = self.getText()
148
- if len(textHTML) > len(textJson):
149
- text = textHTML
150
- else: text = textJson
151
- return text #json
152
- except:
153
- print("failed all")
154
- return ""
155
- def getReference(self):
156
- # get reference to collect more next data
157
- ref = []
158
- json = self.getListSection()
159
- for key in json["References"]:
160
- ct = cleanText.cleanGenText(key)
161
- cleanText, filteredWord = ct.cleanText()
162
- if cleanText not in ref:
163
- ref.append(cleanText)
164
- return ref
165
- def getSupMaterial(self):
166
- # check if there is material or not
167
- json = {}
168
- soup = self.openHTMLFile()
169
- for h2Pos in range(len(soup.find_all('h2'))):
170
- if "supplementary" in soup.find_all('h2')[h2Pos].text.lower() or "material" in soup.find_all('h2')[h2Pos].text.lower() or "additional" in soup.find_all('h2')[h2Pos].text.lower() or "support" in soup.find_all('h2')[h2Pos].text.lower():
171
- #print(soup.find_all('h2')[h2Pos].find_next("a").get("href"))
172
- link, output = [],[]
173
- if soup.find_all('h2')[h2Pos].text not in json:
174
- json[soup.find_all('h2')[h2Pos].text] = []
175
- for l in soup.find_all('h2')[h2Pos].find_all_next("a",href=True):
176
- link.append(l["href"])
177
- if h2Pos + 1 < len(soup.find_all('h2')):
178
- nexth2Link = soup.find_all('h2')[h2Pos+1].find_next("a",href=True)["href"]
179
- if nexth2Link in link:
180
- link = link[:link.index(nexth2Link)]
181
- # only take links having "https" in that
182
- for i in link:
183
- if "https" in i: output.append(i)
184
- json[soup.find_all('h2')[h2Pos].text].extend(output)
185
- return json
186
- def extractTable(self):
187
- soup = self.openHTMLFile()
188
- df = []
189
- if len(soup)>0:
190
- try:
191
- df = pd.read_html(str(soup))
192
- except ValueError:
193
- df = []
194
- print("No tables found in HTML file")
195
- return df
196
- def mergeTextInJson(self,jsonHTML):
197
- cl = cleanText.cleanGenText()
198
- #cl = cleanGenText()
199
- htmlText = ""
200
- for sec in jsonHTML:
201
- # section is "\n\n"
202
- if len(jsonHTML[sec]) > 0:
203
- for i in range(len(jsonHTML[sec])):
204
- # same section is just a dot.
205
- text = jsonHTML[sec][i]
206
- if len(text)>0:
207
- #text = cl.removeTabWhiteSpaceNewLine(text)
208
- #text = cl.removeExtraSpaceBetweenWords(text)
209
- text, filteredWord = cl.textPreprocessing(text, keepPeriod=True)
210
- jsonHTML[sec][i] = text
211
- if i-1 >= 0:
212
- if len(jsonHTML[sec][i-1])>0:
213
- if jsonHTML[sec][i-1][-1] != ".":
214
- htmlText += ". "
215
- htmlText += jsonHTML[sec][i]
216
- if len(jsonHTML[sec][i]) > 0:
217
- if jsonHTML[sec][i][-1]!=".":
218
- htmlText += "."
219
- htmlText += "\n\n"
220
- return htmlText
221
- def removeHeaders(self):
222
- pass
223
- def removeFooters(self):
224
- pass
225
- def removeReferences(self):
226
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reference: https://www.crummy.com/software/BeautifulSoup/bs4/doc/#for-html-documents
2
+ from bs4 import BeautifulSoup
3
+ import requests
4
+ from DefaultPackages import openFile, saveFile
5
+ from NER import cleanText
6
+ import pandas as pd
7
+ class HTML():
8
+ def __init__(self, htmlFile, htmlLink):
9
+ self.htmlLink = htmlLink
10
+ self.htmlFile = htmlFile
11
+ # def openHTMLFile(self):
12
+ # headers = {
13
+ # "User-Agent": (
14
+ # "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
15
+ # "AppleWebKit/537.36 (KHTML, like Gecko) "
16
+ # "Chrome/114.0.0.0 Safari/537.36"
17
+ # ),
18
+ # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
19
+ # "Referer": self.htmlLink,
20
+ # "Connection": "keep-alive"
21
+ # }
22
+
23
+ # session = requests.Session()
24
+ # session.headers.update(headers)
25
+
26
+ # if self.htmlLink != "None":
27
+ # try:
28
+ # r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
29
+ # if r.status_code != 200:
30
+ # print(f"❌ HTML GET failed: {r.status_code} — {self.htmlLink}")
31
+ # return BeautifulSoup("", 'html.parser')
32
+ # soup = BeautifulSoup(r.content, 'html.parser')
33
+ # except Exception as e:
34
+ # print(f"❌ Exception fetching HTML: {e}")
35
+ # return BeautifulSoup("", 'html.parser')
36
+ # else:
37
+ # with open(self.htmlFile) as fp:
38
+ # soup = BeautifulSoup(fp, 'html.parser')
39
+ # return soup
40
+ from lxml.etree import ParserError, XMLSyntaxError
41
+
42
+ def openHTMLFile(self):
43
+ not_need_domain = ['https://broadinstitute.github.io/picard/',
44
+ 'https://software.broadinstitute.org/gatk/best-practices/',
45
+ 'https://www.ncbi.nlm.nih.gov/genbank/',
46
+ 'https://www.mitomap.org/']
47
+ headers = {
48
+ "User-Agent": (
49
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
50
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
51
+ "Chrome/114.0.0.0 Safari/537.36"
52
+ ),
53
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
54
+ "Referer": self.htmlLink,
55
+ "Connection": "keep-alive"
56
+ }
57
+
58
+ session = requests.Session()
59
+ session.headers.update(headers)
60
+ if self.htmlLink in not_need_domain:
61
+ return BeautifulSoup("", 'html.parser')
62
+ try:
63
+ if self.htmlLink and self.htmlLink != "None":
64
+ r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
65
+ if r.status_code != 200 or not r.text.strip():
66
+ print(f"❌ HTML GET failed ({r.status_code}) or empty page: {self.htmlLink}")
67
+ return BeautifulSoup("", 'html.parser')
68
+ soup = BeautifulSoup(r.content, 'html.parser')
69
+ else:
70
+ with open(self.htmlFile, encoding='utf-8') as fp:
71
+ soup = BeautifulSoup(fp, 'html.parser')
72
+ except (ParserError, XMLSyntaxError, OSError) as e:
73
+ print(f"🚫 HTML parse error for {self.htmlLink}: {type(e).__name__}")
74
+ return BeautifulSoup("", 'html.parser')
75
+ except Exception as e:
76
+ print(f"❌ General exception for {self.htmlLink}: {e}")
77
+ return BeautifulSoup("", 'html.parser')
78
+
79
+ return soup
80
+
81
+ def getText(self):
82
+ try:
83
+ soup = self.openHTMLFile()
84
+ s = soup.find_all("html")
85
+ text = ""
86
+ if s:
87
+ for t in range(len(s)):
88
+ text = s[t].get_text()
89
+ cl = cleanText.cleanGenText()
90
+ text = cl.removeExtraSpaceBetweenWords(text)
91
+ return text
92
+ except:
93
+ print("failed get text from html")
94
+ return ""
95
+ def getListSection(self, scienceDirect=None):
96
+ try:
97
+ json = {}
98
+ text = ""
99
+ textJson, textHTML = "",""
100
+ if scienceDirect == None:
101
+ # soup = self.openHTMLFile()
102
+ # # get list of section
103
+ # json = {}
104
+ # for h2Pos in range(len(soup.find_all('h2'))):
105
+ # if soup.find_all('h2')[h2Pos].text not in json:
106
+ # json[soup.find_all('h2')[h2Pos].text] = []
107
+ # if h2Pos + 1 < len(soup.find_all('h2')):
108
+ # content = soup.find_all('h2')[h2Pos].find_next("p")
109
+ # nexth2Content = soup.find_all('h2')[h2Pos+1].find_next("p")
110
+ # while content.text != nexth2Content.text:
111
+ # json[soup.find_all('h2')[h2Pos].text].append(content.text)
112
+ # content = content.find_next("p")
113
+ # else:
114
+ # content = soup.find_all('h2')[h2Pos].find_all_next("p",string=True)
115
+ # json[soup.find_all('h2')[h2Pos].text] = list(i.text for i in content)
116
+
117
+ soup = self.openHTMLFile()
118
+ h2_tags = soup.find_all('h2')
119
+ json = {}
120
+
121
+ for idx, h2 in enumerate(h2_tags):
122
+ section_title = h2.get_text(strip=True)
123
+ json.setdefault(section_title, [])
124
+
125
+ # Get paragraphs until next H2
126
+ next_h2 = h2_tags[idx+1] if idx+1 < len(h2_tags) else None
127
+ for p in h2.find_all_next("p"):
128
+ if next_h2 and p == next_h2:
129
+ break
130
+ json[section_title].append(p.get_text(strip=True))
131
+ # format
132
+ '''json = {'Abstract':[], 'Introduction':[], 'Methods'[],
133
+ 'Results':[], 'Discussion':[], 'References':[],
134
+ 'Acknowledgements':[], 'Author information':[], 'Ethics declarations':[],
135
+ 'Additional information':[], 'Electronic supplementary material':[],
136
+ 'Rights and permissions':[], 'About this article':[], 'Search':[], 'Navigation':[]}'''
137
+ if scienceDirect!= None or len(json)==0:
138
+ # Replace with your actual Elsevier API key
139
+ api_key = os.environ["SCIENCE_DIRECT_API"]
140
+ # ScienceDirect article DOI or PI (Example DOI)
141
+ doi = self.htmlLink.split("https://doi.org/")[-1] #"10.1016/j.ajhg.2011.01.009"
142
+ # Base URL for the Elsevier API
143
+ base_url = "https://api.elsevier.com/content/article/doi/"
144
+ # Set headers with API key
145
+ headers = {
146
+ "Accept": "application/json",
147
+ "X-ELS-APIKey": api_key
148
+ }
149
+ # Make the API request
150
+ response = requests.get(base_url + doi, headers=headers)
151
+ # Check if the request was successful
152
+ if response.status_code == 200:
153
+ data = response.json()
154
+ supp_data = data["full-text-retrieval-response"]#["coredata"]["link"]
155
+ # if "originalText" in list(supp_data.keys()):
156
+ # if type(supp_data["originalText"])==str:
157
+ # json["originalText"] = [supp_data["originalText"]]
158
+ # if type(supp_data["originalText"])==dict:
159
+ # json["originalText"] = [supp_data["originalText"][key] for key in supp_data["originalText"]]
160
+ # else:
161
+ # if type(supp_data)==dict:
162
+ # for key in supp_data:
163
+ # json[key] = [supp_data[key]]
164
+ if type(data)==dict:
165
+ json["fullText"] = data
166
+ textJson = self.mergeTextInJson(json)
167
+ textHTML = self.getText()
168
+ if len(textHTML) > len(textJson):
169
+ text = textHTML
170
+ else: text = textJson
171
+ return text #json
172
+ except:
173
+ print("failed all")
174
+ return ""
175
+ def getReference(self):
176
+ # get reference to collect more next data
177
+ ref = []
178
+ json = self.getListSection()
179
+ for key in json["References"]:
180
+ ct = cleanText.cleanGenText(key)
181
+ cleanText, filteredWord = ct.cleanText()
182
+ if cleanText not in ref:
183
+ ref.append(cleanText)
184
+ return ref
185
+ def getSupMaterial(self):
186
+ # check if there is material or not
187
+ json = {}
188
+ soup = self.openHTMLFile()
189
+ for h2Pos in range(len(soup.find_all('h2'))):
190
+ if "supplementary" in soup.find_all('h2')[h2Pos].text.lower() or "material" in soup.find_all('h2')[h2Pos].text.lower() or "additional" in soup.find_all('h2')[h2Pos].text.lower() or "support" in soup.find_all('h2')[h2Pos].text.lower():
191
+ #print(soup.find_all('h2')[h2Pos].find_next("a").get("href"))
192
+ link, output = [],[]
193
+ if soup.find_all('h2')[h2Pos].text not in json:
194
+ json[soup.find_all('h2')[h2Pos].text] = []
195
+ for l in soup.find_all('h2')[h2Pos].find_all_next("a",href=True):
196
+ link.append(l["href"])
197
+ if h2Pos + 1 < len(soup.find_all('h2')):
198
+ nexth2Link = soup.find_all('h2')[h2Pos+1].find_next("a",href=True)["href"]
199
+ if nexth2Link in link:
200
+ link = link[:link.index(nexth2Link)]
201
+ # only take links having "https" in that
202
+ for i in link:
203
+ if "https" in i: output.append(i)
204
+ json[soup.find_all('h2')[h2Pos].text].extend(output)
205
+ return json
206
+ def extractTable(self):
207
+ soup = self.openHTMLFile()
208
+ df = []
209
+ if len(soup)>0:
210
+ try:
211
+ df = pd.read_html(str(soup))
212
+ except ValueError:
213
+ df = []
214
+ print("No tables found in HTML file")
215
+ return df
216
+ def mergeTextInJson(self,jsonHTML):
217
+ try:
218
+ #cl = cleanText.cleanGenText()
219
+ htmlText = ""
220
+ if jsonHTML:
221
+ # try:
222
+ # for sec, entries in jsonHTML.items():
223
+ # for i, entry in enumerate(entries):
224
+ # # Only process if it's actually text
225
+ # if isinstance(entry, str):
226
+ # if entry.strip():
227
+ # entry, filteredWord = cl.textPreprocessing(entry, keepPeriod=True)
228
+ # else:
229
+ # # Skip or convert dicts/lists to string if needed
230
+ # entry = str(entry)
231
+
232
+ # jsonHTML[sec][i] = entry
233
+
234
+ # # Add spacing between sentences
235
+ # if i - 1 >= 0 and jsonHTML[sec][i - 1] and jsonHTML[sec][i - 1][-1] != ".":
236
+ # htmlText += ". "
237
+ # htmlText += entry
238
+
239
+ # # Add final period if needed
240
+ # if entries and isinstance(entries[-1], str) and entries[-1] and entries[-1][-1] != ".":
241
+ # htmlText += "."
242
+ # htmlText += "\n\n"
243
+ # except:
244
+ htmlText += str(jsonHTML)
245
+ return htmlText
246
+ except:
247
+ print("failed merge text in json")
248
+ return ""
249
+
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
data_preprocess.py CHANGED
@@ -1,747 +1,778 @@
1
- import re
2
- import os
3
- #import streamlit as st
4
- import subprocess
5
- import re
6
- from Bio import Entrez
7
- from docx import Document
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- #from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- import pandas as pd
19
- import model
20
- import pipeline
21
- import tempfile
22
- import nltk
23
- nltk.download('punkt_tab')
24
- def download_excel_file(url, save_path="temp.xlsx"):
25
- if "view.officeapps.live.com" in url:
26
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
27
- real_url = urllib.parse.unquote(parsed_url["src"][0])
28
- response = requests.get(real_url)
29
- with open(save_path, "wb") as f:
30
- f.write(response.content)
31
- return save_path
32
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
33
- response = requests.get(url)
34
- response.raise_for_status() # Raises error if download fails
35
- with open(save_path, "wb") as f:
36
- f.write(response.content)
37
- print(len(response.content))
38
- return save_path
39
- else:
40
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
41
- return url
42
- def extract_text(link,saveFolder):
43
- try:
44
- text = ""
45
- name = link.split("/")[-1]
46
- print("name: ", name)
47
- #file_path = Path(saveFolder) / name
48
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
49
- print("this is local temp path: ", local_temp_path)
50
- if os.path.exists(local_temp_path):
51
- input_to_class = local_temp_path
52
- print("exist")
53
- else:
54
- #input_to_class = link # Let the class handle downloading
55
- # 1. Check if file exists in shared Google Drive folder
56
- file_id = pipeline.find_drive_file(name, saveFolder)
57
- if file_id:
58
- print("📥 Downloading from Google Drive...")
59
- pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
60
- else:
61
- print("🌐 Downloading from web link...")
62
- response = requests.get(link)
63
- with open(local_temp_path, 'wb') as f:
64
- f.write(response.content)
65
- print("✅ Saved locally.")
66
-
67
- # 2. Upload to Drive so it's available for later
68
- pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
69
-
70
- input_to_class = local_temp_path
71
- print(input_to_class)
72
- # pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
73
- # pdf
74
- if link.endswith(".pdf"):
75
- # if file_path.is_file():
76
- # link = saveFolder + "/" + name
77
- # print("File exists.")
78
- #p = pdf.PDF(local_temp_path, saveFolder)
79
- print("inside pdf and input to class: ", input_to_class)
80
- print("save folder in extract text: ", saveFolder)
81
- p = pdf.PDF(input_to_class, saveFolder)
82
- #p = pdf.PDF(link,saveFolder)
83
- #text = p.extractTextWithPDFReader()
84
- text = p.extractText()
85
- print("text from pdf:")
86
- print(text)
87
- #text_exclude_table = p.extract_text_excluding_tables()
88
- # worddoc
89
- elif link.endswith(".doc") or link.endswith(".docx"):
90
- #d = wordDoc.wordDoc(local_temp_path,saveFolder)
91
- d = wordDoc.wordDoc(input_to_class,saveFolder)
92
- text = d.extractTextByPage()
93
- # html
94
- else:
95
- if link.split(".")[-1].lower() not in "xlsx":
96
- if "http" in link or "html" in link:
97
- print("html link: ", link)
98
- html = extractHTML.HTML("",link)
99
- text = html.getListSection() # the text already clean
100
- print("text html: ")
101
- print(text)
102
- # Cleanup: delete the local temp file
103
- if name:
104
- if os.path.exists(local_temp_path):
105
- os.remove(local_temp_path)
106
- print(f"🧹 Deleted local temp file: {local_temp_path}")
107
- print("done extract text")
108
- except:
109
- text = ""
110
- return text
111
-
112
- def extract_table(link,saveFolder):
113
- try:
114
- table = []
115
- name = link.split("/")[-1]
116
- #file_path = Path(saveFolder) / name
117
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
118
- if os.path.exists(local_temp_path):
119
- input_to_class = local_temp_path
120
- print("exist")
121
- else:
122
- #input_to_class = link # Let the class handle downloading
123
- # 1. Check if file exists in shared Google Drive folder
124
- file_id = pipeline.find_drive_file(name, saveFolder)
125
- if file_id:
126
- print("📥 Downloading from Google Drive...")
127
- pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
128
- else:
129
- print("🌐 Downloading from web link...")
130
- response = requests.get(link)
131
- with open(local_temp_path, 'wb') as f:
132
- f.write(response.content)
133
- print("✅ Saved locally.")
134
-
135
- # 2. Upload to Drive so it's available for later
136
- pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
137
-
138
- input_to_class = local_temp_path
139
- print(input_to_class)
140
- #pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
141
- # pdf
142
- if link.endswith(".pdf"):
143
- # if file_path.is_file():
144
- # link = saveFolder + "/" + name
145
- # print("File exists.")
146
- #p = pdf.PDF(local_temp_path,saveFolder)
147
- p = pdf.PDF(input_to_class,saveFolder)
148
- table = p.extractTable()
149
- # worddoc
150
- elif link.endswith(".doc") or link.endswith(".docx"):
151
- #d = wordDoc.wordDoc(local_temp_path,saveFolder)
152
- d = wordDoc.wordDoc(input_to_class,saveFolder)
153
- table = d.extractTableAsList()
154
- # excel
155
- elif link.split(".")[-1].lower() in "xlsx":
156
- # download excel file if it not downloaded yet
157
- savePath = saveFolder +"/"+ link.split("/")[-1]
158
- excelPath = download_excel_file(link, savePath)
159
- try:
160
- #xls = pd.ExcelFile(excelPath)
161
- xls = pd.ExcelFile(local_temp_path)
162
- table_list = []
163
- for sheet_name in xls.sheet_names:
164
- df = pd.read_excel(xls, sheet_name=sheet_name)
165
- cleaned_table = df.fillna("").astype(str).values.tolist()
166
- table_list.append(cleaned_table)
167
- table = table_list
168
- except Exception as e:
169
- print("❌ Failed to extract tables from Excel:", e)
170
- # html
171
- elif "http" in link or "html" in link:
172
- html = extractHTML.HTML("",link)
173
- table = html.extractTable() # table is a list
174
- table = clean_tables_format(table)
175
- # Cleanup: delete the local temp file
176
- if os.path.exists(local_temp_path):
177
- os.remove(local_temp_path)
178
- print(f"🧹 Deleted local temp file: {local_temp_path}")
179
- except:
180
- table = []
181
- return table
182
-
183
- def clean_tables_format(tables):
184
- """
185
- Ensures all tables are in consistent format: List[List[List[str]]]
186
- Cleans by:
187
- - Removing empty strings and rows
188
- - Converting all cells to strings
189
- - Handling DataFrames and list-of-lists
190
- """
191
- cleaned = []
192
- if tables:
193
- for table in tables:
194
- standardized = []
195
-
196
- # Case 1: Pandas DataFrame
197
- if isinstance(table, pd.DataFrame):
198
- table = table.fillna("").astype(str).values.tolist()
199
-
200
- # Case 2: List of Lists
201
- if isinstance(table, list) and all(isinstance(row, list) for row in table):
202
- for row in table:
203
- filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
204
- if filtered_row:
205
- standardized.append(filtered_row)
206
-
207
- if standardized:
208
- cleaned.append(standardized)
209
-
210
- return cleaned
211
-
212
- import json
213
- def normalize_text_for_comparison(s: str) -> str:
214
- """
215
- Normalizes text for robust comparison by:
216
- 1. Converting to lowercase.
217
- 2. Replacing all types of newlines with a single consistent newline (\n).
218
- 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
219
- 4. Stripping leading/trailing whitespace from the entire string.
220
- """
221
- s = s.lower()
222
- s = s.replace('\r\n', '\n') # Handle Windows newlines
223
- s = s.replace('\r', '\n') # Handle Mac classic newlines
224
-
225
- # Replace sequences of whitespace (including multiple newlines) with a single space
226
- # This might be too aggressive if you need to preserve paragraph breaks,
227
- # but good for exact word-sequence matching.
228
- s = re.sub(r'\s+', ' ', s)
229
-
230
- return s.strip()
231
- def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
232
- """
233
- Merge cleaned text and table into one string for LLM input.
234
- - Avoids duplicating tables already in text
235
- - Extracts only relevant rows from large tables
236
- - Skips or saves oversized tables
237
- """
238
- import importlib
239
- json = importlib.import_module("json")
240
-
241
- def estimate_tokens(text_str):
242
- try:
243
- enc = tiktoken.get_encoding(tokenizer)
244
- return len(enc.encode(text_str))
245
- except:
246
- return len(text_str) // 4 # Fallback estimate
247
-
248
- def is_table_relevant(table, keywords, accession_id=None):
249
- flat = " ".join(" ".join(row).lower() for row in table)
250
- if accession_id and accession_id.lower() in flat:
251
- return True
252
- return any(kw.lower() in flat for kw in keywords)
253
- preview, preview1 = "",""
254
- llm_input = "## Document Text\n" + text.strip() + "\n"
255
- clean_text = normalize_text_for_comparison(text)
256
-
257
- if tables:
258
- for idx, table in enumerate(tables):
259
- keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
260
- if accession_id: keywords += [accession_id.lower()]
261
- if isolate: keywords += [isolate.lower()]
262
- if is_table_relevant(table, keywords, accession_id):
263
- if len(table) > 0:
264
- for tab in table:
265
- preview = " ".join(tab) if tab else ""
266
- preview1 = "\n".join(tab) if tab else ""
267
- clean_preview = normalize_text_for_comparison(preview)
268
- clean_preview1 = normalize_text_for_comparison(preview1)
269
- if clean_preview not in clean_text:
270
- if clean_preview1 not in clean_text:
271
- table_str = json.dumps([tab], indent=2)
272
- llm_input += f"## Table {idx+1}\n{table_str}\n"
273
- return llm_input.strip()
274
-
275
- def preprocess_document(link, saveFolder, accession=None, isolate=None):
276
- try:
277
- text = extract_text(link, saveFolder)
278
- print("text and link")
279
- print(link)
280
- print(text)
281
- except: text = ""
282
- try:
283
- tables = extract_table(link, saveFolder)
284
- except: tables = []
285
- if accession: accession = accession
286
- if isolate: isolate = isolate
287
- try:
288
- final_input = merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
289
- except: final_input = ""
290
- return text, tables, final_input
291
-
292
- def extract_sentences(text):
293
- sentences = re.split(r'(?<=[.!?])\s+', text)
294
- return [s.strip() for s in sentences if s.strip()]
295
-
296
- def is_irrelevant_number_sequence(text):
297
- if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
298
- return False
299
- word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
300
- number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
301
- total_tokens = len(re.findall(r'\S+', text))
302
- if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
303
- return True
304
- elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
305
- return True
306
- return False
307
-
308
- def remove_isolated_single_digits(sentence):
309
- tokens = sentence.split()
310
- filtered_tokens = []
311
- for token in tokens:
312
- if token == '0' or token == '1':
313
- pass
314
- else:
315
- filtered_tokens.append(token)
316
- return ' '.join(filtered_tokens).strip()
317
-
318
- def get_contextual_sentences_BFS(text_content, keyword, depth=2):
319
- def extract_codes(sentence):
320
- # Match codes like 'A1YU101', 'KM1', 'MO6' — at least 2 letters + numbers
321
- return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
322
- sentences = extract_sentences(text_content)
323
- relevant_sentences = set()
324
- initial_keywords = set()
325
-
326
- # Define a regex to capture codes like A1YU101 or KM1
327
- # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
328
- code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
329
-
330
- # Attempt to parse the keyword into its prefix and numerical part using re.search
331
- keyword_match = code_pattern.search(keyword)
332
-
333
- keyword_prefix = None
334
- keyword_num = None
335
-
336
- if keyword_match:
337
- keyword_prefix = keyword_match.group(1).lower()
338
- keyword_num = int(keyword_match.group(2))
339
-
340
- for sentence in sentences:
341
- sentence_added = False
342
-
343
- # 1. Check for exact match of the keyword
344
- if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
345
- relevant_sentences.add(sentence.strip())
346
- initial_keywords.add(keyword.lower())
347
- sentence_added = True
348
-
349
- # 2. Check for range patterns (e.g., A1YU101-A1YU137)
350
- # The range pattern should be broad enough to capture the full code string within the range.
351
- range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
352
- range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
353
-
354
- for r_match in range_matches:
355
- start_code_str = r_match.group(1)
356
- end_code_str = r_match.group(2)
357
-
358
- # CRITICAL FIX: Use code_pattern.search for start_match and end_match
359
- start_match = code_pattern.search(start_code_str)
360
- end_match = code_pattern.search(end_code_str)
361
-
362
- if keyword_prefix and keyword_num is not None and start_match and end_match:
363
- start_prefix = start_match.group(1).lower()
364
- end_prefix = end_match.group(1).lower()
365
- start_num = int(start_match.group(2))
366
- end_num = int(end_match.group(2))
367
-
368
- # Check if the keyword's prefix matches and its number is within the range
369
- if keyword_prefix == start_prefix and \
370
- keyword_prefix == end_prefix and \
371
- start_num <= keyword_num <= end_num:
372
- relevant_sentences.add(sentence.strip())
373
- initial_keywords.add(start_code_str.lower())
374
- initial_keywords.add(end_code_str.lower())
375
- sentence_added = True
376
- break # Only need to find one matching range per sentence
377
-
378
- # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
379
- # to initial_keywords to ensure graph traversal from related terms.
380
- if sentence_added:
381
- for word in extract_codes(sentence):
382
- initial_keywords.add(word.lower())
383
-
384
-
385
- # Build word_to_sentences mapping for all sentences
386
- word_to_sentences = {}
387
- for sent in sentences:
388
- codes_in_sent = set(extract_codes(sent))
389
- for code in codes_in_sent:
390
- word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
391
-
392
-
393
- # Build the graph
394
- graph = {}
395
- for sent in sentences:
396
- codes = set(extract_codes(sent))
397
- for word1 in codes:
398
- word1_lower = word1.lower()
399
- graph.setdefault(word1_lower, set())
400
- for word2 in codes:
401
- word2_lower = word2.lower()
402
- if word1_lower != word2_lower:
403
- graph[word1_lower].add(word2_lower)
404
-
405
-
406
- # Perform BFS/graph traversal
407
- queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
408
- visited_words = set(initial_keywords)
409
-
410
- while queue:
411
- current_word, level = queue.pop(0)
412
- if level >= depth:
413
- continue
414
-
415
- relevant_sentences.update(word_to_sentences.get(current_word, []))
416
-
417
- for neighbor in graph.get(current_word, []):
418
- if neighbor not in visited_words:
419
- visited_words.add(neighbor)
420
- queue.append((neighbor, level + 1))
421
-
422
- final_sentences = set()
423
- for sentence in relevant_sentences:
424
- if not is_irrelevant_number_sequence(sentence):
425
- processed_sentence = remove_isolated_single_digits(sentence)
426
- if processed_sentence:
427
- final_sentences.add(processed_sentence)
428
-
429
- return "\n".join(sorted(list(final_sentences)))
430
-
431
-
432
-
433
- def get_contextual_sentences_DFS(text_content, keyword, depth=2):
434
- sentences = extract_sentences(text_content)
435
-
436
- # Build word-to-sentences mapping
437
- word_to_sentences = {}
438
- for sent in sentences:
439
- words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
440
- for word in words_in_sent:
441
- word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
442
-
443
- # Function to extract codes in a sentence
444
- def extract_codes(sentence):
445
- # Only codes like 'KSK1', 'MG272794', not pure numbers
446
- return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
447
-
448
- # DFS with priority based on distance to keyword and early stop if country found
449
- def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
450
- country = "unknown"
451
- if current_depth > max_depth:
452
- return country, False
453
-
454
- if current_word not in word_to_sentences:
455
- return country, False
456
-
457
- for sentence in word_to_sentences[current_word]:
458
- if sentence == parent_sentence:
459
- continue # avoid reusing the same sentence
460
-
461
- collected_sentences.add(sentence)
462
-
463
- #print("current_word:", current_word)
464
- small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
465
- #print(small_sen)
466
- country = model.get_country_from_text(small_sen)
467
- #print("small context country:", country)
468
- if country.lower() != "unknown":
469
- return country, True
470
- else:
471
- country = model.get_country_from_text(sentence)
472
- #print("full sentence country:", country)
473
- if country.lower() != "unknown":
474
- return country, True
475
-
476
- codes_in_sentence = extract_codes(sentence)
477
- idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
478
- if idx is None:
479
- continue
480
-
481
- sorted_children = sorted(
482
- [code for code in codes_in_sentence if code.lower() not in visited_words],
483
- key=lambda x: (abs(codes_in_sentence.index(x) - idx),
484
- 0 if codes_in_sentence.index(x) > idx else 1)
485
- )
486
-
487
- #print("sorted_children:", sorted_children)
488
- for child in sorted_children:
489
- child_lower = child.lower()
490
- if child_lower not in visited_words:
491
- visited_words.add(child_lower)
492
- country, should_stop = dfs_traverse(
493
- child_lower, current_depth + 1, max_depth,
494
- visited_words, collected_sentences, parent_sentence=sentence
495
- )
496
- if should_stop:
497
- return country, True
498
-
499
- return country, False
500
-
501
- # Begin DFS
502
- collected_sentences = set()
503
- visited_words = set([keyword.lower()])
504
- country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
505
-
506
- # Filter irrelevant sentences
507
- final_sentences = set()
508
- for sentence in collected_sentences:
509
- if not is_irrelevant_number_sequence(sentence):
510
- processed = remove_isolated_single_digits(sentence)
511
- if processed:
512
- final_sentences.add(processed)
513
- if not final_sentences:
514
- return country, text_content
515
- return country, "\n".join(sorted(list(final_sentences)))
516
-
517
- # Helper function for normalizing text for overlap comparison
518
- def normalize_for_overlap(s: str) -> str:
519
- s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
520
- s = re.sub(r'\s+', ' ', s).strip()
521
- return s
522
-
523
- def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
524
- if not text1: return text2
525
- if not text2: return text1
526
-
527
- # Case 1: text2 is fully contained in text1 or vice-versa
528
- if text2 in text1:
529
- return text1
530
- if text1 in text2:
531
- return text2
532
-
533
- # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
534
- # This is what your function was primarily designed for.
535
- # It looks for the overlap at the "junction" of text1 and text2.
536
-
537
- max_junction_overlap = 0
538
- for i in range(min(len(text1), len(text2)), 0, -1):
539
- suffix1 = text1[-i:]
540
- prefix2 = text2[:i]
541
- # Prioritize exact match, then normalized match
542
- if suffix1 == prefix2:
543
- max_junction_overlap = i
544
- break
545
- elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
546
- max_junction_overlap = i
547
- break # Take the first (longest) normalized match
548
-
549
- if max_junction_overlap > 0:
550
- merged_text = text1 + text2[max_junction_overlap:]
551
- return re.sub(r'\s+', ' ', merged_text).strip()
552
-
553
- # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
554
- # This addresses your specific test case where the overlap is at the very beginning of both strings.
555
- # This is often used when trying to deduplicate content that shares a common start.
556
-
557
- longest_common_prefix_len = 0
558
- min_len = min(len(text1), len(text2))
559
- for i in range(min_len):
560
- if text1[i] == text2[i]:
561
- longest_common_prefix_len = i + 1
562
- else:
563
- break
564
-
565
- # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
566
- # AND the remaining parts are distinct, then apply this merge.
567
- # This is a heuristic and might need fine-tuning.
568
- if longest_common_prefix_len > 0 and \
569
- text1[longest_common_prefix_len:].strip() and \
570
- text2[longest_common_prefix_len:].strip():
571
-
572
- # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
573
- # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
574
- # common prefix is "Hi, I am Vy."
575
- # Remaining text1: " Nice to meet you."
576
- # Remaining text2: " Goodbye Vy."
577
- # So we merge common_prefix + remaining_text1 + remaining_text2
578
-
579
- common_prefix_str = text1[:longest_common_prefix_len]
580
- remainder_text1 = text1[longest_common_prefix_len:]
581
- remainder_text2 = text2[longest_common_prefix_len:]
582
-
583
- merged_text = common_prefix_str + remainder_text1 + remainder_text2
584
- return re.sub(r'\s+', ' ', merged_text).strip()
585
-
586
-
587
- # If neither specific overlap type is found, just concatenate
588
- merged_text = text1 + text2
589
- return re.sub(r'\s+', ' ', merged_text).strip()
590
-
591
- from docx import Document
592
- from pipeline import upload_file_to_drive
593
- # def save_text_to_docx(text_content: str, file_path: str):
594
- # """
595
- # Saves a given text string into a .docx file.
596
-
597
- # Args:
598
- # text_content (str): The text string to save.
599
- # file_path (str): The full path including the filename where the .docx file will be saved.
600
- # Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
601
- # """
602
- # try:
603
- # document = Document()
604
-
605
- # # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
606
- # for paragraph_text in text_content.split('\n'):
607
- # document.add_paragraph(paragraph_text)
608
-
609
- # document.save(file_path)
610
- # print(f"Text successfully saved to '{file_path}'")
611
- # except Exception as e:
612
- # print(f"Error saving text to docx file: {e}")
613
- # def save_text_to_docx(text_content: str, filename: str, drive_folder_id: str):
614
- # """
615
- # Saves a given text string into a .docx file locally, then uploads to Google Drive.
616
-
617
- # Args:
618
- # text_content (str): The text string to save.
619
- # filename (str): The target .docx file name, e.g. 'BRU18_merged_document.docx'.
620
- # drive_folder_id (str): Google Drive folder ID where to upload the file.
621
- # """
622
- # try:
623
- # # Save to temporary local path first
624
- # print("file name: ", filename)
625
- # print("length text content: ", len(text_content))
626
- # local_path = os.path.join(tempfile.gettempdir(), filename)
627
- # document = Document()
628
- # for paragraph_text in text_content.split('\n'):
629
- # document.add_paragraph(paragraph_text)
630
- # document.save(local_path)
631
- # print(f"✅ Text saved locally to: {local_path}")
632
-
633
- # # ✅ Upload to Drive
634
- # pipeline.upload_file_to_drive(local_path, filename, drive_folder_id)
635
- # print(f"✅ Uploaded '{filename}' to Google Drive folder ID: {drive_folder_id}")
636
-
637
- # except Exception as e:
638
- # print(f"❌ Error saving or uploading DOCX: {e}")
639
- def save_text_to_docx(text_content: str, full_local_path: str):
640
- document = Document()
641
- for paragraph_text in text_content.split('\n'):
642
- document.add_paragraph(paragraph_text)
643
- document.save(full_local_path)
644
- print(f"✅ Saved DOCX locally: {full_local_path}")
645
-
646
-
647
-
648
- '''2 scenerios:
649
- - quick look then found then deepdive and directly get location then stop
650
- - quick look then found then deepdive but not find location then hold the related words then
651
- look another files iteratively for each related word and find location and stop'''
652
- def extract_context(text, keyword, window=500):
653
- # firstly try accession number
654
- code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
655
-
656
- # Attempt to parse the keyword into its prefix and numerical part using re.search
657
- keyword_match = code_pattern.search(keyword)
658
-
659
- keyword_prefix = None
660
- keyword_num = None
661
-
662
- if keyword_match:
663
- keyword_prefix = keyword_match.group(1).lower()
664
- keyword_num = int(keyword_match.group(2))
665
- text = text.lower()
666
- idx = text.find(keyword.lower())
667
- if idx == -1:
668
- if keyword_prefix:
669
- idx = text.find(keyword_prefix)
670
- if idx == -1:
671
- return "Sample ID not found."
672
- return text[max(0, idx-window): idx+window]
673
- return text[max(0, idx-window): idx+window]
674
- def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
675
- cache = {}
676
- country = "unknown"
677
- output = ""
678
- tem_output, small_output = "",""
679
- keyword_appear = (False,"")
680
- keywords = []
681
- if isolate: keywords.append(isolate)
682
- if accession: keywords.append(accession)
683
- for f in filePaths:
684
- # scenerio 1: direct location: truncate the context and then use qa model?
685
- if keywords:
686
- for keyword in keywords:
687
- text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
688
- if keyword in final_input:
689
- context = extract_context(final_input, keyword)
690
- # quick look if country already in context and if yes then return
691
- country = model.get_country_from_text(context)
692
- if country != "unknown":
693
- return country, context, final_input
694
- else:
695
- country = model.get_country_from_text(final_input)
696
- if country != "unknown":
697
- return country, context, final_input
698
- else: # might be cross-ref
699
- keyword_appear = (True, f)
700
- cache[f] = context
701
- small_output = merge_texts_skipping_overlap(output, context) + "\n"
702
- chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
703
- countryBFS = model.get_country_from_text(chunkBFS)
704
- countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
705
- output = merge_texts_skipping_overlap(output, final_input)
706
- if countryDFS != "unknown" and countryBFS != "unknown":
707
- if len(chunkDFS) <= len(chunkBFS):
708
- return countryDFS, chunkDFS, output
709
- else:
710
- return countryBFS, chunkBFS, output
711
- else:
712
- if countryDFS != "unknown":
713
- return countryDFS, chunkDFS, output
714
- if countryBFS != "unknown":
715
- return countryBFS, chunkBFS, output
716
- else:
717
- # scenerio 2:
718
- '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
719
- but if we look at file 1 first then maybe we can have lookup dict which country
720
- such as Thailand as the key and its re'''
721
- cache[f] = final_input
722
- if keyword_appear[0] == True:
723
- for c in cache:
724
- if c!=keyword_appear[1]:
725
- if cache[c].lower() not in output.lower():
726
- output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
727
- chunkBFS = get_contextual_sentences_BFS(output, keyword)
728
- countryBFS = model.get_country_from_text(chunkBFS)
729
- countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
730
- if countryDFS != "unknown" and countryBFS != "unknown":
731
- if len(chunkDFS) <= len(chunkBFS):
732
- return countryDFS, chunkDFS, output
733
- else:
734
- return countryBFS, chunkBFS, output
735
- else:
736
- if countryDFS != "unknown":
737
- return countryDFS, chunkDFS, output
738
- if countryBFS != "unknown":
739
- return countryBFS, chunkBFS, output
740
- else:
741
- if cache[f].lower() not in output.lower():
742
- output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
743
- if len(output) == 0 or keyword_appear[0]==False:
744
- for c in cache:
745
- if cache[c].lower() not in output.lower():
746
- output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
747
  return country, "", output
 
1
+ import re
2
+ import os
3
+ #import streamlit as st
4
+ import subprocess
5
+ import re
6
+ from Bio import Entrez
7
+ from docx import Document
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ #from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ import pandas as pd
19
+ import model
20
+ import pipeline
21
+ import tempfile
22
+ import nltk
23
+ nltk.download('punkt_tab')
24
+ def download_excel_file(url, save_path="temp.xlsx"):
25
+ if "view.officeapps.live.com" in url:
26
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
27
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
28
+ response = requests.get(real_url)
29
+ with open(save_path, "wb") as f:
30
+ f.write(response.content)
31
+ return save_path
32
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
33
+ response = requests.get(url)
34
+ response.raise_for_status() # Raises error if download fails
35
+ with open(save_path, "wb") as f:
36
+ f.write(response.content)
37
+ print(len(response.content))
38
+ return save_path
39
+ else:
40
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
41
+ return url
42
+ def extract_text(link,saveFolder):
43
+ try:
44
+ text = ""
45
+ name = link.split("/")[-1]
46
+ print("name: ", name)
47
+ #file_path = Path(saveFolder) / name
48
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
49
+ print("this is local temp path: ", local_temp_path)
50
+ if os.path.exists(local_temp_path):
51
+ input_to_class = local_temp_path
52
+ print("exist")
53
+ else:
54
+ #input_to_class = link # Let the class handle downloading
55
+ # 1. Check if file exists in shared Google Drive folder
56
+ file_id = pipeline.find_drive_file(name, saveFolder)
57
+ if file_id:
58
+ print("📥 Downloading from Google Drive...")
59
+ pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
60
+ else:
61
+ print("🌐 Downloading from web link...")
62
+ response = requests.get(link)
63
+ with open(local_temp_path, 'wb') as f:
64
+ f.write(response.content)
65
+ print("✅ Saved locally.")
66
+
67
+ # 2. Upload to Drive so it's available for later
68
+ pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
69
+
70
+ input_to_class = local_temp_path
71
+ print(input_to_class)
72
+ # pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
73
+ # pdf
74
+ if link.endswith(".pdf"):
75
+ # if file_path.is_file():
76
+ # link = saveFolder + "/" + name
77
+ # print("File exists.")
78
+ #p = pdf.PDF(local_temp_path, saveFolder)
79
+ print("inside pdf and input to class: ", input_to_class)
80
+ print("save folder in extract text: ", saveFolder)
81
+ #p = pdf.PDF(input_to_class, saveFolder)
82
+ #p = pdf.PDF(link,saveFolder)
83
+ #text = p.extractTextWithPDFReader()
84
+ #text = p.extractText()
85
+ p = pdf.PDFFast(input_to_class, saveFolder)
86
+ text = p.extract_text()
87
+
88
+ print("len text from pdf:")
89
+ print(len(text))
90
+ #text_exclude_table = p.extract_text_excluding_tables()
91
+ # worddoc
92
+ elif link.endswith(".doc") or link.endswith(".docx"):
93
+ #d = wordDoc.wordDoc(local_temp_path,saveFolder)
94
+ # d = wordDoc.wordDoc(input_to_class,saveFolder)
95
+ # text = d.extractTextByPage()
96
+ d = wordDoc.WordDocFast(input_to_class, saveFolder)
97
+ text = d.extractText()
98
+
99
+ # html
100
+ else:
101
+ if link.split(".")[-1].lower() not in "xlsx":
102
+ if "http" in link or "html" in link:
103
+ print("html link: ", link)
104
+ html = extractHTML.HTML("",link)
105
+ text = html.getListSection() # the text already clean
106
+ print("len text html: ")
107
+ print(len(text))
108
+ # Cleanup: delete the local temp file
109
+ if name:
110
+ if os.path.exists(local_temp_path):
111
+ os.remove(local_temp_path)
112
+ print(f"🧹 Deleted local temp file: {local_temp_path}")
113
+ print("done extract text")
114
+ except:
115
+ text = ""
116
+ return text
117
+
118
+ def extract_table(link,saveFolder):
119
+ try:
120
+ table = []
121
+ name = link.split("/")[-1]
122
+ #file_path = Path(saveFolder) / name
123
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
124
+ if os.path.exists(local_temp_path):
125
+ input_to_class = local_temp_path
126
+ print("exist")
127
+ else:
128
+ #input_to_class = link # Let the class handle downloading
129
+ # 1. Check if file exists in shared Google Drive folder
130
+ file_id = pipeline.find_drive_file(name, saveFolder)
131
+ if file_id:
132
+ print("📥 Downloading from Google Drive...")
133
+ pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
134
+ else:
135
+ print("🌐 Downloading from web link...")
136
+ response = requests.get(link)
137
+ with open(local_temp_path, 'wb') as f:
138
+ f.write(response.content)
139
+ print("✅ Saved locally.")
140
+
141
+ # 2. Upload to Drive so it's available for later
142
+ pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
143
+
144
+ input_to_class = local_temp_path
145
+ print(input_to_class)
146
+ #pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
147
+ # pdf
148
+ if link.endswith(".pdf"):
149
+ # if file_path.is_file():
150
+ # link = saveFolder + "/" + name
151
+ # print("File exists.")
152
+ #p = pdf.PDF(local_temp_path,saveFolder)
153
+ p = pdf.PDF(input_to_class,saveFolder)
154
+ table = p.extractTable()
155
+ # worddoc
156
+ elif link.endswith(".doc") or link.endswith(".docx"):
157
+ #d = wordDoc.wordDoc(local_temp_path,saveFolder)
158
+ # d = wordDoc.wordDoc(input_to_class,saveFolder)
159
+ # table = d.extractTableAsList()
160
+ d = wordDoc.WordDocFast(input_to_class, saveFolder)
161
+ table = d.extractTableAsList()
162
+ # excel
163
+ elif link.split(".")[-1].lower() in "xlsx":
164
+ # download excel file if it not downloaded yet
165
+ savePath = saveFolder +"/"+ link.split("/")[-1]
166
+ excelPath = download_excel_file(link, savePath)
167
+ try:
168
+ #xls = pd.ExcelFile(excelPath)
169
+ xls = pd.ExcelFile(local_temp_path)
170
+ table_list = []
171
+ for sheet_name in xls.sheet_names:
172
+ df = pd.read_excel(xls, sheet_name=sheet_name)
173
+ cleaned_table = df.fillna("").astype(str).values.tolist()
174
+ table_list.append(cleaned_table)
175
+ table = table_list
176
+ except Exception as e:
177
+ print("❌ Failed to extract tables from Excel:", e)
178
+ # html
179
+ elif "http" in link or "html" in link:
180
+ html = extractHTML.HTML("",link)
181
+ table = html.extractTable() # table is a list
182
+ table = clean_tables_format(table)
183
+ # Cleanup: delete the local temp file
184
+ if os.path.exists(local_temp_path):
185
+ os.remove(local_temp_path)
186
+ print(f"🧹 Deleted local temp file: {local_temp_path}")
187
+ except:
188
+ table = []
189
+ return table
190
+
191
+ def clean_tables_format(tables):
192
+ """
193
+ Ensures all tables are in consistent format: List[List[List[str]]]
194
+ Cleans by:
195
+ - Removing empty strings and rows
196
+ - Converting all cells to strings
197
+ - Handling DataFrames and list-of-lists
198
+ """
199
+ cleaned = []
200
+ if tables:
201
+ for table in tables:
202
+ standardized = []
203
+
204
+ # Case 1: Pandas DataFrame
205
+ if isinstance(table, pd.DataFrame):
206
+ table = table.fillna("").astype(str).values.tolist()
207
+
208
+ # Case 2: List of Lists
209
+ if isinstance(table, list) and all(isinstance(row, list) for row in table):
210
+ for row in table:
211
+ filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
212
+ if filtered_row:
213
+ standardized.append(filtered_row)
214
+
215
+ if standardized:
216
+ cleaned.append(standardized)
217
+
218
+ return cleaned
219
+
220
+ import json
221
+ def normalize_text_for_comparison(s: str) -> str:
222
+ """
223
+ Normalizes text for robust comparison by:
224
+ 1. Converting to lowercase.
225
+ 2. Replacing all types of newlines with a single consistent newline (\n).
226
+ 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
227
+ 4. Stripping leading/trailing whitespace from the entire string.
228
+ """
229
+ s = s.lower()
230
+ s = s.replace('\r\n', '\n') # Handle Windows newlines
231
+ s = s.replace('\r', '\n') # Handle Mac classic newlines
232
+
233
+ # Replace sequences of whitespace (including multiple newlines) with a single space
234
+ # This might be too aggressive if you need to preserve paragraph breaks,
235
+ # but good for exact word-sequence matching.
236
+ s = re.sub(r'\s+', ' ', s)
237
+
238
+ return s.strip()
239
+ def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
240
+ """
241
+ Merge cleaned text and table into one string for LLM input.
242
+ - Avoids duplicating tables already in text
243
+ - Extracts only relevant rows from large tables
244
+ - Skips or saves oversized tables
245
+ """
246
+ import importlib
247
+ json = importlib.import_module("json")
248
+
249
+ def estimate_tokens(text_str):
250
+ try:
251
+ enc = tiktoken.get_encoding(tokenizer)
252
+ return len(enc.encode(text_str))
253
+ except:
254
+ return len(text_str) // 4 # Fallback estimate
255
+
256
+ def is_table_relevant(table, keywords, accession_id=None):
257
+ flat = " ".join(" ".join(row).lower() for row in table)
258
+ if accession_id and accession_id.lower() in flat:
259
+ return True
260
+ return any(kw.lower() in flat for kw in keywords)
261
+ preview, preview1 = "",""
262
+ llm_input = "## Document Text\n" + text.strip() + "\n"
263
+ clean_text = normalize_text_for_comparison(text)
264
+
265
+ if tables:
266
+ for idx, table in enumerate(tables):
267
+ keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
268
+ if accession_id: keywords += [accession_id.lower()]
269
+ if isolate: keywords += [isolate.lower()]
270
+ if is_table_relevant(table, keywords, accession_id):
271
+ if len(table) > 0:
272
+ for tab in table:
273
+ preview = " ".join(tab) if tab else ""
274
+ preview1 = "\n".join(tab) if tab else ""
275
+ clean_preview = normalize_text_for_comparison(preview)
276
+ clean_preview1 = normalize_text_for_comparison(preview1)
277
+ if clean_preview not in clean_text:
278
+ if clean_preview1 not in clean_text:
279
+ table_str = json.dumps([tab], indent=2)
280
+ llm_input += f"## Table {idx+1}\n{table_str}\n"
281
+ return llm_input.strip()
282
+
283
+ def preprocess_document(link, saveFolder, accession=None, isolate=None, article_text=None):
284
+ if article_text:
285
+ print("article text already available")
286
+ text = article_text
287
+ else:
288
+ try:
289
+ print("start preprocess and extract text")
290
+ text = extract_text(link, saveFolder)
291
+ except: text = ""
292
+ try:
293
+ print("extract table start")
294
+ success, the_output = pipeline.run_with_timeout(extract_table,args=(link,saveFolder),timeout=10)
295
+ print("Returned from timeout logic")
296
+ if success:
297
+ tables = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
298
+ print("yes succeed for extract table")
299
+ else:
300
+ print("not suceed etxract table")
301
+ tables = []
302
+ #tables = extract_table(link, saveFolder)
303
+ except: tables = []
304
+ if accession: accession = accession
305
+ if isolate: isolate = isolate
306
+ try:
307
+ # print("merge text and table start")
308
+ # success, the_output = pipeline.run_with_timeout(merge_text_and_tables,kwargs={"text":text,"tables":tables,"accession_id":accession, "isolate":isolate},timeout=30)
309
+ # print("Returned from timeout logic")
310
+ # if success:
311
+ # final_input = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
312
+ # print("yes succeed")
313
+ # else:
314
+ # print("not suceed")
315
+ print("just merge text and tables")
316
+ final_input = text + ", ".join(tables)
317
+ #final_input = pipeline.timeout(merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
318
+ except:
319
+ print("no succeed here in preprocess docu")
320
+ final_input = ""
321
+ return text, tables, final_input
322
+
323
+ def extract_sentences(text):
324
+ sentences = re.split(r'(?<=[.!?])\s+', text)
325
+ return [s.strip() for s in sentences if s.strip()]
326
+
327
+ def is_irrelevant_number_sequence(text):
328
+ if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
329
+ return False
330
+ word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
331
+ number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
332
+ total_tokens = len(re.findall(r'\S+', text))
333
+ if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
334
+ return True
335
+ elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
336
+ return True
337
+ return False
338
+
339
+ def remove_isolated_single_digits(sentence):
340
+ tokens = sentence.split()
341
+ filtered_tokens = []
342
+ for token in tokens:
343
+ if token == '0' or token == '1':
344
+ pass
345
+ else:
346
+ filtered_tokens.append(token)
347
+ return ' '.join(filtered_tokens).strip()
348
+
349
+ def get_contextual_sentences_BFS(text_content, keyword, depth=2):
350
+ def extract_codes(sentence):
351
+ # Match codes like 'A1YU101', 'KM1', 'MO6' at least 2 letters + numbers
352
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
353
+ sentences = extract_sentences(text_content)
354
+ relevant_sentences = set()
355
+ initial_keywords = set()
356
+
357
+ # Define a regex to capture codes like A1YU101 or KM1
358
+ # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
359
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
360
+
361
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
362
+ keyword_match = code_pattern.search(keyword)
363
+
364
+ keyword_prefix = None
365
+ keyword_num = None
366
+
367
+ if keyword_match:
368
+ keyword_prefix = keyword_match.group(1).lower()
369
+ keyword_num = int(keyword_match.group(2))
370
+
371
+ for sentence in sentences:
372
+ sentence_added = False
373
+
374
+ # 1. Check for exact match of the keyword
375
+ if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
376
+ relevant_sentences.add(sentence.strip())
377
+ initial_keywords.add(keyword.lower())
378
+ sentence_added = True
379
+
380
+ # 2. Check for range patterns (e.g., A1YU101-A1YU137)
381
+ # The range pattern should be broad enough to capture the full code string within the range.
382
+ range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
383
+ range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
384
+
385
+ for r_match in range_matches:
386
+ start_code_str = r_match.group(1)
387
+ end_code_str = r_match.group(2)
388
+
389
+ # CRITICAL FIX: Use code_pattern.search for start_match and end_match
390
+ start_match = code_pattern.search(start_code_str)
391
+ end_match = code_pattern.search(end_code_str)
392
+
393
+ if keyword_prefix and keyword_num is not None and start_match and end_match:
394
+ start_prefix = start_match.group(1).lower()
395
+ end_prefix = end_match.group(1).lower()
396
+ start_num = int(start_match.group(2))
397
+ end_num = int(end_match.group(2))
398
+
399
+ # Check if the keyword's prefix matches and its number is within the range
400
+ if keyword_prefix == start_prefix and \
401
+ keyword_prefix == end_prefix and \
402
+ start_num <= keyword_num <= end_num:
403
+ relevant_sentences.add(sentence.strip())
404
+ initial_keywords.add(start_code_str.lower())
405
+ initial_keywords.add(end_code_str.lower())
406
+ sentence_added = True
407
+ break # Only need to find one matching range per sentence
408
+
409
+ # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
410
+ # to initial_keywords to ensure graph traversal from related terms.
411
+ if sentence_added:
412
+ for word in extract_codes(sentence):
413
+ initial_keywords.add(word.lower())
414
+
415
+
416
+ # Build word_to_sentences mapping for all sentences
417
+ word_to_sentences = {}
418
+ for sent in sentences:
419
+ codes_in_sent = set(extract_codes(sent))
420
+ for code in codes_in_sent:
421
+ word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
422
+
423
+
424
+ # Build the graph
425
+ graph = {}
426
+ for sent in sentences:
427
+ codes = set(extract_codes(sent))
428
+ for word1 in codes:
429
+ word1_lower = word1.lower()
430
+ graph.setdefault(word1_lower, set())
431
+ for word2 in codes:
432
+ word2_lower = word2.lower()
433
+ if word1_lower != word2_lower:
434
+ graph[word1_lower].add(word2_lower)
435
+
436
+
437
+ # Perform BFS/graph traversal
438
+ queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
439
+ visited_words = set(initial_keywords)
440
+
441
+ while queue:
442
+ current_word, level = queue.pop(0)
443
+ if level >= depth:
444
+ continue
445
+
446
+ relevant_sentences.update(word_to_sentences.get(current_word, []))
447
+
448
+ for neighbor in graph.get(current_word, []):
449
+ if neighbor not in visited_words:
450
+ visited_words.add(neighbor)
451
+ queue.append((neighbor, level + 1))
452
+
453
+ final_sentences = set()
454
+ for sentence in relevant_sentences:
455
+ if not is_irrelevant_number_sequence(sentence):
456
+ processed_sentence = remove_isolated_single_digits(sentence)
457
+ if processed_sentence:
458
+ final_sentences.add(processed_sentence)
459
+
460
+ return "\n".join(sorted(list(final_sentences)))
461
+
462
+
463
+
464
+ def get_contextual_sentences_DFS(text_content, keyword, depth=2):
465
+ sentences = extract_sentences(text_content)
466
+
467
+ # Build word-to-sentences mapping
468
+ word_to_sentences = {}
469
+ for sent in sentences:
470
+ words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
471
+ for word in words_in_sent:
472
+ word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
473
+
474
+ # Function to extract codes in a sentence
475
+ def extract_codes(sentence):
476
+ # Only codes like 'KSK1', 'MG272794', not pure numbers
477
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
478
+
479
+ # DFS with priority based on distance to keyword and early stop if country found
480
+ def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
481
+ country = "unknown"
482
+ if current_depth > max_depth:
483
+ return country, False
484
+
485
+ if current_word not in word_to_sentences:
486
+ return country, False
487
+
488
+ for sentence in word_to_sentences[current_word]:
489
+ if sentence == parent_sentence:
490
+ continue # avoid reusing the same sentence
491
+
492
+ collected_sentences.add(sentence)
493
+
494
+ #print("current_word:", current_word)
495
+ small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
496
+ #print(small_sen)
497
+ country = model.get_country_from_text(small_sen)
498
+ #print("small context country:", country)
499
+ if country.lower() != "unknown":
500
+ return country, True
501
+ else:
502
+ country = model.get_country_from_text(sentence)
503
+ #print("full sentence country:", country)
504
+ if country.lower() != "unknown":
505
+ return country, True
506
+
507
+ codes_in_sentence = extract_codes(sentence)
508
+ idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
509
+ if idx is None:
510
+ continue
511
+
512
+ sorted_children = sorted(
513
+ [code for code in codes_in_sentence if code.lower() not in visited_words],
514
+ key=lambda x: (abs(codes_in_sentence.index(x) - idx),
515
+ 0 if codes_in_sentence.index(x) > idx else 1)
516
+ )
517
+
518
+ #print("sorted_children:", sorted_children)
519
+ for child in sorted_children:
520
+ child_lower = child.lower()
521
+ if child_lower not in visited_words:
522
+ visited_words.add(child_lower)
523
+ country, should_stop = dfs_traverse(
524
+ child_lower, current_depth + 1, max_depth,
525
+ visited_words, collected_sentences, parent_sentence=sentence
526
+ )
527
+ if should_stop:
528
+ return country, True
529
+
530
+ return country, False
531
+
532
+ # Begin DFS
533
+ collected_sentences = set()
534
+ visited_words = set([keyword.lower()])
535
+ country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
536
+
537
+ # Filter irrelevant sentences
538
+ final_sentences = set()
539
+ for sentence in collected_sentences:
540
+ if not is_irrelevant_number_sequence(sentence):
541
+ processed = remove_isolated_single_digits(sentence)
542
+ if processed:
543
+ final_sentences.add(processed)
544
+ if not final_sentences:
545
+ return country, text_content
546
+ return country, "\n".join(sorted(list(final_sentences)))
547
+
548
+ # Helper function for normalizing text for overlap comparison
549
+ def normalize_for_overlap(s: str) -> str:
550
+ s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
551
+ s = re.sub(r'\s+', ' ', s).strip()
552
+ return s
553
+
554
+ def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
555
+ if not text1: return text2
556
+ if not text2: return text1
557
+
558
+ # Case 1: text2 is fully contained in text1 or vice-versa
559
+ if text2 in text1:
560
+ return text1
561
+ if text1 in text2:
562
+ return text2
563
+
564
+ # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
565
+ # This is what your function was primarily designed for.
566
+ # It looks for the overlap at the "junction" of text1 and text2.
567
+
568
+ max_junction_overlap = 0
569
+ for i in range(min(len(text1), len(text2)), 0, -1):
570
+ suffix1 = text1[-i:]
571
+ prefix2 = text2[:i]
572
+ # Prioritize exact match, then normalized match
573
+ if suffix1 == prefix2:
574
+ max_junction_overlap = i
575
+ break
576
+ elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
577
+ max_junction_overlap = i
578
+ break # Take the first (longest) normalized match
579
+
580
+ if max_junction_overlap > 0:
581
+ merged_text = text1 + text2[max_junction_overlap:]
582
+ return re.sub(r'\s+', ' ', merged_text).strip()
583
+
584
+ # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
585
+ # This addresses your specific test case where the overlap is at the very beginning of both strings.
586
+ # This is often used when trying to deduplicate content that shares a common start.
587
+
588
+ longest_common_prefix_len = 0
589
+ min_len = min(len(text1), len(text2))
590
+ for i in range(min_len):
591
+ if text1[i] == text2[i]:
592
+ longest_common_prefix_len = i + 1
593
+ else:
594
+ break
595
+
596
+ # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
597
+ # AND the remaining parts are distinct, then apply this merge.
598
+ # This is a heuristic and might need fine-tuning.
599
+ if longest_common_prefix_len > 0 and \
600
+ text1[longest_common_prefix_len:].strip() and \
601
+ text2[longest_common_prefix_len:].strip():
602
+
603
+ # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
604
+ # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
605
+ # common prefix is "Hi, I am Vy."
606
+ # Remaining text1: " Nice to meet you."
607
+ # Remaining text2: " Goodbye Vy."
608
+ # So we merge common_prefix + remaining_text1 + remaining_text2
609
+
610
+ common_prefix_str = text1[:longest_common_prefix_len]
611
+ remainder_text1 = text1[longest_common_prefix_len:]
612
+ remainder_text2 = text2[longest_common_prefix_len:]
613
+
614
+ merged_text = common_prefix_str + remainder_text1 + remainder_text2
615
+ return re.sub(r'\s+', ' ', merged_text).strip()
616
+
617
+
618
+ # If neither specific overlap type is found, just concatenate
619
+ merged_text = text1 + text2
620
+ return re.sub(r'\s+', ' ', merged_text).strip()
621
+
622
+ from docx import Document
623
+ from pipeline import upload_file_to_drive
624
+ # def save_text_to_docx(text_content: str, file_path: str):
625
+ # """
626
+ # Saves a given text string into a .docx file.
627
+
628
+ # Args:
629
+ # text_content (str): The text string to save.
630
+ # file_path (str): The full path including the filename where the .docx file will be saved.
631
+ # Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
632
+ # """
633
+ # try:
634
+ # document = Document()
635
+
636
+ # # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
637
+ # for paragraph_text in text_content.split('\n'):
638
+ # document.add_paragraph(paragraph_text)
639
+
640
+ # document.save(file_path)
641
+ # print(f"Text successfully saved to '{file_path}'")
642
+ # except Exception as e:
643
+ # print(f"Error saving text to docx file: {e}")
644
+ # def save_text_to_docx(text_content: str, filename: str, drive_folder_id: str):
645
+ # """
646
+ # Saves a given text string into a .docx file locally, then uploads to Google Drive.
647
+
648
+ # Args:
649
+ # text_content (str): The text string to save.
650
+ # filename (str): The target .docx file name, e.g. 'BRU18_merged_document.docx'.
651
+ # drive_folder_id (str): Google Drive folder ID where to upload the file.
652
+ # """
653
+ # try:
654
+ # # Save to temporary local path first
655
+ # print("file name: ", filename)
656
+ # print("length text content: ", len(text_content))
657
+ # local_path = os.path.join(tempfile.gettempdir(), filename)
658
+ # document = Document()
659
+ # for paragraph_text in text_content.split('\n'):
660
+ # document.add_paragraph(paragraph_text)
661
+ # document.save(local_path)
662
+ # print(f"✅ Text saved locally to: {local_path}")
663
+
664
+ # # Upload to Drive
665
+ # pipeline.upload_file_to_drive(local_path, filename, drive_folder_id)
666
+ # print(f"✅ Uploaded '{filename}' to Google Drive folder ID: {drive_folder_id}")
667
+
668
+ # except Exception as e:
669
+ # print(f"❌ Error saving or uploading DOCX: {e}")
670
+ def save_text_to_docx(text_content: str, full_local_path: str):
671
+ document = Document()
672
+ for paragraph_text in text_content.split('\n'):
673
+ document.add_paragraph(paragraph_text)
674
+ document.save(full_local_path)
675
+ print(f"✅ Saved DOCX locally: {full_local_path}")
676
+
677
+
678
+
679
+ '''2 scenerios:
680
+ - quick look then found then deepdive and directly get location then stop
681
+ - quick look then found then deepdive but not find location then hold the related words then
682
+ look another files iteratively for each related word and find location and stop'''
683
+ def extract_context(text, keyword, window=500):
684
+ # firstly try accession number
685
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
686
+
687
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
688
+ keyword_match = code_pattern.search(keyword)
689
+
690
+ keyword_prefix = None
691
+ keyword_num = None
692
+
693
+ if keyword_match:
694
+ keyword_prefix = keyword_match.group(1).lower()
695
+ keyword_num = int(keyword_match.group(2))
696
+ text = text.lower()
697
+ idx = text.find(keyword.lower())
698
+ if idx == -1:
699
+ if keyword_prefix:
700
+ idx = text.find(keyword_prefix)
701
+ if idx == -1:
702
+ return "Sample ID not found."
703
+ return text[max(0, idx-window): idx+window]
704
+ return text[max(0, idx-window): idx+window]
705
+ def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
706
+ cache = {}
707
+ country = "unknown"
708
+ output = ""
709
+ tem_output, small_output = "",""
710
+ keyword_appear = (False,"")
711
+ keywords = []
712
+ if isolate: keywords.append(isolate)
713
+ if accession: keywords.append(accession)
714
+ for f in filePaths:
715
+ # scenerio 1: direct location: truncate the context and then use qa model?
716
+ if keywords:
717
+ for keyword in keywords:
718
+ text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
719
+ if keyword in final_input:
720
+ context = extract_context(final_input, keyword)
721
+ # quick look if country already in context and if yes then return
722
+ country = model.get_country_from_text(context)
723
+ if country != "unknown":
724
+ return country, context, final_input
725
+ else:
726
+ country = model.get_country_from_text(final_input)
727
+ if country != "unknown":
728
+ return country, context, final_input
729
+ else: # might be cross-ref
730
+ keyword_appear = (True, f)
731
+ cache[f] = context
732
+ small_output = merge_texts_skipping_overlap(output, context) + "\n"
733
+ chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
734
+ countryBFS = model.get_country_from_text(chunkBFS)
735
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
736
+ output = merge_texts_skipping_overlap(output, final_input)
737
+ if countryDFS != "unknown" and countryBFS != "unknown":
738
+ if len(chunkDFS) <= len(chunkBFS):
739
+ return countryDFS, chunkDFS, output
740
+ else:
741
+ return countryBFS, chunkBFS, output
742
+ else:
743
+ if countryDFS != "unknown":
744
+ return countryDFS, chunkDFS, output
745
+ if countryBFS != "unknown":
746
+ return countryBFS, chunkBFS, output
747
+ else:
748
+ # scenerio 2:
749
+ '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
750
+ but if we look at file 1 first then maybe we can have lookup dict which country
751
+ such as Thailand as the key and its re'''
752
+ cache[f] = final_input
753
+ if keyword_appear[0] == True:
754
+ for c in cache:
755
+ if c!=keyword_appear[1]:
756
+ if cache[c].lower() not in output.lower():
757
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
758
+ chunkBFS = get_contextual_sentences_BFS(output, keyword)
759
+ countryBFS = model.get_country_from_text(chunkBFS)
760
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
761
+ if countryDFS != "unknown" and countryBFS != "unknown":
762
+ if len(chunkDFS) <= len(chunkBFS):
763
+ return countryDFS, chunkDFS, output
764
+ else:
765
+ return countryBFS, chunkBFS, output
766
+ else:
767
+ if countryDFS != "unknown":
768
+ return countryDFS, chunkDFS, output
769
+ if countryBFS != "unknown":
770
+ return countryBFS, chunkBFS, output
771
+ else:
772
+ if cache[f].lower() not in output.lower():
773
+ output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
774
+ if len(output) == 0 or keyword_appear[0]==False:
775
+ for c in cache:
776
+ if cache[c].lower() not in output.lower():
777
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
778
  return country, "", output
model.py CHANGED
The diff for this file is too large to render. See raw diff
 
mtdna_backend.py CHANGED
@@ -1,928 +1,945 @@
1
- import gradio as gr
2
- from collections import Counter
3
- import csv
4
- import os
5
- from functools import lru_cache
6
- #import app
7
- from mtdna_classifier import classify_sample_location
8
- import data_preprocess, model, pipeline
9
- import subprocess
10
- import json
11
- import pandas as pd
12
- import io
13
- import re
14
- import tempfile
15
- import gspread
16
- from oauth2client.service_account import ServiceAccountCredentials
17
- from io import StringIO
18
- import hashlib
19
- import threading
20
-
21
- # @lru_cache(maxsize=3600)
22
- # def classify_sample_location_cached(accession):
23
- # return classify_sample_location(accession)
24
-
25
- #@lru_cache(maxsize=3600)
26
- def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None):
27
- print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
- print("len of save df: ", len(save_df))
29
- return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df)
30
-
31
- # Count and suggest final location
32
- # def compute_final_suggested_location(rows):
33
- # candidates = [
34
- # row.get("Predicted Location", "").strip()
35
- # for row in rows
36
- # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
37
- # ] + [
38
- # row.get("Inferred Region", "").strip()
39
- # for row in rows
40
- # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
41
- # ]
42
-
43
- # if not candidates:
44
- # return Counter(), ("Unknown", 0)
45
- # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
46
- # tokens = []
47
- # for item in candidates:
48
- # # Split by comma, whitespace, and newlines
49
- # parts = re.split(r'[\s,]+', item)
50
- # tokens.extend(parts)
51
-
52
- # # Step 2: Clean and normalize tokens
53
- # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
54
-
55
- # # Step 3: Count
56
- # counts = Counter(tokens)
57
-
58
- # # Step 4: Get most common
59
- # top_location, count = counts.most_common(1)[0]
60
- # return counts, (top_location, count)
61
-
62
- # Store feedback (with required fields)
63
-
64
- def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
65
- if not answer1.strip() or not answer2.strip():
66
- return "⚠️ Please answer both questions before submitting."
67
-
68
- try:
69
- # ✅ Step: Load credentials from Hugging Face secret
70
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
71
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
72
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
73
-
74
- # Connect to Google Sheet
75
- client = gspread.authorize(creds)
76
- sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
77
-
78
- # Append feedback
79
- sheet.append_row([accession, answer1, answer2, contact])
80
- return "✅ Feedback submitted. Thank you!"
81
-
82
- except Exception as e:
83
- return f"❌ Error submitting feedback: {e}"
84
-
85
- # helper function to extract accessions
86
- def extract_accessions_from_input(file=None, raw_text=""):
87
- print(f"RAW TEXT RECEIVED: {raw_text}")
88
- accessions = []
89
- seen = set()
90
- if file:
91
- try:
92
- if file.name.endswith(".csv"):
93
- df = pd.read_csv(file)
94
- elif file.name.endswith(".xlsx"):
95
- df = pd.read_excel(file)
96
- else:
97
- return [], "Unsupported file format. Please upload CSV or Excel."
98
- for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
99
- if acc not in seen:
100
- accessions.append(acc)
101
- seen.add(acc)
102
- except Exception as e:
103
- return [], f"Failed to read file: {e}"
104
-
105
- if raw_text:
106
- text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
107
- for acc in text_ids:
108
- if acc not in seen:
109
- accessions.append(acc)
110
- seen.add(acc)
111
-
112
- return list(accessions), None
113
- # Add a new helper to backend: `filter_unprocessed_accessions()`
114
- def get_incomplete_accessions(file_path):
115
- df = pd.read_excel(file_path)
116
-
117
- incomplete_accessions = []
118
- for _, row in df.iterrows():
119
- sample_id = str(row.get("Sample ID", "")).strip()
120
-
121
- # Skip if no sample ID
122
- if not sample_id:
123
- continue
124
-
125
- # Drop the Sample ID and check if the rest is empty
126
- other_cols = row.drop(labels=["Sample ID"], errors="ignore")
127
- if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
128
- # Extract the accession number from the sample ID using regex
129
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
130
- if match:
131
- incomplete_accessions.append(match.group(0))
132
- print(len(incomplete_accessions))
133
- return incomplete_accessions
134
-
135
- # GOOGLE_SHEET_NAME = "known_samples"
136
- # USAGE_DRIVE_FILENAME = "user_usage_log.json"
137
-
138
- def summarize_results(accession, stop_flag=None):
139
- # Early bail
140
- if stop_flag is not None and stop_flag.value:
141
- print(f"🛑 Skipping {accession} before starting.")
142
- return []
143
- # try cache first
144
- cached = check_known_output(accession)
145
- if cached:
146
- print(f"✅ Using cached result for {accession}")
147
- return [[
148
- cached["Sample ID"] or "unknown",
149
- cached["Predicted Country"] or "unknown",
150
- cached["Country Explanation"] or "unknown",
151
- cached["Predicted Sample Type"] or "unknown",
152
- cached["Sample Type Explanation"] or "unknown",
153
- cached["Sources"] or "No Links",
154
- cached["Time cost"]
155
- ]]
156
- # only run when nothing in the cache
157
- try:
158
- print("try gemini pipeline: ",accession)
159
- # ✅ Load credentials from Hugging Face secret
160
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
161
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
162
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
163
- client = gspread.authorize(creds)
164
-
165
- spreadsheet = client.open("known_samples")
166
- sheet = spreadsheet.sheet1
167
-
168
- data = sheet.get_all_values()
169
- if not data:
170
- print("⚠️ Google Sheet 'known_samples' is empty.")
171
- return None
172
-
173
- save_df = pd.DataFrame(data[1:], columns=data[0])
174
- print("before pipeline, len of save df: ", len(save_df))
175
- outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df)
176
- if stop_flag is not None and stop_flag.value:
177
- print(f"🛑 Skipped {accession} mid-pipeline.")
178
- return []
179
- # outputs = {'KU131308': {'isolate':'BRU18',
180
- # 'country': {'brunei': ['ncbi',
181
- # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
182
- # 'sample_type': {'modern':
183
- # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
184
- # 'query_cost': 9.754999999999999e-05,
185
- # 'time_cost': '24.776 seconds',
186
- # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
187
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
188
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
189
- except Exception as e:
190
- return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
191
-
192
- if accession not in outputs:
193
- print("no accession in output ", accession)
194
- return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
195
-
196
- row_score = []
197
- rows = []
198
- save_rows = []
199
- for key in outputs:
200
- pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
201
- for section, results in outputs[key].items():
202
- if section == "country" or section =="sample_type":
203
- pred_output = []#"\n".join(list(results.keys()))
204
- output_explanation = ""
205
- for result, content in results.items():
206
- if len(result) == 0: result = "unknown"
207
- if len(content) == 0: output_explanation = "unknown"
208
- else:
209
- output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
210
- pred_output.append(result)
211
- pred_output = "\n".join(pred_output)
212
- if section == "country":
213
- pred_country, country_explanation = pred_output, output_explanation
214
- elif section == "sample_type":
215
- pred_sample, sample_explanation = pred_output, output_explanation
216
- if outputs[key]["isolate"].lower()!="unknown":
217
- label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
218
- else: label = key
219
- if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
220
- row = {
221
- "Sample ID": label or "unknown",
222
- "Predicted Country": pred_country or "unknown",
223
- "Country Explanation": country_explanation or "unknown",
224
- "Predicted Sample Type":pred_sample or "unknown",
225
- "Sample Type Explanation":sample_explanation or "unknown",
226
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
227
- "Time cost": outputs[key]["time_cost"]
228
- }
229
- #row_score.append(row)
230
- rows.append(list(row.values()))
231
-
232
- save_row = {
233
- "Sample ID": label or "unknown",
234
- "Predicted Country": pred_country or "unknown",
235
- "Country Explanation": country_explanation or "unknown",
236
- "Predicted Sample Type":pred_sample or "unknown",
237
- "Sample Type Explanation":sample_explanation or "unknown",
238
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
239
- "Query_cost": outputs[key]["query_cost"] or "",
240
- "Time cost": outputs[key]["time_cost"] or "",
241
- "file_chunk":outputs[key]["file_chunk"] or "",
242
- "file_all_output":outputs[key]["file_all_output"] or ""
243
- }
244
- #row_score.append(row)
245
- save_rows.append(list(save_row.values()))
246
-
247
- # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
248
- # summary_lines = [f"### 🧭 Location Summary:\n"]
249
- # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
250
- # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
251
- # summary = "\n".join(summary_lines)
252
-
253
- # save the new running sample to known excel file
254
- # try:
255
- # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
256
- # if os.path.exists(KNOWN_OUTPUT_PATH):
257
- # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
258
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
259
- # else:
260
- # df_combined = df_new
261
- # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
262
- # except Exception as e:
263
- # print(f"⚠️ Failed to save known output: {e}")
264
- # try:
265
- # df_new = pd.DataFrame(save_rows, columns=[
266
- # "Sample ID", "Predicted Country", "Country Explanation",
267
- # "Predicted Sample Type", "Sample Type Explanation",
268
- # "Sources", "Query_cost", "Time cost"
269
- # ])
270
-
271
- # # ✅ Google Sheets API setup
272
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
273
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
274
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
275
- # client = gspread.authorize(creds)
276
-
277
- # # Open the known_samples sheet
278
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
279
- # sheet = spreadsheet.sheet1
280
-
281
- # # ✅ Read old data
282
- # existing_data = sheet.get_all_values()
283
- # if existing_data:
284
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
285
- # else:
286
- # df_old = pd.DataFrame(columns=df_new.columns)
287
-
288
- # # ✅ Combine and remove duplicates
289
- # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
290
-
291
- # # Clear and write back
292
- # sheet.clear()
293
- # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
294
-
295
- # except Exception as e:
296
- # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
297
- try:
298
- # Prepare as DataFrame
299
- df_new = pd.DataFrame(save_rows, columns=[
300
- "Sample ID", "Predicted Country", "Country Explanation",
301
- "Predicted Sample Type", "Sample Type Explanation",
302
- "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
303
- ])
304
-
305
- # ✅ Setup Google Sheets
306
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
307
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
308
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
309
- client = gspread.authorize(creds)
310
- spreadsheet = client.open("known_samples")
311
- sheet = spreadsheet.sheet1
312
-
313
- # Read existing data
314
- existing_data = sheet.get_all_values()
315
-
316
- if existing_data:
317
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
318
-
319
- else:
320
-
321
- df_old = pd.DataFrame(columns=[
322
- "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
323
- "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
324
- "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output"
325
- ])
326
-
327
-
328
- # Index by Sample ID
329
- df_old.set_index("Sample ID", inplace=True)
330
- df_new.set_index("Sample ID", inplace=True)
331
-
332
- # ✅ Update only matching fields
333
- update_columns = [
334
- "Predicted Country", "Predicted Sample Type", "Country Explanation",
335
- "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
336
- ]
337
- for idx, row in df_new.iterrows():
338
- if idx not in df_old.index:
339
- df_old.loc[idx] = "" # new row, fill empty first
340
- for col in update_columns:
341
- if pd.notna(row[col]) and row[col] != "":
342
- df_old.at[idx, col] = row[col]
343
-
344
- # ✅ Reset and write back
345
- df_old.reset_index(inplace=True)
346
- sheet.clear()
347
- sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
348
- print("✅ Match results saved to known_samples.")
349
-
350
- except Exception as e:
351
- print(f" Failed to update known_samples: {e}")
352
-
353
-
354
- return rows#, summary, labelAncient_Modern, explain_label
355
-
356
- # save the batch input in excel file
357
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
358
- # with pd.ExcelWriter(filename) as writer:
359
- # # Save table
360
- # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
361
- # df.to_excel(writer, sheet_name="Detailed Results", index=False)
362
- # try:
363
- # df_old = pd.read_excel(filename)
364
- # except:
365
- # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
366
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
367
- # # if os.path.exists(filename):
368
- # # df_old = pd.read_excel(filename)
369
- # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
370
- # # else:
371
- # # df_combined = df_new
372
- # df_combined.to_excel(filename, index=False)
373
- # # # Save summary
374
- # # summary_df = pd.DataFrame({"Summary": [summary_text]})
375
- # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
376
-
377
- # # # Save flag
378
- # # flag_df = pd.DataFrame({"Flag": [flag_text]})
379
- # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
380
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
381
- # df_new = pd.DataFrame(all_rows, columns=[
382
- # "Sample ID", "Predicted Country", "Country Explanation",
383
- # "Predicted Sample Type", "Sample Type Explanation",
384
- # "Sources", "Time cost"
385
- # ])
386
-
387
- # try:
388
- # if os.path.exists(filename):
389
- # df_old = pd.read_excel(filename)
390
- # else:
391
- # df_old = pd.DataFrame(columns=df_new.columns)
392
- # except Exception as e:
393
- # print(f"⚠️ Warning reading old Excel file: {e}")
394
- # df_old = pd.DataFrame(columns=df_new.columns)
395
-
396
- # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
397
- # df_old.set_index("Sample ID", inplace=True)
398
- # df_new.set_index("Sample ID", inplace=True)
399
-
400
- # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
401
-
402
- # df_combined = df_old.reset_index()
403
-
404
- # try:
405
- # df_combined.to_excel(filename, index=False)
406
- # except Exception as e:
407
- # print(f"❌ Failed to write Excel file {filename}: {e}")
408
- def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
409
- df_new = pd.DataFrame(all_rows, columns=[
410
- "Sample ID", "Predicted Country", "Country Explanation",
411
- "Predicted Sample Type", "Sample Type Explanation",
412
- "Sources", "Time cost"
413
- ])
414
-
415
- if is_resume and os.path.exists(filename):
416
- try:
417
- df_old = pd.read_excel(filename)
418
- except Exception as e:
419
- print(f"⚠️ Warning reading old Excel file: {e}")
420
- df_old = pd.DataFrame(columns=df_new.columns)
421
-
422
- # Set index and update existing rows
423
- df_old.set_index("Sample ID", inplace=True)
424
- df_new.set_index("Sample ID", inplace=True)
425
- df_old.update(df_new)
426
-
427
- df_combined = df_old.reset_index()
428
- else:
429
- # If not resuming or file doesn't exist, just use new rows
430
- df_combined = df_new
431
-
432
- try:
433
- df_combined.to_excel(filename, index=False)
434
- except Exception as e:
435
- print(f"❌ Failed to write Excel file {filename}: {e}")
436
-
437
-
438
- # save the batch input in JSON file
439
- def save_to_json(all_rows, summary_text, flag_text, filename):
440
- output_dict = {
441
- "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
442
- # "Summary_Text": summary_text,
443
- # "Ancient_Modern_Flag": flag_text
444
- }
445
-
446
- # If all_rows is a DataFrame, convert it
447
- if isinstance(all_rows, pd.DataFrame):
448
- output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
449
-
450
- with open(filename, "w") as external_file:
451
- json.dump(output_dict, external_file, indent=2)
452
-
453
- # save the batch input in Text file
454
- def save_to_txt(all_rows, summary_text, flag_text, filename):
455
- if isinstance(all_rows, pd.DataFrame):
456
- detailed_results = all_rows.to_dict(orient="records")
457
- output = ""
458
- #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
459
- output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
460
- for r in detailed_results:
461
- output += ",".join([str(v) for v in r.values()]) + "\n\n"
462
- with open(filename, "w") as f:
463
- f.write("=== Detailed Results ===\n")
464
- f.write(output + "\n")
465
-
466
- # f.write("\n=== Summary ===\n")
467
- # f.write(summary_text + "\n")
468
-
469
- # f.write("\n=== Ancient/Modern Flag ===\n")
470
- # f.write(flag_text + "\n")
471
-
472
- def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
473
- tmp_dir = tempfile.mkdtemp()
474
-
475
- #html_table = all_rows.value # assuming this is stored somewhere
476
-
477
- # Parse back to DataFrame
478
- #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
479
- all_rows = pd.read_html(StringIO(all_rows))[0]
480
- print(all_rows)
481
-
482
- if output_type == "Excel":
483
- file_path = f"{tmp_dir}/batch_output.xlsx"
484
- save_to_excel(all_rows, summary_text, flag_text, file_path)
485
- elif output_type == "JSON":
486
- file_path = f"{tmp_dir}/batch_output.json"
487
- save_to_json(all_rows, summary_text, flag_text, file_path)
488
- print("Done with JSON")
489
- elif output_type == "TXT":
490
- file_path = f"{tmp_dir}/batch_output.txt"
491
- save_to_txt(all_rows, summary_text, flag_text, file_path)
492
- else:
493
- return gr.update(visible=False) # invalid option
494
-
495
- return gr.update(value=file_path, visible=True)
496
- # save cost by checking the known outputs
497
-
498
- # def check_known_output(accession):
499
- # if not os.path.exists(KNOWN_OUTPUT_PATH):
500
- # return None
501
-
502
- # try:
503
- # df = pd.read_excel(KNOWN_OUTPUT_PATH)
504
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
505
- # if match:
506
- # accession = match.group(0)
507
-
508
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
509
- # if not matched.empty:
510
- # return matched.iloc[0].to_dict() # Return the cached row
511
- # except Exception as e:
512
- # print(f"⚠️ Failed to load known samples: {e}")
513
- # return None
514
-
515
- # def check_known_output(accession):
516
- # try:
517
- # # ✅ Load credentials from Hugging Face secret
518
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
519
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
520
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
521
- # client = gspread.authorize(creds)
522
-
523
- # # Open the known_samples sheet
524
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
525
- # sheet = spreadsheet.sheet1
526
-
527
- # # Read all rows
528
- # data = sheet.get_all_values()
529
- # if not data:
530
- # return None
531
-
532
- # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
533
-
534
- # # ✅ Normalize accession pattern
535
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
536
- # if match:
537
- # accession = match.group(0)
538
-
539
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
540
- # if not matched.empty:
541
- # return matched.iloc[0].to_dict()
542
-
543
- # except Exception as e:
544
- # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
545
- # return None
546
- def check_known_output(accession):
547
- print("inside check known output function")
548
- try:
549
- # Load credentials from Hugging Face secret
550
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
551
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
552
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
553
- client = gspread.authorize(creds)
554
-
555
- spreadsheet = client.open("known_samples")
556
- sheet = spreadsheet.sheet1
557
-
558
- data = sheet.get_all_values()
559
- if not data:
560
- print("⚠️ Google Sheet 'known_samples' is empty.")
561
- return None
562
-
563
- df = pd.DataFrame(data[1:], columns=data[0])
564
- if "Sample ID" not in df.columns:
565
- print("❌ Column 'Sample ID' not found in Google Sheet.")
566
- return None
567
-
568
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
569
- if match:
570
- accession = match.group(0)
571
-
572
- matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
573
- if not matched.empty:
574
- #return matched.iloc[0].to_dict()
575
- row = matched.iloc[0]
576
- country = row.get("Predicted Country", "").strip().lower()
577
- sample_type = row.get("Predicted Sample Type", "").strip().lower()
578
-
579
- if country and country != "unknown" and sample_type and sample_type != "unknown":
580
- return row.to_dict()
581
- else:
582
- print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
583
- return None
584
- else:
585
- print(f"🔍 Accession {accession} not found in known_samples.")
586
- return None
587
-
588
- except Exception as e:
589
- import traceback
590
- print("❌ Exception occurred during check_known_output:")
591
- traceback.print_exc()
592
- return None
593
-
594
-
595
- def hash_user_id(user_input):
596
- return hashlib.sha256(user_input.encode()).hexdigest()
597
-
598
- # ✅ Load and save usage count
599
-
600
- # def load_user_usage():
601
- # if not os.path.exists(USER_USAGE_TRACK_FILE):
602
- # return {}
603
-
604
- # try:
605
- # with open(USER_USAGE_TRACK_FILE, "r") as f:
606
- # content = f.read().strip()
607
- # if not content:
608
- # return {} # file is empty
609
- # return json.loads(content)
610
- # except (json.JSONDecodeError, ValueError):
611
- # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
612
- # return {} # fallback to empty dict
613
- # def load_user_usage():
614
- # try:
615
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
616
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
617
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
618
- # client = gspread.authorize(creds)
619
-
620
- # sheet = client.open("user_usage_log").sheet1
621
- # data = sheet.get_all_records() # Assumes columns: email, usage_count
622
-
623
- # usage = {}
624
- # for row in data:
625
- # email = row.get("email", "").strip().lower()
626
- # count = int(row.get("usage_count", 0))
627
- # if email:
628
- # usage[email] = count
629
- # return usage
630
- # except Exception as e:
631
- # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
632
- # return {}
633
- # def load_user_usage():
634
- # try:
635
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
636
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
637
-
638
- # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
639
- # if not found:
640
- # return {} # not found, start fresh
641
-
642
- # #file_id = found[0]["id"]
643
- # file_id = found
644
- # content = pipeline.download_drive_file_content(file_id)
645
- # return json.loads(content.strip()) if content.strip() else {}
646
-
647
- # except Exception as e:
648
- # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
649
- # return {}
650
- def load_user_usage():
651
- try:
652
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
653
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
654
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
655
- client = gspread.authorize(creds)
656
-
657
- sheet = client.open("user_usage_log").sheet1
658
- data = sheet.get_all_values()
659
- print("data: ", data)
660
- print("🧪 Raw header row from sheet:", data[0])
661
- print("🧪 Character codes in each header:")
662
- for h in data[0]:
663
- print([ord(c) for c in h])
664
-
665
- if not data or len(data) < 2:
666
- print("⚠️ Sheet is empty or missing rows.")
667
- return {}
668
-
669
- headers = [h.strip().lower() for h in data[0]]
670
- if "email" not in headers or "usage_count" not in headers:
671
- print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
672
- return {}
673
-
674
- permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
675
- df = pd.DataFrame(data[1:], columns=headers)
676
-
677
- usage = {}
678
- permitted = {}
679
- for _, row in df.iterrows():
680
- email = row.get("email", "").strip().lower()
681
- try:
682
- #count = int(row.get("usage_count", 0))
683
- try:
684
- count = int(float(row.get("usage_count", 0)))
685
- except Exception:
686
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
687
- count = 0
688
-
689
- if email:
690
- usage[email] = count
691
- if permitted_index is not None:
692
- try:
693
- permitted_count = int(float(row.get("permitted_samples", 50)))
694
- permitted[email] = permitted_count
695
- except:
696
- permitted[email] = 50
697
-
698
- except ValueError:
699
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
700
- return usage, permitted
701
-
702
- except Exception as e:
703
- print(f" Error in load_user_usage: {e}")
704
- return {}, {}
705
-
706
-
707
-
708
- # def save_user_usage(usage):
709
- # with open(USER_USAGE_TRACK_FILE, "w") as f:
710
- # json.dump(usage, f, indent=2)
711
-
712
- # def save_user_usage(usage_dict):
713
- # try:
714
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
715
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
716
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
717
- # client = gspread.authorize(creds)
718
-
719
- # sheet = client.open("user_usage_log").sheet1
720
- # sheet.clear() # clear old contents first
721
-
722
- # # Write header + rows
723
- # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
724
- # sheet.update(rows)
725
- # except Exception as e:
726
- # print(f" Failed to save user usage to Google Sheets: {e}")
727
- # def save_user_usage(usage_dict):
728
- # try:
729
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
730
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
731
-
732
- # import tempfile
733
- # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
734
- # print("💾 Saving this usage dict:", usage_dict)
735
- # with open(tmp_path, "w") as f:
736
- # json.dump(usage_dict, f, indent=2)
737
-
738
- # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
739
-
740
- # except Exception as e:
741
- # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
742
- # def save_user_usage(usage_dict):
743
- # try:
744
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
745
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
746
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
747
- # client = gspread.authorize(creds)
748
-
749
- # spreadsheet = client.open("user_usage_log")
750
- # sheet = spreadsheet.sheet1
751
-
752
- # # Step 1: Convert new usage to DataFrame
753
- # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
754
- # df_new["email"] = df_new["email"].str.strip().str.lower()
755
-
756
- # # Step 2: Load existing data
757
- # existing_data = sheet.get_all_values()
758
- # print("🧪 Sheet existing_data:", existing_data)
759
-
760
- # # Try to load old data
761
- # if existing_data and len(existing_data[0]) >= 1:
762
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
763
-
764
- # # Fix missing columns
765
- # if "email" not in df_old.columns:
766
- # df_old["email"] = ""
767
- # if "usage_count" not in df_old.columns:
768
- # df_old["usage_count"] = 0
769
-
770
- # df_old["email"] = df_old["email"].str.strip().str.lower()
771
- # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
772
- # else:
773
- # df_old = pd.DataFrame(columns=["email", "usage_count"])
774
-
775
- # # Step 3: Merge
776
- # df_combined = pd.concat([df_old, df_new], ignore_index=True)
777
- # df_combined = df_combined.groupby("email", as_index=False).sum()
778
-
779
- # # Step 4: Write back
780
- # sheet.clear()
781
- # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
782
- # print(" Saved user usage to user_usage_log sheet.")
783
-
784
- # except Exception as e:
785
- # print(f" Failed to save user usage to Google Sheets: {e}")
786
- def save_user_usage(usage_dict):
787
- try:
788
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
789
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
790
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
791
- client = gspread.authorize(creds)
792
-
793
- spreadsheet = client.open("user_usage_log")
794
- sheet = spreadsheet.sheet1
795
-
796
- # Build new df
797
- df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
798
- df_new["email"] = df_new["email"].str.strip().str.lower()
799
- df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
800
-
801
- # Read existing data
802
- existing_data = sheet.get_all_values()
803
- if existing_data and len(existing_data[0]) >= 2:
804
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
805
- df_old["email"] = df_old["email"].str.strip().str.lower()
806
- df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
807
- else:
808
- df_old = pd.DataFrame(columns=["email", "usage_count"])
809
-
810
- # Overwrite specific emails only
811
- df_old = df_old.set_index("email")
812
- for email, count in usage_dict.items():
813
- email = email.strip().lower()
814
- df_old.loc[email, "usage_count"] = count
815
- df_old = df_old.reset_index()
816
-
817
- # Save
818
- sheet.clear()
819
- sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
820
- print("✅ Saved user usage to user_usage_log sheet.")
821
-
822
- except Exception as e:
823
- print(f" Failed to save user usage to Google Sheets: {e}")
824
-
825
-
826
-
827
-
828
- # def increment_usage(user_id, num_samples=1):
829
- # usage = load_user_usage()
830
- # if user_id not in usage:
831
- # usage[user_id] = 0
832
- # usage[user_id] += num_samples
833
- # save_user_usage(usage)
834
- # return usage[user_id]
835
- # def increment_usage(email: str, count: int):
836
- # usage = load_user_usage()
837
- # email_key = email.strip().lower()
838
- # usage[email_key] = usage.get(email_key, 0) + count
839
- # save_user_usage(usage)
840
- # return usage[email_key]
841
- def increment_usage(email: str, count: int = 1):
842
- usage, permitted = load_user_usage()
843
- email_key = email.strip().lower()
844
- #usage[email_key] = usage.get(email_key, 0) + count
845
- current = usage.get(email_key, 0)
846
- new_value = current + count
847
- max_allowed = permitted.get(email_key) or 50
848
- usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower
849
- print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
850
- print("max allow is: ", max_allowed)
851
- save_user_usage(usage)
852
- return usage[email_key], max_allowed
853
-
854
-
855
- # run the batch
856
- def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
857
- stop_flag=None, output_file_path=None,
858
- limited_acc=50, yield_callback=None):
859
- if user_email:
860
- limited_acc += 10
861
- accessions, error = extract_accessions_from_input(file, raw_text)
862
- if error:
863
- #return [], "", "", f"Error: {error}"
864
- return [], f"Error: {error}", 0, "", ""
865
- if resume_file:
866
- accessions = get_incomplete_accessions(resume_file)
867
- tmp_dir = tempfile.mkdtemp()
868
- if not output_file_path:
869
- if resume_file:
870
- output_file_path = os.path.join(tmp_dir, resume_file)
871
- else:
872
- output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
873
-
874
- all_rows = []
875
- # all_summaries = []
876
- # all_flags = []
877
- progress_lines = []
878
- warning = ""
879
- if len(accessions) > limited_acc:
880
- accessions = accessions[:limited_acc]
881
- warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
882
- for i, acc in enumerate(accessions):
883
- if stop_flag and stop_flag.value:
884
- line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
885
- progress_lines.append(line)
886
- if yield_callback:
887
- yield_callback(line)
888
- print("🛑 User requested stop.")
889
- break
890
- print(f"[{i+1}/{len(accessions)}] Processing {acc}")
891
- try:
892
- # rows, summary, label, explain = summarize_results(acc)
893
- rows = summarize_results(acc)
894
- all_rows.extend(rows)
895
- # all_summaries.append(f"**{acc}**\n{summary}")
896
- # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
897
- #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
898
- save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
899
- line = f"✅ Processed {acc} ({i+1}/{len(accessions)})"
900
- progress_lines.append(line)
901
- if yield_callback:
902
- yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
903
- except Exception as e:
904
- print(f"❌ Failed to process {acc}: {e}")
905
- continue
906
- #all_summaries.append(f"**{acc}**: Failed - {e}")
907
- #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
908
- limited_acc -= 1
909
- """for row in all_rows:
910
- source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
911
-
912
- if source_column.startswith("http"): # Check if the source is a URL
913
- # Wrap it with HTML anchor tags to make it clickable
914
- row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
915
- if not warning:
916
- warning = f"You only have {limited_acc} left"
917
- if user_email.strip():
918
- user_hash = hash_user_id(user_email)
919
- total_queries = increment_usage(user_hash, len(all_rows))
920
- else:
921
- total_queries = 0
922
- yield_callback("✅ Finished!")
923
-
924
- # summary_text = "\n\n---\n\n".join(all_summaries)
925
- # flag_text = "\n\n---\n\n".join(all_flags)
926
- #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
927
- #return all_rows, gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
 
1
+ import gradio as gr
2
+ from collections import Counter
3
+ import csv
4
+ import os
5
+ from functools import lru_cache
6
+ #import app
7
+ from mtdna_classifier import classify_sample_location
8
+ import data_preprocess, model, pipeline
9
+ import subprocess
10
+ import json
11
+ import pandas as pd
12
+ import io
13
+ import re
14
+ import tempfile
15
+ import gspread
16
+ from oauth2client.service_account import ServiceAccountCredentials
17
+ from io import StringIO
18
+ import hashlib
19
+ import threading
20
+
21
+ # @lru_cache(maxsize=3600)
22
+ # def classify_sample_location_cached(accession):
23
+ # return classify_sample_location(accession)
24
+
25
+ #@lru_cache(maxsize=3600)
26
+ def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None):
27
+ print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
+ print("len of save df: ", len(save_df))
29
+ return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df)
30
+
31
+ # Count and suggest final location
32
+ # def compute_final_suggested_location(rows):
33
+ # candidates = [
34
+ # row.get("Predicted Location", "").strip()
35
+ # for row in rows
36
+ # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
37
+ # ] + [
38
+ # row.get("Inferred Region", "").strip()
39
+ # for row in rows
40
+ # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
41
+ # ]
42
+
43
+ # if not candidates:
44
+ # return Counter(), ("Unknown", 0)
45
+ # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
46
+ # tokens = []
47
+ # for item in candidates:
48
+ # # Split by comma, whitespace, and newlines
49
+ # parts = re.split(r'[\s,]+', item)
50
+ # tokens.extend(parts)
51
+
52
+ # # Step 2: Clean and normalize tokens
53
+ # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
54
+
55
+ # # Step 3: Count
56
+ # counts = Counter(tokens)
57
+
58
+ # # Step 4: Get most common
59
+ # top_location, count = counts.most_common(1)[0]
60
+ # return counts, (top_location, count)
61
+
62
+ # Store feedback (with required fields)
63
+
64
+ def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
65
+ if not answer1.strip() or not answer2.strip():
66
+ return "⚠️ Please answer both questions before submitting."
67
+
68
+ try:
69
+ # ✅ Step: Load credentials from Hugging Face secret
70
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
71
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
72
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
73
+
74
+ # Connect to Google Sheet
75
+ client = gspread.authorize(creds)
76
+ sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
77
+
78
+ # Append feedback
79
+ sheet.append_row([accession, answer1, answer2, contact])
80
+ return "✅ Feedback submitted. Thank you!"
81
+
82
+ except Exception as e:
83
+ return f"❌ Error submitting feedback: {e}"
84
+
85
+ import re
86
+
87
+ ACCESSION_REGEX = re.compile(r'^[A-Z]{1,4}_?\d{6}(\.\d+)?$')
88
+
89
+ def is_valid_accession(acc):
90
+ return bool(ACCESSION_REGEX.match(acc))
91
+
92
+ # helper function to extract accessions
93
+ def extract_accessions_from_input(file=None, raw_text=""):
94
+ print(f"RAW TEXT RECEIVED: {raw_text}")
95
+ accessions, invalid_accessions = [], []
96
+ seen = set()
97
+ if file:
98
+ try:
99
+ if file.name.endswith(".csv"):
100
+ df = pd.read_csv(file)
101
+ elif file.name.endswith(".xlsx"):
102
+ df = pd.read_excel(file)
103
+ else:
104
+ return [], "Unsupported file format. Please upload CSV or Excel."
105
+ for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
106
+ if acc not in seen:
107
+ if is_valid_accession(acc):
108
+ accessions.append(acc)
109
+ seen.add(acc)
110
+ else:
111
+ invalid_accessions.append(acc)
112
+
113
+ except Exception as e:
114
+ return [],[], f"Failed to read file: {e}"
115
+
116
+ if raw_text:
117
+ try:
118
+ text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
119
+ for acc in text_ids:
120
+ if acc not in seen:
121
+ if is_valid_accession(acc):
122
+ accessions.append(acc)
123
+ seen.add(acc)
124
+ else:
125
+ invalid_accessions.append(acc)
126
+ except Exception as e:
127
+ return [],[], f"Failed to read file: {e}"
128
+
129
+ return list(accessions), list(invalid_accessions), None
130
+ # ✅ Add a new helper to backend: `filter_unprocessed_accessions()`
131
+ def get_incomplete_accessions(file_path):
132
+ df = pd.read_excel(file_path)
133
+
134
+ incomplete_accessions = []
135
+ for _, row in df.iterrows():
136
+ sample_id = str(row.get("Sample ID", "")).strip()
137
+
138
+ # Skip if no sample ID
139
+ if not sample_id:
140
+ continue
141
+
142
+ # Drop the Sample ID and check if the rest is empty
143
+ other_cols = row.drop(labels=["Sample ID"], errors="ignore")
144
+ if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
145
+ # Extract the accession number from the sample ID using regex
146
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
147
+ if match:
148
+ incomplete_accessions.append(match.group(0))
149
+ print(len(incomplete_accessions))
150
+ return incomplete_accessions
151
+
152
+ # GOOGLE_SHEET_NAME = "known_samples"
153
+ # USAGE_DRIVE_FILENAME = "user_usage_log.json"
154
+
155
+ def summarize_results(accession, stop_flag=None):
156
+ # Early bail
157
+ if stop_flag is not None and stop_flag.value:
158
+ print(f"🛑 Skipping {accession} before starting.")
159
+ return []
160
+ # try cache first
161
+ cached = check_known_output(accession)
162
+ if cached:
163
+ print(f"✅ Using cached result for {accession}")
164
+ return [[
165
+ cached["Sample ID"] or "unknown",
166
+ cached["Predicted Country"] or "unknown",
167
+ cached["Country Explanation"] or "unknown",
168
+ cached["Predicted Sample Type"] or "unknown",
169
+ cached["Sample Type Explanation"] or "unknown",
170
+ cached["Sources"] or "No Links",
171
+ cached["Time cost"]
172
+ ]]
173
+ # only run when nothing in the cache
174
+ try:
175
+ print("try gemini pipeline: ",accession)
176
+ # Load credentials from Hugging Face secret
177
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
178
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
179
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
180
+ client = gspread.authorize(creds)
181
+
182
+ spreadsheet = client.open("known_samples")
183
+ sheet = spreadsheet.sheet1
184
+
185
+ data = sheet.get_all_values()
186
+ if not data:
187
+ print("⚠️ Google Sheet 'known_samples' is empty.")
188
+ return None
189
+
190
+ save_df = pd.DataFrame(data[1:], columns=data[0])
191
+ print("before pipeline, len of save df: ", len(save_df))
192
+ outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df)
193
+ if stop_flag is not None and stop_flag.value:
194
+ print(f"🛑 Skipped {accession} mid-pipeline.")
195
+ return []
196
+ # outputs = {'KU131308': {'isolate':'BRU18',
197
+ # 'country': {'brunei': ['ncbi',
198
+ # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
199
+ # 'sample_type': {'modern':
200
+ # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
201
+ # 'query_cost': 9.754999999999999e-05,
202
+ # 'time_cost': '24.776 seconds',
203
+ # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
204
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
205
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
206
+ except Exception as e:
207
+ return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
208
+
209
+ if accession not in outputs:
210
+ print("no accession in output ", accession)
211
+ return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
212
+
213
+ row_score = []
214
+ rows = []
215
+ save_rows = []
216
+ for key in outputs:
217
+ pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
218
+ for section, results in outputs[key].items():
219
+ if section == "country" or section =="sample_type":
220
+ pred_output = []#"\n".join(list(results.keys()))
221
+ output_explanation = ""
222
+ for result, content in results.items():
223
+ if len(result) == 0: result = "unknown"
224
+ if len(content) == 0: output_explanation = "unknown"
225
+ else:
226
+ output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
227
+ pred_output.append(result)
228
+ pred_output = "\n".join(pred_output)
229
+ if section == "country":
230
+ pred_country, country_explanation = pred_output, output_explanation
231
+ elif section == "sample_type":
232
+ pred_sample, sample_explanation = pred_output, output_explanation
233
+ if outputs[key]["isolate"].lower()!="unknown":
234
+ label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
235
+ else: label = key
236
+ if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
237
+ row = {
238
+ "Sample ID": label or "unknown",
239
+ "Predicted Country": pred_country or "unknown",
240
+ "Country Explanation": country_explanation or "unknown",
241
+ "Predicted Sample Type":pred_sample or "unknown",
242
+ "Sample Type Explanation":sample_explanation or "unknown",
243
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
244
+ "Time cost": outputs[key]["time_cost"]
245
+ }
246
+ #row_score.append(row)
247
+ rows.append(list(row.values()))
248
+
249
+ save_row = {
250
+ "Sample ID": label or "unknown",
251
+ "Predicted Country": pred_country or "unknown",
252
+ "Country Explanation": country_explanation or "unknown",
253
+ "Predicted Sample Type":pred_sample or "unknown",
254
+ "Sample Type Explanation":sample_explanation or "unknown",
255
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
256
+ "Query_cost": outputs[key]["query_cost"] or "",
257
+ "Time cost": outputs[key]["time_cost"] or "",
258
+ "file_chunk":outputs[key]["file_chunk"] or "",
259
+ "file_all_output":outputs[key]["file_all_output"] or ""
260
+ }
261
+ #row_score.append(row)
262
+ save_rows.append(list(save_row.values()))
263
+
264
+ # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
265
+ # summary_lines = [f"### 🧭 Location Summary:\n"]
266
+ # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
267
+ # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
268
+ # summary = "\n".join(summary_lines)
269
+
270
+ # save the new running sample to known excel file
271
+ # try:
272
+ # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
273
+ # if os.path.exists(KNOWN_OUTPUT_PATH):
274
+ # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
275
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
276
+ # else:
277
+ # df_combined = df_new
278
+ # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
279
+ # except Exception as e:
280
+ # print(f"⚠️ Failed to save known output: {e}")
281
+ # try:
282
+ # df_new = pd.DataFrame(save_rows, columns=[
283
+ # "Sample ID", "Predicted Country", "Country Explanation",
284
+ # "Predicted Sample Type", "Sample Type Explanation",
285
+ # "Sources", "Query_cost", "Time cost"
286
+ # ])
287
+
288
+ # # ✅ Google Sheets API setup
289
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
290
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
291
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
292
+ # client = gspread.authorize(creds)
293
+
294
+ # # ✅ Open the known_samples sheet
295
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
296
+ # sheet = spreadsheet.sheet1
297
+
298
+ # # Read old data
299
+ # existing_data = sheet.get_all_values()
300
+ # if existing_data:
301
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
302
+ # else:
303
+ # df_old = pd.DataFrame(columns=df_new.columns)
304
+
305
+ # #Combine and remove duplicates
306
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
307
+
308
+ # # Clear and write back
309
+ # sheet.clear()
310
+ # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
311
+
312
+ # except Exception as e:
313
+ # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
314
+ try:
315
+ # Prepare as DataFrame
316
+ df_new = pd.DataFrame(save_rows, columns=[
317
+ "Sample ID", "Predicted Country", "Country Explanation",
318
+ "Predicted Sample Type", "Sample Type Explanation",
319
+ "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
320
+ ])
321
+
322
+ # Setup Google Sheets
323
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
324
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
325
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
326
+ client = gspread.authorize(creds)
327
+ spreadsheet = client.open("known_samples")
328
+ sheet = spreadsheet.sheet1
329
+
330
+ # Read existing data
331
+ existing_data = sheet.get_all_values()
332
+
333
+ if existing_data:
334
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
335
+
336
+ else:
337
+
338
+ df_old = pd.DataFrame(columns=[
339
+ "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
340
+ "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
341
+ "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output"
342
+ ])
343
+
344
+
345
+ # ✅ Index by Sample ID
346
+ df_old.set_index("Sample ID", inplace=True)
347
+ df_new.set_index("Sample ID", inplace=True)
348
+
349
+ # ✅ Update only matching fields
350
+ update_columns = [
351
+ "Predicted Country", "Predicted Sample Type", "Country Explanation",
352
+ "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
353
+ ]
354
+ for idx, row in df_new.iterrows():
355
+ if idx not in df_old.index:
356
+ df_old.loc[idx] = "" # new row, fill empty first
357
+ for col in update_columns:
358
+ if pd.notna(row[col]) and row[col] != "":
359
+ df_old.at[idx, col] = row[col]
360
+
361
+ # Reset and write back
362
+ df_old.reset_index(inplace=True)
363
+ sheet.clear()
364
+ sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
365
+ print(" Match results saved to known_samples.")
366
+
367
+ except Exception as e:
368
+ print(f"❌ Failed to update known_samples: {e}")
369
+
370
+
371
+ return rows#, summary, labelAncient_Modern, explain_label
372
+
373
+ # save the batch input in excel file
374
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
375
+ # with pd.ExcelWriter(filename) as writer:
376
+ # # Save table
377
+ # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
378
+ # df.to_excel(writer, sheet_name="Detailed Results", index=False)
379
+ # try:
380
+ # df_old = pd.read_excel(filename)
381
+ # except:
382
+ # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
383
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
384
+ # # if os.path.exists(filename):
385
+ # # df_old = pd.read_excel(filename)
386
+ # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
387
+ # # else:
388
+ # # df_combined = df_new
389
+ # df_combined.to_excel(filename, index=False)
390
+ # # # Save summary
391
+ # # summary_df = pd.DataFrame({"Summary": [summary_text]})
392
+ # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
393
+
394
+ # # # Save flag
395
+ # # flag_df = pd.DataFrame({"Flag": [flag_text]})
396
+ # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
397
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
398
+ # df_new = pd.DataFrame(all_rows, columns=[
399
+ # "Sample ID", "Predicted Country", "Country Explanation",
400
+ # "Predicted Sample Type", "Sample Type Explanation",
401
+ # "Sources", "Time cost"
402
+ # ])
403
+
404
+ # try:
405
+ # if os.path.exists(filename):
406
+ # df_old = pd.read_excel(filename)
407
+ # else:
408
+ # df_old = pd.DataFrame(columns=df_new.columns)
409
+ # except Exception as e:
410
+ # print(f"⚠️ Warning reading old Excel file: {e}")
411
+ # df_old = pd.DataFrame(columns=df_new.columns)
412
+
413
+ # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
414
+ # df_old.set_index("Sample ID", inplace=True)
415
+ # df_new.set_index("Sample ID", inplace=True)
416
+
417
+ # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
418
+
419
+ # df_combined = df_old.reset_index()
420
+
421
+ # try:
422
+ # df_combined.to_excel(filename, index=False)
423
+ # except Exception as e:
424
+ # print(f" Failed to write Excel file {filename}: {e}")
425
+ def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
426
+ df_new = pd.DataFrame(all_rows, columns=[
427
+ "Sample ID", "Predicted Country", "Country Explanation",
428
+ "Predicted Sample Type", "Sample Type Explanation",
429
+ "Sources", "Time cost"
430
+ ])
431
+
432
+ if is_resume and os.path.exists(filename):
433
+ try:
434
+ df_old = pd.read_excel(filename)
435
+ except Exception as e:
436
+ print(f"⚠️ Warning reading old Excel file: {e}")
437
+ df_old = pd.DataFrame(columns=df_new.columns)
438
+
439
+ # Set index and update existing rows
440
+ df_old.set_index("Sample ID", inplace=True)
441
+ df_new.set_index("Sample ID", inplace=True)
442
+ df_old.update(df_new)
443
+
444
+ df_combined = df_old.reset_index()
445
+ else:
446
+ # If not resuming or file doesn't exist, just use new rows
447
+ df_combined = df_new
448
+
449
+ try:
450
+ df_combined.to_excel(filename, index=False)
451
+ except Exception as e:
452
+ print(f"❌ Failed to write Excel file {filename}: {e}")
453
+
454
+
455
+ # save the batch input in JSON file
456
+ def save_to_json(all_rows, summary_text, flag_text, filename):
457
+ output_dict = {
458
+ "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
459
+ # "Summary_Text": summary_text,
460
+ # "Ancient_Modern_Flag": flag_text
461
+ }
462
+
463
+ # If all_rows is a DataFrame, convert it
464
+ if isinstance(all_rows, pd.DataFrame):
465
+ output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
466
+
467
+ with open(filename, "w") as external_file:
468
+ json.dump(output_dict, external_file, indent=2)
469
+
470
+ # save the batch input in Text file
471
+ def save_to_txt(all_rows, summary_text, flag_text, filename):
472
+ if isinstance(all_rows, pd.DataFrame):
473
+ detailed_results = all_rows.to_dict(orient="records")
474
+ output = ""
475
+ #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
476
+ output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
477
+ for r in detailed_results:
478
+ output += ",".join([str(v) for v in r.values()]) + "\n\n"
479
+ with open(filename, "w") as f:
480
+ f.write("=== Detailed Results ===\n")
481
+ f.write(output + "\n")
482
+
483
+ # f.write("\n=== Summary ===\n")
484
+ # f.write(summary_text + "\n")
485
+
486
+ # f.write("\n=== Ancient/Modern Flag ===\n")
487
+ # f.write(flag_text + "\n")
488
+
489
+ def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
490
+ tmp_dir = tempfile.mkdtemp()
491
+
492
+ #html_table = all_rows.value # assuming this is stored somewhere
493
+
494
+ # Parse back to DataFrame
495
+ #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
496
+ all_rows = pd.read_html(StringIO(all_rows))[0]
497
+ print(all_rows)
498
+
499
+ if output_type == "Excel":
500
+ file_path = f"{tmp_dir}/batch_output.xlsx"
501
+ save_to_excel(all_rows, summary_text, flag_text, file_path)
502
+ elif output_type == "JSON":
503
+ file_path = f"{tmp_dir}/batch_output.json"
504
+ save_to_json(all_rows, summary_text, flag_text, file_path)
505
+ print("Done with JSON")
506
+ elif output_type == "TXT":
507
+ file_path = f"{tmp_dir}/batch_output.txt"
508
+ save_to_txt(all_rows, summary_text, flag_text, file_path)
509
+ else:
510
+ return gr.update(visible=False) # invalid option
511
+
512
+ return gr.update(value=file_path, visible=True)
513
+ # save cost by checking the known outputs
514
+
515
+ # def check_known_output(accession):
516
+ # if not os.path.exists(KNOWN_OUTPUT_PATH):
517
+ # return None
518
+
519
+ # try:
520
+ # df = pd.read_excel(KNOWN_OUTPUT_PATH)
521
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
522
+ # if match:
523
+ # accession = match.group(0)
524
+
525
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
526
+ # if not matched.empty:
527
+ # return matched.iloc[0].to_dict() # Return the cached row
528
+ # except Exception as e:
529
+ # print(f"⚠️ Failed to load known samples: {e}")
530
+ # return None
531
+
532
+ # def check_known_output(accession):
533
+ # try:
534
+ # # ✅ Load credentials from Hugging Face secret
535
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
536
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
537
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
538
+ # client = gspread.authorize(creds)
539
+
540
+ # # Open the known_samples sheet
541
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
542
+ # sheet = spreadsheet.sheet1
543
+
544
+ # # Read all rows
545
+ # data = sheet.get_all_values()
546
+ # if not data:
547
+ # return None
548
+
549
+ # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
550
+
551
+ # # Normalize accession pattern
552
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
553
+ # if match:
554
+ # accession = match.group(0)
555
+
556
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
557
+ # if not matched.empty:
558
+ # return matched.iloc[0].to_dict()
559
+
560
+ # except Exception as e:
561
+ # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
562
+ # return None
563
+ def check_known_output(accession):
564
+ print("inside check known output function")
565
+ try:
566
+ # ✅ Load credentials from Hugging Face secret
567
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
568
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
569
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
570
+ client = gspread.authorize(creds)
571
+
572
+ spreadsheet = client.open("known_samples")
573
+ sheet = spreadsheet.sheet1
574
+
575
+ data = sheet.get_all_values()
576
+ if not data:
577
+ print("⚠️ Google Sheet 'known_samples' is empty.")
578
+ return None
579
+
580
+ df = pd.DataFrame(data[1:], columns=data[0])
581
+ if "Sample ID" not in df.columns:
582
+ print(" Column 'Sample ID' not found in Google Sheet.")
583
+ return None
584
+
585
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
586
+ if match:
587
+ accession = match.group(0)
588
+
589
+ matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
590
+ if not matched.empty:
591
+ #return matched.iloc[0].to_dict()
592
+ row = matched.iloc[0]
593
+ country = row.get("Predicted Country", "").strip().lower()
594
+ sample_type = row.get("Predicted Sample Type", "").strip().lower()
595
+
596
+ if country and country != "unknown" and sample_type and sample_type != "unknown":
597
+ return row.to_dict()
598
+ else:
599
+ print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
600
+ return None
601
+ else:
602
+ print(f"🔍 Accession {accession} not found in known_samples.")
603
+ return None
604
+
605
+ except Exception as e:
606
+ import traceback
607
+ print("❌ Exception occurred during check_known_output:")
608
+ traceback.print_exc()
609
+ return None
610
+
611
+
612
+ def hash_user_id(user_input):
613
+ return hashlib.sha256(user_input.encode()).hexdigest()
614
+
615
+ # Load and save usage count
616
+
617
+ # def load_user_usage():
618
+ # if not os.path.exists(USER_USAGE_TRACK_FILE):
619
+ # return {}
620
+
621
+ # try:
622
+ # with open(USER_USAGE_TRACK_FILE, "r") as f:
623
+ # content = f.read().strip()
624
+ # if not content:
625
+ # return {} # file is empty
626
+ # return json.loads(content)
627
+ # except (json.JSONDecodeError, ValueError):
628
+ # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
629
+ # return {} # fallback to empty dict
630
+ # def load_user_usage():
631
+ # try:
632
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
633
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
634
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
635
+ # client = gspread.authorize(creds)
636
+
637
+ # sheet = client.open("user_usage_log").sheet1
638
+ # data = sheet.get_all_records() # Assumes columns: email, usage_count
639
+
640
+ # usage = {}
641
+ # for row in data:
642
+ # email = row.get("email", "").strip().lower()
643
+ # count = int(row.get("usage_count", 0))
644
+ # if email:
645
+ # usage[email] = count
646
+ # return usage
647
+ # except Exception as e:
648
+ # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
649
+ # return {}
650
+ # def load_user_usage():
651
+ # try:
652
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
653
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
654
+
655
+ # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
656
+ # if not found:
657
+ # return {} # not found, start fresh
658
+
659
+ # #file_id = found[0]["id"]
660
+ # file_id = found
661
+ # content = pipeline.download_drive_file_content(file_id)
662
+ # return json.loads(content.strip()) if content.strip() else {}
663
+
664
+ # except Exception as e:
665
+ # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
666
+ # return {}
667
+ def load_user_usage():
668
+ try:
669
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
670
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
671
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
672
+ client = gspread.authorize(creds)
673
+
674
+ sheet = client.open("user_usage_log").sheet1
675
+ data = sheet.get_all_values()
676
+ print("data: ", data)
677
+ print("🧪 Raw header row from sheet:", data[0])
678
+ print("🧪 Character codes in each header:")
679
+ for h in data[0]:
680
+ print([ord(c) for c in h])
681
+
682
+ if not data or len(data) < 2:
683
+ print("⚠️ Sheet is empty or missing rows.")
684
+ return {}
685
+
686
+ headers = [h.strip().lower() for h in data[0]]
687
+ if "email" not in headers or "usage_count" not in headers:
688
+ print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
689
+ return {}
690
+
691
+ permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
692
+ df = pd.DataFrame(data[1:], columns=headers)
693
+
694
+ usage = {}
695
+ permitted = {}
696
+ for _, row in df.iterrows():
697
+ email = row.get("email", "").strip().lower()
698
+ try:
699
+ #count = int(row.get("usage_count", 0))
700
+ try:
701
+ count = int(float(row.get("usage_count", 0)))
702
+ except Exception:
703
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
704
+ count = 0
705
+
706
+ if email:
707
+ usage[email] = count
708
+ if permitted_index is not None:
709
+ try:
710
+ permitted_count = int(float(row.get("permitted_samples", 50)))
711
+ permitted[email] = permitted_count
712
+ except:
713
+ permitted[email] = 50
714
+
715
+ except ValueError:
716
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
717
+ return usage, permitted
718
+
719
+ except Exception as e:
720
+ print(f"❌ Error in load_user_usage: {e}")
721
+ return {}, {}
722
+
723
+
724
+
725
+ # def save_user_usage(usage):
726
+ # with open(USER_USAGE_TRACK_FILE, "w") as f:
727
+ # json.dump(usage, f, indent=2)
728
+
729
+ # def save_user_usage(usage_dict):
730
+ # try:
731
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
732
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
733
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
734
+ # client = gspread.authorize(creds)
735
+
736
+ # sheet = client.open("user_usage_log").sheet1
737
+ # sheet.clear() # clear old contents first
738
+
739
+ # # Write header + rows
740
+ # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
741
+ # sheet.update(rows)
742
+ # except Exception as e:
743
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
744
+ # def save_user_usage(usage_dict):
745
+ # try:
746
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
747
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
748
+
749
+ # import tempfile
750
+ # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
751
+ # print("💾 Saving this usage dict:", usage_dict)
752
+ # with open(tmp_path, "w") as f:
753
+ # json.dump(usage_dict, f, indent=2)
754
+
755
+ # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
756
+
757
+ # except Exception as e:
758
+ # print(f" Failed to save user_usage_log.json to Google Drive: {e}")
759
+ # def save_user_usage(usage_dict):
760
+ # try:
761
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
762
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
763
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
764
+ # client = gspread.authorize(creds)
765
+
766
+ # spreadsheet = client.open("user_usage_log")
767
+ # sheet = spreadsheet.sheet1
768
+
769
+ # # Step 1: Convert new usage to DataFrame
770
+ # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
771
+ # df_new["email"] = df_new["email"].str.strip().str.lower()
772
+
773
+ # # Step 2: Load existing data
774
+ # existing_data = sheet.get_all_values()
775
+ # print("🧪 Sheet existing_data:", existing_data)
776
+
777
+ # # Try to load old data
778
+ # if existing_data and len(existing_data[0]) >= 1:
779
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
780
+
781
+ # # Fix missing columns
782
+ # if "email" not in df_old.columns:
783
+ # df_old["email"] = ""
784
+ # if "usage_count" not in df_old.columns:
785
+ # df_old["usage_count"] = 0
786
+
787
+ # df_old["email"] = df_old["email"].str.strip().str.lower()
788
+ # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
789
+ # else:
790
+ # df_old = pd.DataFrame(columns=["email", "usage_count"])
791
+
792
+ # # Step 3: Merge
793
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True)
794
+ # df_combined = df_combined.groupby("email", as_index=False).sum()
795
+
796
+ # # Step 4: Write back
797
+ # sheet.clear()
798
+ # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
799
+ # print(" Saved user usage to user_usage_log sheet.")
800
+
801
+ # except Exception as e:
802
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
803
+ def save_user_usage(usage_dict):
804
+ try:
805
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
806
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
807
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
808
+ client = gspread.authorize(creds)
809
+
810
+ spreadsheet = client.open("user_usage_log")
811
+ sheet = spreadsheet.sheet1
812
+
813
+ # Build new df
814
+ df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
815
+ df_new["email"] = df_new["email"].str.strip().str.lower()
816
+ df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
817
+
818
+ # Read existing data
819
+ existing_data = sheet.get_all_values()
820
+ if existing_data and len(existing_data[0]) >= 2:
821
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
822
+ df_old["email"] = df_old["email"].str.strip().str.lower()
823
+ df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
824
+ else:
825
+ df_old = pd.DataFrame(columns=["email", "usage_count"])
826
+
827
+ # ✅ Overwrite specific emails only
828
+ df_old = df_old.set_index("email")
829
+ for email, count in usage_dict.items():
830
+ email = email.strip().lower()
831
+ df_old.loc[email, "usage_count"] = count
832
+ df_old = df_old.reset_index()
833
+
834
+ # Save
835
+ sheet.clear()
836
+ sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
837
+ print("✅ Saved user usage to user_usage_log sheet.")
838
+
839
+ except Exception as e:
840
+ print(f"❌ Failed to save user usage to Google Sheets: {e}")
841
+
842
+
843
+
844
+
845
+ # def increment_usage(user_id, num_samples=1):
846
+ # usage = load_user_usage()
847
+ # if user_id not in usage:
848
+ # usage[user_id] = 0
849
+ # usage[user_id] += num_samples
850
+ # save_user_usage(usage)
851
+ # return usage[user_id]
852
+ # def increment_usage(email: str, count: int):
853
+ # usage = load_user_usage()
854
+ # email_key = email.strip().lower()
855
+ # usage[email_key] = usage.get(email_key, 0) + count
856
+ # save_user_usage(usage)
857
+ # return usage[email_key]
858
+ def increment_usage(email: str, count: int = 1):
859
+ usage, permitted = load_user_usage()
860
+ email_key = email.strip().lower()
861
+ #usage[email_key] = usage.get(email_key, 0) + count
862
+ current = usage.get(email_key, 0)
863
+ new_value = current + count
864
+ max_allowed = permitted.get(email_key) or 50
865
+ usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower
866
+ print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
867
+ print("max allow is: ", max_allowed)
868
+ save_user_usage(usage)
869
+ return usage[email_key], max_allowed
870
+
871
+
872
+ # run the batch
873
+ def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
874
+ stop_flag=None, output_file_path=None,
875
+ limited_acc=50, yield_callback=None):
876
+ if user_email:
877
+ limited_acc += 10
878
+ accessions, error = extract_accessions_from_input(file, raw_text)
879
+ if error:
880
+ #return [], "", "", f"Error: {error}"
881
+ return [], f"Error: {error}", 0, "", ""
882
+ if resume_file:
883
+ accessions = get_incomplete_accessions(resume_file)
884
+ tmp_dir = tempfile.mkdtemp()
885
+ if not output_file_path:
886
+ if resume_file:
887
+ output_file_path = os.path.join(tmp_dir, resume_file)
888
+ else:
889
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
890
+
891
+ all_rows = []
892
+ # all_summaries = []
893
+ # all_flags = []
894
+ progress_lines = []
895
+ warning = ""
896
+ if len(accessions) > limited_acc:
897
+ accessions = accessions[:limited_acc]
898
+ warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
899
+ for i, acc in enumerate(accessions):
900
+ if stop_flag and stop_flag.value:
901
+ line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
902
+ progress_lines.append(line)
903
+ if yield_callback:
904
+ yield_callback(line)
905
+ print("🛑 User requested stop.")
906
+ break
907
+ print(f"[{i+1}/{len(accessions)}] Processing {acc}")
908
+ try:
909
+ # rows, summary, label, explain = summarize_results(acc)
910
+ rows = summarize_results(acc)
911
+ all_rows.extend(rows)
912
+ # all_summaries.append(f"**{acc}**\n{summary}")
913
+ # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
914
+ #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
915
+ save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
916
+ line = f" Processed {acc} ({i+1}/{len(accessions)})"
917
+ progress_lines.append(line)
918
+ if yield_callback:
919
+ yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
920
+ except Exception as e:
921
+ print(f"❌ Failed to process {acc}: {e}")
922
+ continue
923
+ #all_summaries.append(f"**{acc}**: Failed - {e}")
924
+ #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
925
+ limited_acc -= 1
926
+ """for row in all_rows:
927
+ source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
928
+
929
+ if source_column.startswith("http"): # Check if the source is a URL
930
+ # Wrap it with HTML anchor tags to make it clickable
931
+ row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
932
+ if not warning:
933
+ warning = f"You only have {limited_acc} left"
934
+ if user_email.strip():
935
+ user_hash = hash_user_id(user_email)
936
+ total_queries = increment_usage(user_hash, len(all_rows))
937
+ else:
938
+ total_queries = 0
939
+ yield_callback("✅ Finished!")
940
+
941
+ # summary_text = "\n\n---\n\n".join(all_summaries)
942
+ # flag_text = "\n\n---\n\n".join(all_flags)
943
+ #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
944
+ #return all_rows, gr.update(visible=True), gr.update(visible=False)
945
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
mtdna_classifier.py CHANGED
@@ -1,764 +1,764 @@
1
- # mtDNA Location Classifier MVP (Google Colab)
2
- # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
- import os
4
- #import streamlit as st
5
- import subprocess
6
- import re
7
- from Bio import Entrez
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
-
20
- # Set your email (required by NCBI Entrez)
21
- #Entrez.email = "[email protected]"
22
- import nltk
23
-
24
- nltk.download("stopwords")
25
- nltk.download("punkt")
26
- nltk.download('punkt_tab')
27
- # Step 1: Get PubMed ID from Accession using EDirect
28
- from Bio import Entrez, Medline
29
- import re
30
-
31
- Entrez.email = "[email protected]"
32
-
33
- # --- Helper Functions (Re-organized and Upgraded) ---
34
-
35
- def fetch_ncbi_metadata(accession_number):
36
- """
37
- Fetches metadata directly from NCBI GenBank using Entrez.
38
- Includes robust error handling and improved field extraction.
39
- Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
- Also attempts to extract ethnicity and sample_type (ancient/modern).
41
-
42
- Args:
43
- accession_number (str): The NCBI accession number (e.g., "ON792208").
44
-
45
- Returns:
46
- dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
- 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
- """
49
- Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
-
51
- country = "unknown"
52
- specific_location = "unknown"
53
- ethnicity = "unknown"
54
- sample_type = "unknown"
55
- collection_date = "unknown"
56
- isolate = "unknown"
57
- title = "unknown"
58
- doi = "unknown"
59
- pubmed_id = None
60
- all_feature = "unknown"
61
-
62
- KNOWN_COUNTRIES = [
63
- "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
- "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
- "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
- "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
- "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
- "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
- "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
- "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
- "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
- "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
- "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
- "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
- "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
- "Yemen", "Zambia", "Zimbabwe"
77
- ]
78
- COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
-
80
- try:
81
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
- record = Entrez.read(handle)
83
- handle.close()
84
-
85
- gb_seq = None
86
- # Validate record structure: It should be a list with at least one element (a dict)
87
- if isinstance(record, list) and len(record) > 0:
88
- if isinstance(record[0], dict):
89
- gb_seq = record[0]
90
- else:
91
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
- else:
93
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
-
95
- # If gb_seq is still None, return defaults
96
- if gb_seq is None:
97
- return {"country": "unknown",
98
- "specific_location": "unknown",
99
- "ethnicity": "unknown",
100
- "sample_type": "unknown",
101
- "collection_date": "unknown",
102
- "isolate": "unknown",
103
- "title": "unknown",
104
- "doi": "unknown",
105
- "pubmed_id": None,
106
- "all_features": "unknown"}
107
-
108
-
109
- # If gb_seq is valid, proceed with extraction
110
- collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
-
112
- references = gb_seq.get("GBSeq_references", [])
113
- for ref in references:
114
- if not pubmed_id:
115
- pubmed_id = ref.get("GBReference_pubmed",None)
116
- if title == "unknown":
117
- title = ref.get("GBReference_title","unknown")
118
- for xref in ref.get("GBReference_xref", []):
119
- if xref.get("GBXref_dbname") == "doi":
120
- doi = xref.get("GBXref_id")
121
- break
122
-
123
- features = gb_seq.get("GBSeq_feature-table", [])
124
-
125
- context_for_flagging = "" # Accumulate text for ancient/modern detection
126
- features_context = ""
127
- for feature in features:
128
- if feature.get("GBFeature_key") == "source":
129
- feature_context = ""
130
- qualifiers = feature.get("GBFeature_quals", [])
131
- found_country = "unknown"
132
- found_specific_location = "unknown"
133
- found_ethnicity = "unknown"
134
-
135
- temp_geo_loc_name = "unknown"
136
- temp_note_origin_locality = "unknown"
137
- temp_country_qual = "unknown"
138
- temp_locality_qual = "unknown"
139
- temp_collection_location_qual = "unknown"
140
- temp_isolation_source_qual = "unknown"
141
- temp_env_sample_qual = "unknown"
142
- temp_pop_qual = "unknown"
143
- temp_organism_qual = "unknown"
144
- temp_specimen_qual = "unknown"
145
- temp_strain_qual = "unknown"
146
-
147
- for qual in qualifiers:
148
- qual_name = qual.get("GBQualifier_name")
149
- qual_value = qual.get("GBQualifier_value")
150
- feature_context += qual_name + ": " + qual_value +"\n"
151
- if qual_name == "collection_date":
152
- collection_date = qual_value
153
- elif qual_name == "isolate":
154
- isolate = qual_value
155
- elif qual_name == "population":
156
- temp_pop_qual = qual_value
157
- elif qual_name == "organism":
158
- temp_organism_qual = qual_value
159
- elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
- temp_specimen_qual = qual_value
161
- elif qual_name == "strain":
162
- temp_strain_qual = qual_value
163
- elif qual_name == "isolation_source":
164
- temp_isolation_source_qual = qual_value
165
- elif qual_name == "environmental_sample":
166
- temp_env_sample_qual = qual_value
167
-
168
- if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
- elif qual_name == "note":
170
- if qual_value.startswith("origin_locality:"):
171
- temp_note_origin_locality = qual_value
172
- context_for_flagging += qual_value + " " # Capture all notes for flagging
173
- elif qual_name == "country": temp_country_qual = qual_value
174
- elif qual_name == "locality": temp_locality_qual = qual_value
175
- elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
-
177
-
178
- # --- Aggregate all relevant info into context_for_flagging ---
179
- context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
- context_for_flagging = context_for_flagging.strip()
181
-
182
- # --- Determine final country and specific_location based on priority ---
183
- if temp_geo_loc_name != "unknown":
184
- parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
- if len(parts) > 1:
186
- found_specific_location = parts[-1]; found_country = parts[0]
187
- else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
- elif temp_note_origin_locality != "unknown":
189
- match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
- if match:
191
- location_string = match.group(1).strip()
192
- parts = [p.strip() for p in location_string.split(':')]
193
- if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
194
- else: found_country = location_string; found_specific_location = "unknown"
195
- elif temp_locality_qual != "unknown":
196
- found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
197
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
198
- else: found_specific_location = temp_locality_qual; found_country = "unknown"
199
- elif temp_collection_location_qual != "unknown":
200
- found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
201
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
202
- else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
203
- elif temp_isolation_source_qual != "unknown":
204
- found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
205
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
206
- else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
207
- elif temp_env_sample_qual != "unknown":
208
- found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
209
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
210
- else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
211
- if found_country == "unknown" and temp_country_qual != "unknown":
212
- found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
213
- if found_country_match: found_country = found_country_match.group(1)
214
-
215
- country = found_country
216
- specific_location = found_specific_location
217
- # --- Determine final ethnicity ---
218
- if temp_pop_qual != "unknown":
219
- found_ethnicity = temp_pop_qual
220
- elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
221
- found_ethnicity = isolate
222
- elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
223
- eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
224
- if eth_match:
225
- found_ethnicity = eth_match.group(1).strip()
226
-
227
- ethnicity = found_ethnicity
228
-
229
- # --- Determine sample_type (ancient/modern) ---
230
- if context_for_flagging:
231
- sample_type, explain = detect_ancient_flag(context_for_flagging)
232
- features_context += feature_context + "\n"
233
- break
234
-
235
- if specific_location != "unknown" and specific_location.lower() == country.lower():
236
- specific_location = "unknown"
237
- if not features_context: features_context = "unknown"
238
- return {"country": country.lower(),
239
- "specific_location": specific_location.lower(),
240
- "ethnicity": ethnicity.lower(),
241
- "sample_type": sample_type.lower(),
242
- "collection_date": collection_date,
243
- "isolate": isolate,
244
- "title": title,
245
- "doi": doi,
246
- "pubmed_id": pubmed_id,
247
- "all_features": features_context}
248
-
249
- except:
250
- print(f"Error fetching NCBI data for {accession_number}")
251
- return {"country": "unknown",
252
- "specific_location": "unknown",
253
- "ethnicity": "unknown",
254
- "sample_type": "unknown",
255
- "collection_date": "unknown",
256
- "isolate": "unknown",
257
- "title": "unknown",
258
- "doi": "unknown",
259
- "pubmed_id": None,
260
- "all_features": "unknown"}
261
-
262
- # --- Helper function for country matching (re-defined from main code to be self-contained) ---
263
- _country_keywords = {
264
- "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
265
- "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
266
- "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
267
- "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
268
- "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
269
- "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
270
- "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
271
- "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
272
- "central india": "India", "east india": "India", "northeast india": "India",
273
- "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
274
- "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
275
- }
276
-
277
- def get_country_from_text(text):
278
- text_lower = text.lower()
279
- for keyword, country in _country_keywords.items():
280
- if keyword in text_lower:
281
- return country
282
- return "unknown"
283
- # The result will be seen as manualLink for the function get_paper_text
284
- # def search_google_custom(query, max_results=3):
285
- # # query should be the title from ncbi or paper/source title
286
- # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
287
- # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
288
- # endpoint = os.environ["SEARCH_ENDPOINT"]
289
- # params = {
290
- # "key": GOOGLE_CSE_API_KEY,
291
- # "cx": GOOGLE_CSE_CX,
292
- # "q": query,
293
- # "num": max_results
294
- # }
295
- # try:
296
- # response = requests.get(endpoint, params=params)
297
- # if response.status_code == 429:
298
- # print("Rate limit hit. Try again later.")
299
- # return []
300
- # response.raise_for_status()
301
- # data = response.json().get("items", [])
302
- # return [item.get("link") for item in data if item.get("link")]
303
- # except Exception as e:
304
- # print("Google CSE error:", e)
305
- # return []
306
-
307
- def search_google_custom(query, max_results=3):
308
- # query should be the title from ncbi or paper/source title
309
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
310
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
311
- endpoint = os.environ["SEARCH_ENDPOINT"]
312
- params = {
313
- "key": GOOGLE_CSE_API_KEY,
314
- "cx": GOOGLE_CSE_CX,
315
- "q": query,
316
- "num": max_results
317
- }
318
- try:
319
- response = requests.get(endpoint, params=params)
320
- if response.status_code == 429:
321
- print("Rate limit hit. Try again later.")
322
- print("try with back up account")
323
- try:
324
- return search_google_custom_backup(query, max_results)
325
- except:
326
- return []
327
- response.raise_for_status()
328
- data = response.json().get("items", [])
329
- return [item.get("link") for item in data if item.get("link")]
330
- except Exception as e:
331
- print("Google CSE error:", e)
332
- return []
333
-
334
- def search_google_custom_backup(query, max_results=3):
335
- # query should be the title from ncbi or paper/source title
336
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY_BACKUP"]
337
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX_BACKUP"]
338
- endpoint = os.environ["SEARCH_ENDPOINT"]
339
- params = {
340
- "key": GOOGLE_CSE_API_KEY,
341
- "cx": GOOGLE_CSE_CX,
342
- "q": query,
343
- "num": max_results
344
- }
345
- try:
346
- response = requests.get(endpoint, params=params)
347
- if response.status_code == 429:
348
- print("Rate limit hit. Try again later.")
349
- return []
350
- response.raise_for_status()
351
- data = response.json().get("items", [])
352
- return [item.get("link") for item in data if item.get("link")]
353
- except Exception as e:
354
- print("Google CSE error:", e)
355
- return []
356
- # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
357
- # Step 3.1: Extract Text
358
- # sub: download excel file
359
- def download_excel_file(url, save_path="temp.xlsx"):
360
- if "view.officeapps.live.com" in url:
361
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
362
- real_url = urllib.parse.unquote(parsed_url["src"][0])
363
- response = requests.get(real_url)
364
- with open(save_path, "wb") as f:
365
- f.write(response.content)
366
- return save_path
367
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
368
- response = requests.get(url)
369
- response.raise_for_status() # Raises error if download fails
370
- with open(save_path, "wb") as f:
371
- f.write(response.content)
372
- return save_path
373
- else:
374
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
375
- return url
376
- def get_paper_text(doi,id,manualLinks=None):
377
- # create the temporary folder to contain the texts
378
- folder_path = Path("data/"+str(id))
379
- if not folder_path.exists():
380
- cmd = f'mkdir data/{id}'
381
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
382
- print("data/"+str(id) +" created.")
383
- else:
384
- print("data/"+str(id) +" already exists.")
385
- saveLinkFolder = "data/"+id
386
-
387
- link = 'https://doi.org/' + doi
388
- '''textsToExtract = { "doiLink":"paperText"
389
- "file1.pdf":"text1",
390
- "file2.doc":"text2",
391
- "file3.xlsx":excelText3'''
392
- textsToExtract = {}
393
- # get the file to create listOfFile for each id
394
- html = extractHTML.HTML("",link)
395
- jsonSM = html.getSupMaterial()
396
- text = ""
397
- links = [link] + sum((jsonSM[key] for key in jsonSM),[])
398
- if manualLinks != None:
399
- links += manualLinks
400
- for l in links:
401
- # get the main paper
402
- name = l.split("/")[-1]
403
- file_path = folder_path / name
404
- if l == link:
405
- text = html.getListSection()
406
- textsToExtract[link] = text
407
- elif l.endswith(".pdf"):
408
- if file_path.is_file():
409
- l = saveLinkFolder + "/" + name
410
- print("File exists.")
411
- p = pdf.PDF(l,saveLinkFolder,doi)
412
- f = p.openPDFFile()
413
- pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
414
- doc = fitz.open(pdf_path)
415
- text = "\n".join([page.get_text() for page in doc])
416
- textsToExtract[l] = text
417
- elif l.endswith(".doc") or l.endswith(".docx"):
418
- d = wordDoc.wordDoc(l,saveLinkFolder)
419
- text = d.extractTextByPage()
420
- textsToExtract[l] = text
421
- elif l.split(".")[-1].lower() in "xlsx":
422
- wc = word2vec.word2Vec()
423
- # download excel file if it not downloaded yet
424
- savePath = saveLinkFolder +"/"+ l.split("/")[-1]
425
- excelPath = download_excel_file(l, savePath)
426
- corpus = wc.tableTransformToCorpusText([],excelPath)
427
- text = ''
428
- for c in corpus:
429
- para = corpus[c]
430
- for words in para:
431
- text += " ".join(words)
432
- textsToExtract[l] = text
433
- # delete folder after finishing getting text
434
- #cmd = f'rm -r data/{id}'
435
- #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
436
- return textsToExtract
437
- # Step 3.2: Extract context
438
- def extract_context(text, keyword, window=500):
439
- # firstly try accession number
440
- idx = text.find(keyword)
441
- if idx == -1:
442
- return "Sample ID not found."
443
- return text[max(0, idx-window): idx+window]
444
- def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
445
- if keep_if is None:
446
- keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
447
-
448
- outputs = ""
449
- text = text.lower()
450
-
451
- # If isolate is provided, prioritize paragraphs that mention it
452
- # If isolate is provided, prioritize paragraphs that mention it
453
- if accession and accession.lower() in text:
454
- if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
455
- outputs += extract_context(text, accession.lower(), window=700)
456
- if isolate and isolate.lower() in text:
457
- if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
458
- outputs += extract_context(text, isolate.lower(), window=700)
459
- for keyword in keep_if:
460
- para = extract_context(text, keyword)
461
- if para and para not in outputs:
462
- outputs += para + "\n"
463
- return outputs
464
- # Step 4: Classification for now (demo purposes)
465
- # 4.1: Using a HuggingFace model (question-answering)
466
- def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
467
- try:
468
- qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
469
- result = qa({"context": context, "question": question})
470
- return result.get("answer", "Unknown")
471
- except Exception as e:
472
- return f"Error: {str(e)}"
473
-
474
- # 4.2: Infer from haplogroup
475
- # Load pre-trained spaCy model for NER
476
- try:
477
- nlp = spacy.load("en_core_web_sm")
478
- except OSError:
479
- download("en_core_web_sm")
480
- nlp = spacy.load("en_core_web_sm")
481
-
482
- # Define the haplogroup-to-region mapping (simple rule-based)
483
- import csv
484
-
485
- def load_haplogroup_mapping(csv_path):
486
- mapping = {}
487
- with open(csv_path) as f:
488
- reader = csv.DictReader(f)
489
- for row in reader:
490
- mapping[row["haplogroup"]] = [row["region"],row["source"]]
491
- return mapping
492
-
493
- # Function to extract haplogroup from the text
494
- def extract_haplogroup(text):
495
- match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
496
- if match:
497
- submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
498
- if submatch:
499
- return submatch.group(0)
500
- else:
501
- return match.group(1) # fallback
502
- fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
503
- if fallback:
504
- return fallback.group(1)
505
- return None
506
-
507
-
508
- # Function to extract location based on NER
509
- def extract_location(text):
510
- doc = nlp(text)
511
- locations = []
512
- for ent in doc.ents:
513
- if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
514
- locations.append(ent.text)
515
- return locations
516
-
517
- # Function to infer location from haplogroup
518
- def infer_location_from_haplogroup(haplogroup):
519
- haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
520
- return haplo_map.get(haplogroup, ["Unknown","Unknown"])
521
-
522
- # Function to classify the mtDNA sample
523
- def classify_mtDNA_sample_from_haplo(text):
524
- # Extract haplogroup
525
- haplogroup = extract_haplogroup(text)
526
- # Extract location based on NER
527
- locations = extract_location(text)
528
- # Infer location based on haplogroup
529
- inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
530
- return {
531
- "source":sourceHaplo,
532
- "locations_found_in_context": locations,
533
- "haplogroup": haplogroup,
534
- "inferred_location": inferred_location
535
-
536
- }
537
- # 4.3 Get from available NCBI
538
- def infer_location_fromNCBI(accession):
539
- try:
540
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
541
- text = handle.read()
542
- handle.close()
543
- match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
544
- if match:
545
- return match.group(2), match.group(0) # This is the value like "Brunei"
546
- return "Not found", "Not found"
547
-
548
- except Exception as e:
549
- print("❌ Entrez error:", e)
550
- return "Not found", "Not found"
551
-
552
- ### ANCIENT/MODERN FLAG
553
- from Bio import Entrez
554
- import re
555
-
556
- def flag_ancient_modern(accession, textsToExtract, isolate=None):
557
- """
558
- Try to classify a sample as Ancient or Modern using:
559
- 1. NCBI accession (if available)
560
- 2. Supplementary text or context fallback
561
- """
562
- context = ""
563
- label, explain = "", ""
564
-
565
- try:
566
- # Check if we can fetch metadata from NCBI using the accession
567
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
568
- text = handle.read()
569
- handle.close()
570
-
571
- isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
572
- if isolate_source:
573
- context += isolate_source.group(0) + " "
574
-
575
- specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
576
- if specimen:
577
- context += specimen.group(0) + " "
578
-
579
- if context.strip():
580
- label, explain = detect_ancient_flag(context)
581
- if label!="Unknown":
582
- return label, explain + " from NCBI\n(" + context + ")"
583
-
584
- # If no useful NCBI metadata, check supplementary texts
585
- if textsToExtract:
586
- labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
587
-
588
- for source in textsToExtract:
589
- text_block = textsToExtract[source]
590
- context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
591
- label, explain = detect_ancient_flag(context)
592
-
593
- if label == "Ancient":
594
- labels["ancient"][0] += 1
595
- labels["ancient"][1] += f"{source}:\n{explain}\n\n"
596
- elif label == "Modern":
597
- labels["modern"][0] += 1
598
- labels["modern"][1] += f"{source}:\n{explain}\n\n"
599
- else:
600
- labels["unknown"] += 1
601
-
602
- if max(labels["modern"][0],labels["ancient"][0]) > 0:
603
- if labels["modern"][0] > labels["ancient"][0]:
604
- return "Modern", labels["modern"][1]
605
- else:
606
- return "Ancient", labels["ancient"][1]
607
- else:
608
- return "Unknown", "No strong keywords detected"
609
- else:
610
- print("No DOI or PubMed ID available for inference.")
611
- return "", ""
612
-
613
- except Exception as e:
614
- print("Error:", e)
615
- return "", ""
616
-
617
-
618
- def detect_ancient_flag(context_snippet):
619
- context = context_snippet.lower()
620
-
621
- ancient_keywords = [
622
- "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
623
- "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
624
- "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
625
- "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
626
- ]
627
-
628
- modern_keywords = [
629
- "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
630
- "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
631
- "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
632
- "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
633
- "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
634
- ]
635
-
636
- ancient_hits = [k for k in ancient_keywords if k in context]
637
- modern_hits = [k for k in modern_keywords if k in context]
638
-
639
- if ancient_hits and not modern_hits:
640
- return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
641
- elif modern_hits and not ancient_hits:
642
- return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
643
- elif ancient_hits and modern_hits:
644
- if len(ancient_hits) >= len(modern_hits):
645
- return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
646
- else:
647
- return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
648
-
649
- # Fallback to QA
650
- answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
651
- if answer.startswith("Error"):
652
- return "Unknown", answer
653
- if "ancient" in answer.lower():
654
- return "Ancient", f"Leaning ancient based on QA: {answer}"
655
- elif "modern" in answer.lower():
656
- return "Modern", f"Leaning modern based on QA: {answer}"
657
- else:
658
- return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
659
-
660
- # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
661
- def classify_sample_location(accession):
662
- outputs = {}
663
- keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
664
- # Step 1: get pubmed id and isolate
665
- pubmedID, isolate = get_info_from_accession(accession)
666
- '''if not pubmedID:
667
- return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
668
- if not isolate:
669
- isolate = "UNKNOWN_ISOLATE"
670
- # Step 2: get doi
671
- doi = get_doi_from_pubmed_id(pubmedID)
672
- '''if not doi:
673
- return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
674
- # Step 3: get text
675
- '''textsToExtract = { "doiLink":"paperText"
676
- "file1.pdf":"text1",
677
- "file2.doc":"text2",
678
- "file3.xlsx":excelText3'''
679
- if doi and pubmedID:
680
- textsToExtract = get_paper_text(doi,pubmedID)
681
- else: textsToExtract = {}
682
- '''if not textsToExtract:
683
- return {"error": f"No texts extracted for DOI {doi}"}'''
684
- if isolate not in [None, "UNKNOWN_ISOLATE"]:
685
- label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
686
- else:
687
- label, explain = flag_ancient_modern(accession,textsToExtract)
688
- # Step 4: prediction
689
- outputs[accession] = {}
690
- outputs[isolate] = {}
691
- # 4.0 Infer from NCBI
692
- location, outputNCBI = infer_location_fromNCBI(accession)
693
- NCBI_result = {
694
- "source": "NCBI",
695
- "sample_id": accession,
696
- "predicted_location": location,
697
- "context_snippet": outputNCBI}
698
- outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
699
- if textsToExtract:
700
- long_text = ""
701
- for key in textsToExtract:
702
- text = textsToExtract[key]
703
- # try accession number first
704
- outputs[accession][key] = {}
705
- keyword = accession
706
- context = extract_context(text, keyword, window=500)
707
- # 4.1: Using a HuggingFace model (question-answering)
708
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
709
- qa_result = {
710
- "source": key,
711
- "sample_id": keyword,
712
- "predicted_location": location,
713
- "context_snippet": context
714
- }
715
- outputs[keyword][key]["QAModel"] = qa_result
716
- # 4.2: Infer from haplogroup
717
- haplo_result = classify_mtDNA_sample_from_haplo(context)
718
- outputs[keyword][key]["haplogroup"] = haplo_result
719
- # try isolate
720
- keyword = isolate
721
- outputs[isolate][key] = {}
722
- context = extract_context(text, keyword, window=500)
723
- # 4.1.1: Using a HuggingFace model (question-answering)
724
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
725
- qa_result = {
726
- "source": key,
727
- "sample_id": keyword,
728
- "predicted_location": location,
729
- "context_snippet": context
730
- }
731
- outputs[keyword][key]["QAModel"] = qa_result
732
- # 4.2.1: Infer from haplogroup
733
- haplo_result = classify_mtDNA_sample_from_haplo(context)
734
- outputs[keyword][key]["haplogroup"] = haplo_result
735
- # add long text
736
- long_text += text + ". \n"
737
- # 4.3: UpgradeClassify
738
- # try sample_id as accession number
739
- sample_id = accession
740
- if sample_id:
741
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
742
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
743
- if locations!="No clear location found in top matches":
744
- outputs[sample_id]["upgradeClassifier"] = {}
745
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
746
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
747
- "sample_id": sample_id,
748
- "predicted_location": ", ".join(locations),
749
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
750
- }
751
- # try sample_id as isolate name
752
- sample_id = isolate
753
- if sample_id:
754
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
755
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
756
- if locations!="No clear location found in top matches":
757
- outputs[sample_id]["upgradeClassifier"] = {}
758
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
759
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
760
- "sample_id": sample_id,
761
- "predicted_location": ", ".join(locations),
762
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
763
- }
764
  return outputs, label, explain
 
1
+ # mtDNA Location Classifier MVP (Google Colab)
2
+ # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
+ import os
4
+ #import streamlit as st
5
+ import subprocess
6
+ import re
7
+ from Bio import Entrez
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
+
20
+ # Set your email (required by NCBI Entrez)
21
+ #Entrez.email = "[email protected]"
22
+ import nltk
23
+
24
+ nltk.download("stopwords")
25
+ nltk.download("punkt")
26
+ nltk.download('punkt_tab')
27
+ # Step 1: Get PubMed ID from Accession using EDirect
28
+ from Bio import Entrez, Medline
29
+ import re
30
+
31
+ Entrez.email = "[email protected]"
32
+
33
+ # --- Helper Functions (Re-organized and Upgraded) ---
34
+
35
+ def fetch_ncbi_metadata(accession_number):
36
+ """
37
+ Fetches metadata directly from NCBI GenBank using Entrez.
38
+ Includes robust error handling and improved field extraction.
39
+ Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
+ Also attempts to extract ethnicity and sample_type (ancient/modern).
41
+
42
+ Args:
43
+ accession_number (str): The NCBI accession number (e.g., "ON792208").
44
+
45
+ Returns:
46
+ dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
+ 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
+ """
49
+ Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
+
51
+ country = "unknown"
52
+ specific_location = "unknown"
53
+ ethnicity = "unknown"
54
+ sample_type = "unknown"
55
+ collection_date = "unknown"
56
+ isolate = "unknown"
57
+ title = "unknown"
58
+ doi = "unknown"
59
+ pubmed_id = None
60
+ all_feature = "unknown"
61
+
62
+ KNOWN_COUNTRIES = [
63
+ "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
+ "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
+ "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
+ "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
+ "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
+ "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
+ "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
+ "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
+ "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
+ "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
+ "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
+ "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
+ "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
+ "Yemen", "Zambia", "Zimbabwe"
77
+ ]
78
+ COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
+
80
+ try:
81
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
+ record = Entrez.read(handle)
83
+ handle.close()
84
+
85
+ gb_seq = None
86
+ # Validate record structure: It should be a list with at least one element (a dict)
87
+ if isinstance(record, list) and len(record) > 0:
88
+ if isinstance(record[0], dict):
89
+ gb_seq = record[0]
90
+ else:
91
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
+ else:
93
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
+
95
+ # If gb_seq is still None, return defaults
96
+ if gb_seq is None:
97
+ return {"country": "unknown",
98
+ "specific_location": "unknown",
99
+ "ethnicity": "unknown",
100
+ "sample_type": "unknown",
101
+ "collection_date": "unknown",
102
+ "isolate": "unknown",
103
+ "title": "unknown",
104
+ "doi": "unknown",
105
+ "pubmed_id": None,
106
+ "all_features": "unknown"}
107
+
108
+
109
+ # If gb_seq is valid, proceed with extraction
110
+ collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
+
112
+ references = gb_seq.get("GBSeq_references", [])
113
+ for ref in references:
114
+ if not pubmed_id:
115
+ pubmed_id = ref.get("GBReference_pubmed",None)
116
+ if title == "unknown":
117
+ title = ref.get("GBReference_title","unknown")
118
+ for xref in ref.get("GBReference_xref", []):
119
+ if xref.get("GBXref_dbname") == "doi":
120
+ doi = xref.get("GBXref_id")
121
+ break
122
+
123
+ features = gb_seq.get("GBSeq_feature-table", [])
124
+
125
+ context_for_flagging = "" # Accumulate text for ancient/modern detection
126
+ features_context = ""
127
+ for feature in features:
128
+ if feature.get("GBFeature_key") == "source":
129
+ feature_context = ""
130
+ qualifiers = feature.get("GBFeature_quals", [])
131
+ found_country = "unknown"
132
+ found_specific_location = "unknown"
133
+ found_ethnicity = "unknown"
134
+
135
+ temp_geo_loc_name = "unknown"
136
+ temp_note_origin_locality = "unknown"
137
+ temp_country_qual = "unknown"
138
+ temp_locality_qual = "unknown"
139
+ temp_collection_location_qual = "unknown"
140
+ temp_isolation_source_qual = "unknown"
141
+ temp_env_sample_qual = "unknown"
142
+ temp_pop_qual = "unknown"
143
+ temp_organism_qual = "unknown"
144
+ temp_specimen_qual = "unknown"
145
+ temp_strain_qual = "unknown"
146
+
147
+ for qual in qualifiers:
148
+ qual_name = qual.get("GBQualifier_name")
149
+ qual_value = qual.get("GBQualifier_value")
150
+ feature_context += qual_name + ": " + qual_value +"\n"
151
+ if qual_name == "collection_date":
152
+ collection_date = qual_value
153
+ elif qual_name == "isolate":
154
+ isolate = qual_value
155
+ elif qual_name == "population":
156
+ temp_pop_qual = qual_value
157
+ elif qual_name == "organism":
158
+ temp_organism_qual = qual_value
159
+ elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
+ temp_specimen_qual = qual_value
161
+ elif qual_name == "strain":
162
+ temp_strain_qual = qual_value
163
+ elif qual_name == "isolation_source":
164
+ temp_isolation_source_qual = qual_value
165
+ elif qual_name == "environmental_sample":
166
+ temp_env_sample_qual = qual_value
167
+
168
+ if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
+ elif qual_name == "note":
170
+ if qual_value.startswith("origin_locality:"):
171
+ temp_note_origin_locality = qual_value
172
+ context_for_flagging += qual_value + " " # Capture all notes for flagging
173
+ elif qual_name == "country": temp_country_qual = qual_value
174
+ elif qual_name == "locality": temp_locality_qual = qual_value
175
+ elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
+
177
+
178
+ # --- Aggregate all relevant info into context_for_flagging ---
179
+ context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
+ context_for_flagging = context_for_flagging.strip()
181
+
182
+ # --- Determine final country and specific_location based on priority ---
183
+ if temp_geo_loc_name != "unknown":
184
+ parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
+ if len(parts) > 1:
186
+ found_specific_location = parts[-1]; found_country = parts[0]
187
+ else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
+ elif temp_note_origin_locality != "unknown":
189
+ match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
+ if match:
191
+ location_string = match.group(1).strip()
192
+ parts = [p.strip() for p in location_string.split(':')]
193
+ if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
194
+ else: found_country = location_string; found_specific_location = "unknown"
195
+ elif temp_locality_qual != "unknown":
196
+ found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
197
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
198
+ else: found_specific_location = temp_locality_qual; found_country = "unknown"
199
+ elif temp_collection_location_qual != "unknown":
200
+ found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
201
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
202
+ else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
203
+ elif temp_isolation_source_qual != "unknown":
204
+ found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
205
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
206
+ else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
207
+ elif temp_env_sample_qual != "unknown":
208
+ found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
209
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
210
+ else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
211
+ if found_country == "unknown" and temp_country_qual != "unknown":
212
+ found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
213
+ if found_country_match: found_country = found_country_match.group(1)
214
+
215
+ country = found_country
216
+ specific_location = found_specific_location
217
+ # --- Determine final ethnicity ---
218
+ if temp_pop_qual != "unknown":
219
+ found_ethnicity = temp_pop_qual
220
+ elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
221
+ found_ethnicity = isolate
222
+ elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
223
+ eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
224
+ if eth_match:
225
+ found_ethnicity = eth_match.group(1).strip()
226
+
227
+ ethnicity = found_ethnicity
228
+
229
+ # --- Determine sample_type (ancient/modern) ---
230
+ if context_for_flagging:
231
+ sample_type, explain = detect_ancient_flag(context_for_flagging)
232
+ features_context += feature_context + "\n"
233
+ break
234
+
235
+ if specific_location != "unknown" and specific_location.lower() == country.lower():
236
+ specific_location = "unknown"
237
+ if not features_context: features_context = "unknown"
238
+ return {"country": country.lower(),
239
+ "specific_location": specific_location.lower(),
240
+ "ethnicity": ethnicity.lower(),
241
+ "sample_type": sample_type.lower(),
242
+ "collection_date": collection_date,
243
+ "isolate": isolate,
244
+ "title": title,
245
+ "doi": doi,
246
+ "pubmed_id": pubmed_id,
247
+ "all_features": features_context}
248
+
249
+ except:
250
+ print(f"Error fetching NCBI data for {accession_number}")
251
+ return {"country": "unknown",
252
+ "specific_location": "unknown",
253
+ "ethnicity": "unknown",
254
+ "sample_type": "unknown",
255
+ "collection_date": "unknown",
256
+ "isolate": "unknown",
257
+ "title": "unknown",
258
+ "doi": "unknown",
259
+ "pubmed_id": None,
260
+ "all_features": "unknown"}
261
+
262
+ # --- Helper function for country matching (re-defined from main code to be self-contained) ---
263
+ _country_keywords = {
264
+ "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
265
+ "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
266
+ "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
267
+ "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
268
+ "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
269
+ "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
270
+ "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
271
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
272
+ "central india": "India", "east india": "India", "northeast india": "India",
273
+ "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
274
+ "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
275
+ }
276
+
277
+ def get_country_from_text(text):
278
+ text_lower = text.lower()
279
+ for keyword, country in _country_keywords.items():
280
+ if keyword in text_lower:
281
+ return country
282
+ return "unknown"
283
+ # The result will be seen as manualLink for the function get_paper_text
284
+ # def search_google_custom(query, max_results=3):
285
+ # # query should be the title from ncbi or paper/source title
286
+ # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
287
+ # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
288
+ # endpoint = os.environ["SEARCH_ENDPOINT"]
289
+ # params = {
290
+ # "key": GOOGLE_CSE_API_KEY,
291
+ # "cx": GOOGLE_CSE_CX,
292
+ # "q": query,
293
+ # "num": max_results
294
+ # }
295
+ # try:
296
+ # response = requests.get(endpoint, params=params)
297
+ # if response.status_code == 429:
298
+ # print("Rate limit hit. Try again later.")
299
+ # return []
300
+ # response.raise_for_status()
301
+ # data = response.json().get("items", [])
302
+ # return [item.get("link") for item in data if item.get("link")]
303
+ # except Exception as e:
304
+ # print("Google CSE error:", e)
305
+ # return []
306
+
307
+ def search_google_custom(query, max_results=3):
308
+ # query should be the title from ncbi or paper/source title
309
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
310
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
311
+ endpoint = os.environ["SEARCH_ENDPOINT"]
312
+ params = {
313
+ "key": GOOGLE_CSE_API_KEY,
314
+ "cx": GOOGLE_CSE_CX,
315
+ "q": query,
316
+ "num": max_results
317
+ }
318
+ try:
319
+ response = requests.get(endpoint, params=params)
320
+ if response.status_code == 429:
321
+ print("Rate limit hit. Try again later.")
322
+ print("try with back up account")
323
+ try:
324
+ return search_google_custom_backup(query, max_results)
325
+ except:
326
+ return []
327
+ response.raise_for_status()
328
+ data = response.json().get("items", [])
329
+ return [item.get("link") for item in data if item.get("link")]
330
+ except Exception as e:
331
+ print("Google CSE error:", e)
332
+ return []
333
+
334
+ def search_google_custom_backup(query, max_results=3):
335
+ # query should be the title from ncbi or paper/source title
336
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY_BACKUP"]
337
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX_BACKUP"]
338
+ endpoint = os.environ["SEARCH_ENDPOINT"]
339
+ params = {
340
+ "key": GOOGLE_CSE_API_KEY,
341
+ "cx": GOOGLE_CSE_CX,
342
+ "q": query,
343
+ "num": max_results
344
+ }
345
+ try:
346
+ response = requests.get(endpoint, params=params)
347
+ if response.status_code == 429:
348
+ print("Rate limit hit. Try again later.")
349
+ return []
350
+ response.raise_for_status()
351
+ data = response.json().get("items", [])
352
+ return [item.get("link") for item in data if item.get("link")]
353
+ except Exception as e:
354
+ print("Google CSE error:", e)
355
+ return []
356
+ # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
357
+ # Step 3.1: Extract Text
358
+ # sub: download excel file
359
+ def download_excel_file(url, save_path="temp.xlsx"):
360
+ if "view.officeapps.live.com" in url:
361
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
362
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
363
+ response = requests.get(real_url)
364
+ with open(save_path, "wb") as f:
365
+ f.write(response.content)
366
+ return save_path
367
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
368
+ response = requests.get(url)
369
+ response.raise_for_status() # Raises error if download fails
370
+ with open(save_path, "wb") as f:
371
+ f.write(response.content)
372
+ return save_path
373
+ else:
374
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
375
+ return url
376
+ def get_paper_text(doi,id,manualLinks=None):
377
+ # create the temporary folder to contain the texts
378
+ folder_path = Path("data/"+str(id))
379
+ if not folder_path.exists():
380
+ cmd = f'mkdir data/{id}'
381
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
382
+ print("data/"+str(id) +" created.")
383
+ else:
384
+ print("data/"+str(id) +" already exists.")
385
+ saveLinkFolder = "data/"+id
386
+
387
+ link = 'https://doi.org/' + doi
388
+ '''textsToExtract = { "doiLink":"paperText"
389
+ "file1.pdf":"text1",
390
+ "file2.doc":"text2",
391
+ "file3.xlsx":excelText3'''
392
+ textsToExtract = {}
393
+ # get the file to create listOfFile for each id
394
+ html = extractHTML.HTML("",link)
395
+ jsonSM = html.getSupMaterial()
396
+ text = ""
397
+ links = [link] + sum((jsonSM[key] for key in jsonSM),[])
398
+ if manualLinks != None:
399
+ links += manualLinks
400
+ for l in links:
401
+ # get the main paper
402
+ name = l.split("/")[-1]
403
+ file_path = folder_path / name
404
+ if l == link:
405
+ text = html.getListSection()
406
+ textsToExtract[link] = text
407
+ elif l.endswith(".pdf"):
408
+ if file_path.is_file():
409
+ l = saveLinkFolder + "/" + name
410
+ print("File exists.")
411
+ p = pdf.PDF(l,saveLinkFolder,doi)
412
+ f = p.openPDFFile()
413
+ pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
414
+ doc = fitz.open(pdf_path)
415
+ text = "\n".join([page.get_text() for page in doc])
416
+ textsToExtract[l] = text
417
+ elif l.endswith(".doc") or l.endswith(".docx"):
418
+ d = wordDoc.wordDoc(l,saveLinkFolder)
419
+ text = d.extractTextByPage()
420
+ textsToExtract[l] = text
421
+ elif l.split(".")[-1].lower() in "xlsx":
422
+ wc = word2vec.word2Vec()
423
+ # download excel file if it not downloaded yet
424
+ savePath = saveLinkFolder +"/"+ l.split("/")[-1]
425
+ excelPath = download_excel_file(l, savePath)
426
+ corpus = wc.tableTransformToCorpusText([],excelPath)
427
+ text = ''
428
+ for c in corpus:
429
+ para = corpus[c]
430
+ for words in para:
431
+ text += " ".join(words)
432
+ textsToExtract[l] = text
433
+ # delete folder after finishing getting text
434
+ #cmd = f'rm -r data/{id}'
435
+ #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
436
+ return textsToExtract
437
+ # Step 3.2: Extract context
438
+ def extract_context(text, keyword, window=500):
439
+ # firstly try accession number
440
+ idx = text.find(keyword)
441
+ if idx == -1:
442
+ return "Sample ID not found."
443
+ return text[max(0, idx-window): idx+window]
444
+ def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
445
+ if keep_if is None:
446
+ keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
447
+
448
+ outputs = ""
449
+ text = text.lower()
450
+
451
+ # If isolate is provided, prioritize paragraphs that mention it
452
+ # If isolate is provided, prioritize paragraphs that mention it
453
+ if accession and accession.lower() in text:
454
+ if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
455
+ outputs += extract_context(text, accession.lower(), window=700)
456
+ if isolate and isolate.lower() in text:
457
+ if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
458
+ outputs += extract_context(text, isolate.lower(), window=700)
459
+ for keyword in keep_if:
460
+ para = extract_context(text, keyword)
461
+ if para and para not in outputs:
462
+ outputs += para + "\n"
463
+ return outputs
464
+ # Step 4: Classification for now (demo purposes)
465
+ # 4.1: Using a HuggingFace model (question-answering)
466
+ def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
467
+ try:
468
+ qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
469
+ result = qa({"context": context, "question": question})
470
+ return result.get("answer", "Unknown")
471
+ except Exception as e:
472
+ return f"Error: {str(e)}"
473
+
474
+ # 4.2: Infer from haplogroup
475
+ # Load pre-trained spaCy model for NER
476
+ try:
477
+ nlp = spacy.load("en_core_web_sm")
478
+ except OSError:
479
+ download("en_core_web_sm")
480
+ nlp = spacy.load("en_core_web_sm")
481
+
482
+ # Define the haplogroup-to-region mapping (simple rule-based)
483
+ import csv
484
+
485
+ def load_haplogroup_mapping(csv_path):
486
+ mapping = {}
487
+ with open(csv_path) as f:
488
+ reader = csv.DictReader(f)
489
+ for row in reader:
490
+ mapping[row["haplogroup"]] = [row["region"],row["source"]]
491
+ return mapping
492
+
493
+ # Function to extract haplogroup from the text
494
+ def extract_haplogroup(text):
495
+ match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
496
+ if match:
497
+ submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
498
+ if submatch:
499
+ return submatch.group(0)
500
+ else:
501
+ return match.group(1) # fallback
502
+ fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
503
+ if fallback:
504
+ return fallback.group(1)
505
+ return None
506
+
507
+
508
+ # Function to extract location based on NER
509
+ def extract_location(text):
510
+ doc = nlp(text)
511
+ locations = []
512
+ for ent in doc.ents:
513
+ if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
514
+ locations.append(ent.text)
515
+ return locations
516
+
517
+ # Function to infer location from haplogroup
518
+ def infer_location_from_haplogroup(haplogroup):
519
+ haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
520
+ return haplo_map.get(haplogroup, ["Unknown","Unknown"])
521
+
522
+ # Function to classify the mtDNA sample
523
+ def classify_mtDNA_sample_from_haplo(text):
524
+ # Extract haplogroup
525
+ haplogroup = extract_haplogroup(text)
526
+ # Extract location based on NER
527
+ locations = extract_location(text)
528
+ # Infer location based on haplogroup
529
+ inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
530
+ return {
531
+ "source":sourceHaplo,
532
+ "locations_found_in_context": locations,
533
+ "haplogroup": haplogroup,
534
+ "inferred_location": inferred_location
535
+
536
+ }
537
+ # 4.3 Get from available NCBI
538
+ def infer_location_fromNCBI(accession):
539
+ try:
540
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
541
+ text = handle.read()
542
+ handle.close()
543
+ match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
544
+ if match:
545
+ return match.group(2), match.group(0) # This is the value like "Brunei"
546
+ return "Not found", "Not found"
547
+
548
+ except Exception as e:
549
+ print("❌ Entrez error:", e)
550
+ return "Not found", "Not found"
551
+
552
+ ### ANCIENT/MODERN FLAG
553
+ from Bio import Entrez
554
+ import re
555
+
556
+ def flag_ancient_modern(accession, textsToExtract, isolate=None):
557
+ """
558
+ Try to classify a sample as Ancient or Modern using:
559
+ 1. NCBI accession (if available)
560
+ 2. Supplementary text or context fallback
561
+ """
562
+ context = ""
563
+ label, explain = "", ""
564
+
565
+ try:
566
+ # Check if we can fetch metadata from NCBI using the accession
567
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
568
+ text = handle.read()
569
+ handle.close()
570
+
571
+ isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
572
+ if isolate_source:
573
+ context += isolate_source.group(0) + " "
574
+
575
+ specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
576
+ if specimen:
577
+ context += specimen.group(0) + " "
578
+
579
+ if context.strip():
580
+ label, explain = detect_ancient_flag(context)
581
+ if label!="Unknown":
582
+ return label, explain + " from NCBI\n(" + context + ")"
583
+
584
+ # If no useful NCBI metadata, check supplementary texts
585
+ if textsToExtract:
586
+ labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
587
+
588
+ for source in textsToExtract:
589
+ text_block = textsToExtract[source]
590
+ context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
591
+ label, explain = detect_ancient_flag(context)
592
+
593
+ if label == "Ancient":
594
+ labels["ancient"][0] += 1
595
+ labels["ancient"][1] += f"{source}:\n{explain}\n\n"
596
+ elif label == "Modern":
597
+ labels["modern"][0] += 1
598
+ labels["modern"][1] += f"{source}:\n{explain}\n\n"
599
+ else:
600
+ labels["unknown"] += 1
601
+
602
+ if max(labels["modern"][0],labels["ancient"][0]) > 0:
603
+ if labels["modern"][0] > labels["ancient"][0]:
604
+ return "Modern", labels["modern"][1]
605
+ else:
606
+ return "Ancient", labels["ancient"][1]
607
+ else:
608
+ return "Unknown", "No strong keywords detected"
609
+ else:
610
+ print("No DOI or PubMed ID available for inference.")
611
+ return "", ""
612
+
613
+ except Exception as e:
614
+ print("Error:", e)
615
+ return "", ""
616
+
617
+
618
+ def detect_ancient_flag(context_snippet):
619
+ context = context_snippet.lower()
620
+
621
+ ancient_keywords = [
622
+ "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
623
+ "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
624
+ "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
625
+ "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
626
+ ]
627
+
628
+ modern_keywords = [
629
+ "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
630
+ "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
631
+ "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
632
+ "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
633
+ "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
634
+ ]
635
+
636
+ ancient_hits = [k for k in ancient_keywords if k in context]
637
+ modern_hits = [k for k in modern_keywords if k in context]
638
+
639
+ if ancient_hits and not modern_hits:
640
+ return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
641
+ elif modern_hits and not ancient_hits:
642
+ return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
643
+ elif ancient_hits and modern_hits:
644
+ if len(ancient_hits) >= len(modern_hits):
645
+ return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
646
+ else:
647
+ return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
648
+
649
+ # Fallback to QA
650
+ answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
651
+ if answer.startswith("Error"):
652
+ return "Unknown", answer
653
+ if "ancient" in answer.lower():
654
+ return "Ancient", f"Leaning ancient based on QA: {answer}"
655
+ elif "modern" in answer.lower():
656
+ return "Modern", f"Leaning modern based on QA: {answer}"
657
+ else:
658
+ return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
659
+
660
+ # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
661
+ def classify_sample_location(accession):
662
+ outputs = {}
663
+ keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
664
+ # Step 1: get pubmed id and isolate
665
+ pubmedID, isolate = get_info_from_accession(accession)
666
+ '''if not pubmedID:
667
+ return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
668
+ if not isolate:
669
+ isolate = "UNKNOWN_ISOLATE"
670
+ # Step 2: get doi
671
+ doi = get_doi_from_pubmed_id(pubmedID)
672
+ '''if not doi:
673
+ return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
674
+ # Step 3: get text
675
+ '''textsToExtract = { "doiLink":"paperText"
676
+ "file1.pdf":"text1",
677
+ "file2.doc":"text2",
678
+ "file3.xlsx":excelText3'''
679
+ if doi and pubmedID:
680
+ textsToExtract = get_paper_text(doi,pubmedID)
681
+ else: textsToExtract = {}
682
+ '''if not textsToExtract:
683
+ return {"error": f"No texts extracted for DOI {doi}"}'''
684
+ if isolate not in [None, "UNKNOWN_ISOLATE"]:
685
+ label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
686
+ else:
687
+ label, explain = flag_ancient_modern(accession,textsToExtract)
688
+ # Step 4: prediction
689
+ outputs[accession] = {}
690
+ outputs[isolate] = {}
691
+ # 4.0 Infer from NCBI
692
+ location, outputNCBI = infer_location_fromNCBI(accession)
693
+ NCBI_result = {
694
+ "source": "NCBI",
695
+ "sample_id": accession,
696
+ "predicted_location": location,
697
+ "context_snippet": outputNCBI}
698
+ outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
699
+ if textsToExtract:
700
+ long_text = ""
701
+ for key in textsToExtract:
702
+ text = textsToExtract[key]
703
+ # try accession number first
704
+ outputs[accession][key] = {}
705
+ keyword = accession
706
+ context = extract_context(text, keyword, window=500)
707
+ # 4.1: Using a HuggingFace model (question-answering)
708
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
709
+ qa_result = {
710
+ "source": key,
711
+ "sample_id": keyword,
712
+ "predicted_location": location,
713
+ "context_snippet": context
714
+ }
715
+ outputs[keyword][key]["QAModel"] = qa_result
716
+ # 4.2: Infer from haplogroup
717
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
718
+ outputs[keyword][key]["haplogroup"] = haplo_result
719
+ # try isolate
720
+ keyword = isolate
721
+ outputs[isolate][key] = {}
722
+ context = extract_context(text, keyword, window=500)
723
+ # 4.1.1: Using a HuggingFace model (question-answering)
724
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
725
+ qa_result = {
726
+ "source": key,
727
+ "sample_id": keyword,
728
+ "predicted_location": location,
729
+ "context_snippet": context
730
+ }
731
+ outputs[keyword][key]["QAModel"] = qa_result
732
+ # 4.2.1: Infer from haplogroup
733
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
734
+ outputs[keyword][key]["haplogroup"] = haplo_result
735
+ # add long text
736
+ long_text += text + ". \n"
737
+ # 4.3: UpgradeClassify
738
+ # try sample_id as accession number
739
+ sample_id = accession
740
+ if sample_id:
741
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
742
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
743
+ if locations!="No clear location found in top matches":
744
+ outputs[sample_id]["upgradeClassifier"] = {}
745
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
746
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
747
+ "sample_id": sample_id,
748
+ "predicted_location": ", ".join(locations),
749
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
750
+ }
751
+ # try sample_id as isolate name
752
+ sample_id = isolate
753
+ if sample_id:
754
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
755
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
756
+ if locations!="No clear location found in top matches":
757
+ outputs[sample_id]["upgradeClassifier"] = {}
758
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
759
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
760
+ "sample_id": sample_id,
761
+ "predicted_location": ", ".join(locations),
762
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
763
+ }
764
  return outputs, label, explain
mtdna_tool_explainer_updated.html CHANGED
@@ -1,135 +1,135 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>mtDNA Tool – System Overview</title>
7
-
8
- <style>
9
- .custom-container {
10
- background-color: #ffffff !important;
11
- color: #222222 !important;
12
- font-family: Arial, sans-serif !important;
13
- line-height: 1.6 !important;
14
- padding: 2rem !important;
15
- max-width: 900px !important;
16
- margin: auto !important;
17
- }
18
-
19
- .custom-container h1,
20
- .custom-container h2,
21
- .custom-container h3,
22
- .custom-container strong,
23
- .custom-container b,
24
- .custom-container p,
25
- .custom-container li,
26
- .custom-container ol,
27
- .custom-container ul,
28
- .custom-container span {
29
- color: #222222 !important;
30
- font-weight: normal !important;
31
- }
32
-
33
- .custom-container h1,
34
- .custom-container h2 {
35
- font-weight: bold !important;
36
- }
37
-
38
- .custom-container img {
39
- max-width: 100%;
40
- border: 1px solid #ccc;
41
- padding: 5px;
42
- background: #fff;
43
- }
44
-
45
- .custom-container code {
46
- background: none !important;
47
- color: #222 !important;
48
- font-family: inherit !important;
49
- font-size: inherit !important;
50
- padding: 0 !important;
51
- border-radius: 0 !important;
52
- }
53
-
54
-
55
- .custom-container .highlight {
56
- background: #ffffcc;
57
- padding: 4px 8px;
58
- border-left: 4px solid #ffcc00;
59
- margin: 1rem 0;
60
- color: #333 !important;
61
- }
62
- </style>
63
- </head>
64
-
65
- <body>
66
- <div class="custom-container">
67
-
68
- <h1>mtDNA Location Classifier – Brief System Pipeline and Usage Guide</h1>
69
-
70
- <p>The <strong>mtDNA Tool</strong> is a lightweight pipeline designed to help researchers extract metadata such as geographic origin, sample type (ancient/modern), and optional niche labels (e.g., ethnicity, specific location) from mtDNA GenBank accession numbers. It supports batch input and produces structured Excel summaries.</p>
71
-
72
- <h2>System Overview Diagram</h2>
73
- <p>The figure below shows the core execution flow—from input accession to final output.</p>
74
- <img src="https://huggingface.co/spaces/VyLala/mtDNALocation/resolve/main/flowchart.png" alt="mtDNA Pipeline Flowchart">
75
-
76
-
77
- <h2>Key Steps</h2>
78
- <ol>
79
- <li><strong>Input</strong>: One or more GenBank accession numbers are submitted (e.g., via UI, CSV, or text).</li>
80
-
81
- <li><strong>Metadata Collection</strong>: Using <code>fetch_ncbi_metadata</code>, the pipeline retrieves metadata like country, isolate, collection date, and reference title. If available, supplementary material and full-text articles are parsed using DOI, PubMed, or Google Custom Search.</li>
82
-
83
- <li><strong>Text Extraction & Preprocessing</strong>:
84
- <ul>
85
- <li>All available documents are parsed and cleaned (tables, paragraphs, overlapping sections).</li>
86
- <li>Text is merged into two formats: a smaller <code>chunk</code> and a full <code>all_output</code>.</li>
87
- </ul>
88
- </li>
89
-
90
- <li><strong>LLM-based Inference (Gemini + RAG)</strong>:
91
- <ul>
92
- <li>Chunks are embedded with FAISS and stored for reuse.</li>
93
- <li>The Gemini model answers specific queries like predicted country, sample type, and any niche label requested by the user.</li>
94
- </ul>
95
- </li>
96
-
97
- <li><strong>Result Structuring</strong>:
98
- <ul>
99
- <li>Each output includes predicted fields + explanation text (methods used, quotes, sources).</li>
100
- <li>Summarized and saved using <code>save_to_excel</code>.</li>
101
- </ul>
102
- </li>
103
- </ol>
104
-
105
- <h2>Output Format</h2>
106
- <p>The final output is an Excel file with the following fields:</p>
107
- <ul>
108
- <li><code>Sample ID</code></li>
109
- <li><code>Predicted Country</code> and <code>Country Explanation</code></li>
110
- <li><code>Predicted Sample Type</code> and <code>Sample Type Explanation</code></li>
111
- <li><code>Sources</code> (links to articles)</li>
112
- <li><code>Time Cost</code></li>
113
- </ul>
114
-
115
- <h2>System Highlights</h2>
116
- <ul>
117
- <li>RAG + Gemini integration for improved explanation and transparency</li>
118
- <li>Excel export for structured research use</li>
119
- <li>Optional ethnic/location/language inference using isolate names</li>
120
- <li>Quality check (e.g., fallback on short explanations, low token count)</li>
121
- <li>Report Button – After results are displayed, users can submit errors or mismatches using the report text box below the output table</li>
122
- </ul>
123
-
124
- <h2>Citation</h2>
125
- <div class="highlight">
126
- Phung, V. (2025). mtDNA Location Classifier. HuggingFace Spaces. https://huggingface.co/spaces/VyLala/mtDNALocation
127
- </div>
128
-
129
- <h2>Contact</h2>
130
- <p>If you are a researcher working with historical mtDNA data or edge-case accessions and need scalable inference or logging, reach out through the HuggingFace space or email provided in the repo README.</p>
131
-
132
- </div>
133
- </body>
134
- </html>
135
-
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>mtDNA Tool – System Overview</title>
7
+
8
+ <style>
9
+ .custom-container {
10
+ background-color: #ffffff !important;
11
+ color: #222222 !important;
12
+ font-family: Arial, sans-serif !important;
13
+ line-height: 1.6 !important;
14
+ padding: 2rem !important;
15
+ max-width: 900px !important;
16
+ margin: auto !important;
17
+ }
18
+
19
+ .custom-container h1,
20
+ .custom-container h2,
21
+ .custom-container h3,
22
+ .custom-container strong,
23
+ .custom-container b,
24
+ .custom-container p,
25
+ .custom-container li,
26
+ .custom-container ol,
27
+ .custom-container ul,
28
+ .custom-container span {
29
+ color: #222222 !important;
30
+ font-weight: normal !important;
31
+ }
32
+
33
+ .custom-container h1,
34
+ .custom-container h2 {
35
+ font-weight: bold !important;
36
+ }
37
+
38
+ .custom-container img {
39
+ max-width: 100%;
40
+ border: 1px solid #ccc;
41
+ padding: 5px;
42
+ background: #fff;
43
+ }
44
+
45
+ .custom-container code {
46
+ background: none !important;
47
+ color: #222 !important;
48
+ font-family: inherit !important;
49
+ font-size: inherit !important;
50
+ padding: 0 !important;
51
+ border-radius: 0 !important;
52
+ }
53
+
54
+
55
+ .custom-container .highlight {
56
+ background: #ffffcc;
57
+ padding: 4px 8px;
58
+ border-left: 4px solid #ffcc00;
59
+ margin: 1rem 0;
60
+ color: #333 !important;
61
+ }
62
+ </style>
63
+ </head>
64
+
65
+ <body>
66
+ <div class="custom-container">
67
+
68
+ <h1>mtDNA Location Classifier – Brief System Pipeline and Usage Guide</h1>
69
+
70
+ <p>The <strong>mtDNA Tool</strong> is a lightweight pipeline designed to help researchers extract metadata such as geographic origin, sample type (ancient/modern), and optional niche labels (e.g., ethnicity, specific location) from mtDNA GenBank accession numbers. It supports batch input and produces structured Excel summaries.</p>
71
+
72
+ <h2>System Overview Diagram</h2>
73
+ <p>The figure below shows the core execution flow—from input accession to final output.</p>
74
+ <img src="https://huggingface.co/spaces/VyLala/mtDNALocation/resolve/main/flowchart.png" alt="mtDNA Pipeline Flowchart">
75
+
76
+
77
+ <h2>Key Steps</h2>
78
+ <ol>
79
+ <li><strong>Input</strong>: One or more GenBank accession numbers are submitted (e.g., via UI, CSV, or text).</li>
80
+
81
+ <li><strong>Metadata Collection</strong>: Using <code>fetch_ncbi_metadata</code>, the pipeline retrieves metadata like country, isolate, collection date, and reference title. If available, supplementary material and full-text articles are parsed using DOI, PubMed, or Google Custom Search.</li>
82
+
83
+ <li><strong>Text Extraction & Preprocessing</strong>:
84
+ <ul>
85
+ <li>All available documents are parsed and cleaned (tables, paragraphs, overlapping sections).</li>
86
+ <li>Text is merged into two formats: a smaller <code>chunk</code> and a full <code>all_output</code>.</li>
87
+ </ul>
88
+ </li>
89
+
90
+ <li><strong>LLM-based Inference (Gemini + RAG)</strong>:
91
+ <ul>
92
+ <li>Chunks are embedded with FAISS and stored for reuse.</li>
93
+ <li>The Gemini model answers specific queries like predicted country, sample type, and any niche label requested by the user.</li>
94
+ </ul>
95
+ </li>
96
+
97
+ <li><strong>Result Structuring</strong>:
98
+ <ul>
99
+ <li>Each output includes predicted fields + explanation text (methods used, quotes, sources).</li>
100
+ <li>Summarized and saved using <code>save_to_excel</code>.</li>
101
+ </ul>
102
+ </li>
103
+ </ol>
104
+
105
+ <h2>Output Format</h2>
106
+ <p>The final output is an Excel file with the following fields:</p>
107
+ <ul>
108
+ <li><code>Sample ID</code></li>
109
+ <li><code>Predicted Country</code> and <code>Country Explanation</code></li>
110
+ <li><code>Predicted Sample Type</code> and <code>Sample Type Explanation</code></li>
111
+ <li><code>Sources</code> (links to articles)</li>
112
+ <li><code>Time Cost</code></li>
113
+ </ul>
114
+
115
+ <h2>System Highlights</h2>
116
+ <ul>
117
+ <li>RAG + Gemini integration for improved explanation and transparency</li>
118
+ <li>Excel export for structured research use</li>
119
+ <li>Optional ethnic/location/language inference using isolate names</li>
120
+ <li>Quality check (e.g., fallback on short explanations, low token count)</li>
121
+ <li>Report Button – After results are displayed, users can submit errors or mismatches using the report text box below the output table</li>
122
+ </ul>
123
+
124
+ <h2>Citation</h2>
125
+ <div class="highlight">
126
+ Phung, V. (2025). mtDNA Location Classifier. HuggingFace Spaces. https://huggingface.co/spaces/VyLala/mtDNALocation
127
+ </div>
128
+
129
+ <h2>Contact</h2>
130
+ <p>If you are a researcher working with historical mtDNA data or edge-case accessions and need scalable inference or logging, reach out through the HuggingFace space or email provided in the repo README.</p>
131
+
132
+ </div>
133
+ </body>
134
+ </html>
135
+
offer.html CHANGED
@@ -1,81 +1,81 @@
1
-
2
- <div style="font-family: sans-serif; line-height: 1.6;">
3
- <h1>mtDNA Location Classifier</h1>
4
-
5
- <h2> The purpose of the tool </h2>
6
- <p><strong>Make biological data reusable by labeling it better</strong></p>
7
- <p>Are you dealing with <strong>incomplete mtDNA metadata</strong> (like country, ethnicity, sample type)?<br>
8
- This tool helps researchers like you generate <strong>clean, structured labels</strong> — ready to use for your paper. <br>
9
- To help researchers reuse existing biological data more effectively by labeling missing or messy metadata with accuracy, clarity, and speed.
10
- </p>
11
-
12
- <hr>
13
-
14
- <h2> What You’ll Get:</h2>
15
- <ul>
16
- <li>Inference from sequence ID alone</li>
17
- <li>Handles hard edge cases</li>
18
- <li>Clear sample type, country, and more (ethnicity, phenotype, etc.)</li>
19
- <li>Excel export with citations</li>
20
- <li>Feedback-based refund policy</li>
21
- </ul>
22
-
23
- <hr>
24
-
25
- <h2>Free Tier</h2>
26
- <ul>
27
- <li>30 free samples — no email needed</li>
28
- <li>+20 bonus samples + Excel file when you enter your email</li>
29
- <li>Don’t like the result? Tell us why on the report — we won’t count the bad ones (email required)</li>
30
- </ul>
31
-
32
- <hr>
33
-
34
- <h2>Pricing — Pay As You Go (DIY)</h2>
35
- <table border="1" cellpadding="6" cellspacing="0">
36
- <thead>
37
- <tr>
38
- <th>Case Type</th>
39
- <th>Price/Sample</th>
40
- <th>Output</th>
41
- <th>Description</th>
42
- </tr>
43
- </thead>
44
- <tbody>
45
- <tr><td>Normal</td><td>$0.10</td><td>Sample Type + Country</td><td>Sample has known publication (e.g., PubMed ID) but no clear label</td></tr>
46
- <tr><td>Edge</td><td>$1.00</td><td>Sample Type + Country</td><td>Direct submissions, no DOI/article/PubMed ID</td></tr>
47
- <tr><td>Niche</td><td>$2.00</td><td>Sample Type + Country + 1 Custom Label</td><td>Adds output for phenotype, ethnicity, or city/province</td></tr>
48
- </tbody>
49
- </table>
50
-
51
- <hr>
52
-
53
- <h2>Batch Discount (1000+ Samples)</h2>
54
- <ul>
55
- <li><strong>Normal Output (Sample Type, Country)</strong> → $100 total ($0.10/sample)<br>Unsatisfied samples? We’ll reduce the cost ($0.10/sample deducted)</li>
56
- <li><strong>Niche Output (Sample Type, Country + 1 Niche Label)</strong> → $500 total ($0.50/sample)<br>Unsatisfied outputs will also be refunded proportionally</li>
57
- </ul>
58
-
59
- <hr>
60
-
61
- <h2>Early User Bonus (Limited!)</h2>
62
- <p>Are you one of our <strong>first 10 paying users</strong>?<br>
63
- Just type <code>early_user</code> in your email.</p>
64
- <p>You'll get <strong>20% lifetime discount</strong> on every plan — forever.<br>
65
- We’ll apply this automatically so you don’t have to calculate anything.</p>
66
-
67
- <hr>
68
-
69
- <h2>Our Mission</h2>
70
- <p>Rebuilding trust in genomic metadata — one mtDNA sample at a time.</p>
71
-
72
-
73
- <hr>
74
-
75
- <h2>Try It Now</h2>
76
- <p>Paste your sequence ID on our demo:<br>
77
- <a href="https://huggingface.co/spaces/VyLala/mtDNALocation" target="_blank">Try the Classifier</a></p>
78
- <p>Need help or bulk analysis?<br>
79
- <a href="mailto:[email protected]" target="_blank">Contact Us</a></p>
80
-
81
- </div>
 
1
+
2
+ <div style="font-family: sans-serif; line-height: 1.6;">
3
+ <h1>mtDNA Location Classifier</h1>
4
+
5
+ <h2> The purpose of the tool </h2>
6
+ <p><strong>Make biological data reusable by labeling it better</strong></p>
7
+ <p>Are you dealing with <strong>incomplete mtDNA metadata</strong> (like country, ethnicity, sample type)?<br>
8
+ This tool helps researchers like you generate <strong>clean, structured labels</strong> — ready to use for your paper. <br>
9
+ To help researchers reuse existing biological data more effectively by labeling missing or messy metadata with accuracy, clarity, and speed.
10
+ </p>
11
+
12
+ <hr>
13
+
14
+ <h2> What You’ll Get:</h2>
15
+ <ul>
16
+ <li>Inference from sequence ID alone</li>
17
+ <li>Handles hard edge cases</li>
18
+ <li>Clear sample type, country, and more (ethnicity, phenotype, etc.)</li>
19
+ <li>Excel export with citations</li>
20
+ <li>Feedback-based refund policy</li>
21
+ </ul>
22
+
23
+ <hr>
24
+
25
+ <h2>Free Tier</h2>
26
+ <ul>
27
+ <li>30 free samples — no email needed</li>
28
+ <li>+20 bonus samples + Excel file when you enter your email</li>
29
+ <li>Don’t like the result? Tell us why on the report — we won’t count the bad ones (email required)</li>
30
+ </ul>
31
+
32
+ <hr>
33
+
34
+ <h2>Pricing — Pay As You Go (DIY)</h2>
35
+ <table border="1" cellpadding="6" cellspacing="0">
36
+ <thead>
37
+ <tr>
38
+ <th>Case Type</th>
39
+ <th>Price/Sample</th>
40
+ <th>Output</th>
41
+ <th>Description</th>
42
+ </tr>
43
+ </thead>
44
+ <tbody>
45
+ <tr><td>Normal</td><td>$0.10</td><td>Sample Type + Country</td><td>Sample has known publication (e.g., PubMed ID) but no clear label</td></tr>
46
+ <tr><td>Edge</td><td>$1.00</td><td>Sample Type + Country</td><td>Direct submissions, no DOI/article/PubMed ID</td></tr>
47
+ <tr><td>Niche</td><td>$2.00</td><td>Sample Type + Country + 1 Custom Label</td><td>Adds output for phenotype, ethnicity, or city/province</td></tr>
48
+ </tbody>
49
+ </table>
50
+
51
+ <hr>
52
+
53
+ <h2>Batch Discount (1000+ Samples)</h2>
54
+ <ul>
55
+ <li><strong>Normal Output (Sample Type, Country)</strong> → $100 total ($0.10/sample)<br>Unsatisfied samples? We’ll reduce the cost ($0.10/sample deducted)</li>
56
+ <li><strong>Niche Output (Sample Type, Country + 1 Niche Label)</strong> → $500 total ($0.50/sample)<br>Unsatisfied outputs will also be refunded proportionally</li>
57
+ </ul>
58
+
59
+ <hr>
60
+
61
+ <h2>Early User Bonus (Limited!)</h2>
62
+ <p>Are you one of our <strong>first 10 paying users</strong>?<br>
63
+ Just type <code>early_user</code> in your email.</p>
64
+ <p>You'll get <strong>20% lifetime discount</strong> on every plan — forever.<br>
65
+ We’ll apply this automatically so you don’t have to calculate anything.</p>
66
+
67
+ <hr>
68
+
69
+ <h2>Our Mission</h2>
70
+ <p>Rebuilding trust in genomic metadata — one mtDNA sample at a time.</p>
71
+
72
+
73
+ <hr>
74
+
75
+ <h2>Try It Now</h2>
76
+ <p>Paste your sequence ID on our demo:<br>
77
+ <a href="https://huggingface.co/spaces/VyLala/mtDNALocation" target="_blank">Try the Classifier</a></p>
78
+ <p>Need help or bulk analysis?<br>
79
+ <a href="mailto:[email protected]" target="_blank">Contact Us</a></p>
80
+
81
+ </div>
pipeline.py CHANGED
The diff for this file is too large to render. See raw diff
 
smart_fallback.py CHANGED
@@ -1,259 +1,342 @@
1
- from Bio import Entrez, Medline
2
- #import model
3
- import mtdna_classifier
4
- from NER.html import extractHTML
5
- import data_preprocess
6
- import pipeline
7
- # Setup
8
- def fetch_ncbi(accession_number):
9
- try:
10
- Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
11
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
12
- record = Entrez.read(handle)
13
- handle.close()
14
- outputs = {"authors":"unknown",
15
- "institution":"unknown",
16
- "isolate":"unknown",
17
- "definition":"unknown",
18
- "title":"unknown",
19
- "seq_comment":"unknown",
20
- "collection_date":"unknown" } #'GBSeq_update-date': '25-OCT-2023', 'GBSeq_create-date'
21
- gb_seq = None
22
- # Validate record structure: It should be a list with at least one element (a dict)
23
- if isinstance(record, list) and len(record) > 0:
24
- if isinstance(record[0], dict):
25
- gb_seq = record[0]
26
- else:
27
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
28
- # extract collection date
29
- if "GBSeq_create-date" in gb_seq and outputs["collection_date"]=="unknown":
30
- outputs["collection_date"] = gb_seq["GBSeq_create-date"]
31
- else:
32
- if "GBSeq_update-date" in gb_seq and outputs["collection_date"]=="unknown":
33
- outputs["collection_date"] = gb_seq["GBSeq_update-date"]
34
- # extract definition
35
- if "GBSeq_definition" in gb_seq and outputs["definition"]=="unknown":
36
- outputs["definition"] = gb_seq["GBSeq_definition"]
37
- # extract related-reference things
38
- if "GBSeq_references" in gb_seq:
39
- for ref in gb_seq["GBSeq_references"]:
40
- # extract authors
41
- if "GBReference_authors" in ref and outputs["authors"]=="unknown":
42
- outputs["authors"] = "and ".join(ref["GBReference_authors"])
43
- # extract title
44
- if "GBReference_title" in ref and outputs["title"]=="unknown":
45
- outputs["title"] = ref["GBReference_title"]
46
- # extract submitted journal
47
- if 'GBReference_journal' in ref and outputs["institution"]=="unknown":
48
- outputs["institution"] = ref['GBReference_journal']
49
- # extract seq_comment
50
- if 'GBSeq_comment'in gb_seq and outputs["seq_comment"]=="unknown":
51
- outputs["seq_comment"] = gb_seq["GBSeq_comment"]
52
- # extract isolate
53
- if "GBSeq_feature-table" in gb_seq:
54
- if 'GBFeature_quals' in gb_seq["GBSeq_feature-table"][0]:
55
- for ref in gb_seq["GBSeq_feature-table"][0]["GBFeature_quals"]:
56
- if ref['GBQualifier_name'] == "isolate" and outputs["isolate"]=="unknown":
57
- outputs["isolate"] = ref["GBQualifier_value"]
58
- else:
59
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
60
-
61
- # If gb_seq is still None, return defaults
62
- if gb_seq is None:
63
- return {"authors":"unknown",
64
- "institution":"unknown",
65
- "isolate":"unknown",
66
- "definition":"unknown",
67
- "title":"unknown",
68
- "seq_comment":"unknown",
69
- "collection_date":"unknown" }
70
- return outputs
71
- except:
72
- print("error in fetching ncbi data")
73
- return {"authors":"unknown",
74
- "institution":"unknown",
75
- "isolate":"unknown",
76
- "definition":"unknown",
77
- "title":"unknown",
78
- "seq_comment":"unknown",
79
- "collection_date":"unknown" }
80
- # Fallback if NCBI crashed or cannot find accession on NBCI
81
- def google_accession_search(accession_id):
82
- """
83
- Search for metadata by accession ID using Google Custom Search.
84
- Falls back to known biological databases and archives.
85
- """
86
- queries = [
87
- f"{accession_id}",
88
- f"{accession_id} site:ncbi.nlm.nih.gov",
89
- f"{accession_id} site:pubmed.ncbi.nlm.nih.gov",
90
- f"{accession_id} site:europepmc.org",
91
- f"{accession_id} site:researchgate.net",
92
- f"{accession_id} mtDNA",
93
- f"{accession_id} mitochondrial DNA"
94
- ]
95
-
96
- links = []
97
- for query in queries:
98
- search_results = mtdna_classifier.search_google_custom(query, 2)
99
- for link in search_results:
100
- if link not in links:
101
- links.append(link)
102
- return links
103
-
104
- # Method 1: Smarter Google
105
- def smart_google_queries(metadata: dict):
106
- queries = []
107
-
108
- # Extract useful fields
109
- isolate = metadata.get("isolate")
110
- author = metadata.get("authors")
111
- institution = metadata.get("institution")
112
- title = metadata.get("title")
113
- combined = []
114
- # Construct queries
115
- if isolate and isolate!="unknown" and isolate!="Unpublished":
116
- queries.append(f'"{isolate}" mitochondrial DNA')
117
- queries.append(f'"{isolate}" site:ncbi.nlm.nih.gov')
118
-
119
- if author and author!="unknown" and author!="Unpublished":
120
- # try:
121
- # author_name = ".".join(author.split(' ')[0].split(".")[:-1]) # Use last name only
122
- # except:
123
- # try:
124
- # author_name = author.split(',')[0] # Use last name only
125
- # except:
126
- # author_name = author
127
- try:
128
- author_name = author.split(',')[0] # Use last name only
129
- except:
130
- author_name = author
131
- queries.append(f'"{author_name}" mitochondrial DNA')
132
- queries.append(f'"{author_name}" mtDNA site:researchgate.net')
133
-
134
- if institution and institution!="unknown" and institution!="Unpublished":
135
- try:
136
- short_inst = ",".join(institution.split(',')[:2]) # Take first part of institution
137
- except:
138
- try:
139
- short_inst = institution.split(',')[0]
140
- except:
141
- short_inst = institution
142
- queries.append(f'"{short_inst}" mtDNA sequence')
143
- #queries.append(f'"{short_inst}" isolate site:nature.com')
144
- if title and title!='unknown' and title!="Unpublished":
145
- if title!="Direct Submission":
146
- queries.append(title)
147
-
148
- return queries
149
-
150
- def filter_links_by_metadata(search_results, saveLinkFolder, accession=None, stop_flag=None):
151
- TRUSTED_DOMAINS = [
152
- "ncbi.nlm.nih.gov",
153
- "pubmed.ncbi.nlm.nih.gov",
154
- "pmc.ncbi.nlm.nih.gov",
155
- "biorxiv.org",
156
- "researchgate.net",
157
- "nature.com",
158
- "sciencedirect.com"
159
- ]
160
- if stop_flag is not None and stop_flag.value:
161
- print(f"🛑 Stop detected {accession}, aborting early...")
162
- return []
163
- def is_trusted_link(link):
164
- for domain in TRUSTED_DOMAINS:
165
- if domain in link:
166
- return True
167
- return False
168
- def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
169
- output = []
170
- keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
171
- if accession:
172
- keywords = [accession] + keywords
173
- title_snippet = link.lower()
174
- print("save link folder inside this filter function: ", saveLinkFolder)
175
- success_process, output_process = pipeline.run_with_timeout(data_preprocess.extract_text,args=(link,saveLinkFolder),timeout=60)
176
- if stop_flag is not None and stop_flag.value:
177
- print(f"🛑 Stop detected {accession}, aborting early...")
178
- return []
179
- if success_process:
180
- article_text = output_process
181
- print("yes succeed for getting article text")
182
- else:
183
- print("no suceed, fallback to no link")
184
- article_text = ""
185
- #article_text = data_preprocess.extract_text(link,saveLinkFolder)
186
- print("article text")
187
- #print(article_text)
188
- if stop_flag is not None and stop_flag.value:
189
- print(f"🛑 Stop detected {accession}, aborting early...")
190
- return []
191
- try:
192
- ext = link.split(".")[-1].lower()
193
- if ext not in ["pdf", "docx", "xlsx"]:
194
- html = extractHTML.HTML("", link)
195
- if stop_flag is not None and stop_flag.value:
196
- print(f"🛑 Stop detected {accession}, aborting early...")
197
- return []
198
- jsonSM = html.getSupMaterial()
199
- if jsonSM:
200
- output += sum((jsonSM[key] for key in jsonSM), [])
201
- except Exception:
202
- pass # continue silently
203
- for keyword in keywords:
204
- if keyword.lower() in article_text.lower():
205
- if link not in output:
206
- output.append([link,keyword.lower()])
207
- print("link and keyword for article text: ", link, keyword)
208
- return output
209
- if keyword.lower() in title_snippet.lower():
210
- if link not in output:
211
- output.append([link,keyword.lower()])
212
- print("link and keyword for title: ", link, keyword)
213
- return output
214
- return output
215
-
216
- filtered = []
217
- better_filter = []
218
- if len(search_results) > 0:
219
- for link in search_results:
220
- # if is_trusted_link(link):
221
- # if link not in filtered:
222
- # filtered.append(link)
223
- # else:
224
- print(link)
225
- if stop_flag is not None and stop_flag.value:
226
- print(f"🛑 Stop detected {accession}, aborting early...")
227
- return []
228
- if link:
229
- output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
230
- print("output link: ")
231
- print(output_link)
232
- for out_link in output_link:
233
- if isinstance(out_link,list) and len(out_link) > 1:
234
- print(out_link)
235
- kw = out_link[1]
236
- print("kw and acc: ", kw, accession.lower())
237
- if accession and kw == accession.lower():
238
- better_filter.append(out_link[0])
239
- filtered.append(out_link[0])
240
- else: filtered.append(out_link)
241
- print("done with link and here is filter: ",filtered)
242
- if better_filter:
243
- filtered = better_filter
244
- return filtered
245
-
246
- def smart_google_search(metadata):
247
- queries = smart_google_queries(metadata)
248
- links = []
249
- for q in queries:
250
- #print("\n🔍 Query:", q)
251
- results = mtdna_classifier.search_google_custom(q,2)
252
- for link in results:
253
- #print(f"- {link}")
254
- if link not in links:
255
- links.append(link)
256
- #filter_links = filter_links_by_metadata(links)
257
- return links
258
- # Method 2: Prompt LLM better or better ai search api with all
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  # the total information from even ncbi and all search
 
1
+ from Bio import Entrez, Medline
2
+ #import model
3
+ import mtdna_classifier
4
+ from NER.html import extractHTML
5
+ import data_preprocess
6
+ import pipeline
7
+ # Setup
8
+ def fetch_ncbi(accession_number):
9
+ try:
10
+ Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
11
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
12
+ record = Entrez.read(handle)
13
+ handle.close()
14
+ outputs = {"authors":"unknown",
15
+ "institution":"unknown",
16
+ "isolate":"unknown",
17
+ "definition":"unknown",
18
+ "title":"unknown",
19
+ "seq_comment":"unknown",
20
+ "collection_date":"unknown" } #'GBSeq_update-date': '25-OCT-2023', 'GBSeq_create-date'
21
+ gb_seq = None
22
+ # Validate record structure: It should be a list with at least one element (a dict)
23
+ if isinstance(record, list) and len(record) > 0:
24
+ if isinstance(record[0], dict):
25
+ gb_seq = record[0]
26
+ else:
27
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
28
+ # extract collection date
29
+ if "GBSeq_create-date" in gb_seq and outputs["collection_date"]=="unknown":
30
+ outputs["collection_date"] = gb_seq["GBSeq_create-date"]
31
+ else:
32
+ if "GBSeq_update-date" in gb_seq and outputs["collection_date"]=="unknown":
33
+ outputs["collection_date"] = gb_seq["GBSeq_update-date"]
34
+ # extract definition
35
+ if "GBSeq_definition" in gb_seq and outputs["definition"]=="unknown":
36
+ outputs["definition"] = gb_seq["GBSeq_definition"]
37
+ # extract related-reference things
38
+ if "GBSeq_references" in gb_seq:
39
+ for ref in gb_seq["GBSeq_references"]:
40
+ # extract authors
41
+ if "GBReference_authors" in ref and outputs["authors"]=="unknown":
42
+ outputs["authors"] = "and ".join(ref["GBReference_authors"])
43
+ # extract title
44
+ if "GBReference_title" in ref and outputs["title"]=="unknown":
45
+ outputs["title"] = ref["GBReference_title"]
46
+ # extract submitted journal
47
+ if 'GBReference_journal' in ref and outputs["institution"]=="unknown":
48
+ outputs["institution"] = ref['GBReference_journal']
49
+ # extract seq_comment
50
+ if 'GBSeq_comment'in gb_seq and outputs["seq_comment"]=="unknown":
51
+ outputs["seq_comment"] = gb_seq["GBSeq_comment"]
52
+ # extract isolate
53
+ if "GBSeq_feature-table" in gb_seq:
54
+ if 'GBFeature_quals' in gb_seq["GBSeq_feature-table"][0]:
55
+ for ref in gb_seq["GBSeq_feature-table"][0]["GBFeature_quals"]:
56
+ if ref['GBQualifier_name'] == "isolate" and outputs["isolate"]=="unknown":
57
+ outputs["isolate"] = ref["GBQualifier_value"]
58
+ else:
59
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
60
+
61
+ # If gb_seq is still None, return defaults
62
+ if gb_seq is None:
63
+ return {"authors":"unknown",
64
+ "institution":"unknown",
65
+ "isolate":"unknown",
66
+ "definition":"unknown",
67
+ "title":"unknown",
68
+ "seq_comment":"unknown",
69
+ "collection_date":"unknown" }
70
+ return outputs
71
+ except:
72
+ print("error in fetching ncbi data")
73
+ return {"authors":"unknown",
74
+ "institution":"unknown",
75
+ "isolate":"unknown",
76
+ "definition":"unknown",
77
+ "title":"unknown",
78
+ "seq_comment":"unknown",
79
+ "collection_date":"unknown" }
80
+ # Fallback if NCBI crashed or cannot find accession on NBCI
81
+ def google_accession_search(accession_id):
82
+ """
83
+ Search for metadata by accession ID using Google Custom Search.
84
+ Falls back to known biological databases and archives.
85
+ """
86
+ queries = [
87
+ f"{accession_id}",
88
+ f"{accession_id} site:ncbi.nlm.nih.gov",
89
+ f"{accession_id} site:pubmed.ncbi.nlm.nih.gov",
90
+ f"{accession_id} site:europepmc.org",
91
+ f"{accession_id} site:researchgate.net",
92
+ f"{accession_id} mtDNA",
93
+ f"{accession_id} mitochondrial DNA"
94
+ ]
95
+
96
+ links = []
97
+ for query in queries:
98
+ search_results = mtdna_classifier.search_google_custom(query, 2)
99
+ for link in search_results:
100
+ if link not in links:
101
+ links.append(link)
102
+ return links
103
+
104
+ # Method 1: Smarter Google
105
+ def smart_google_queries(metadata: dict):
106
+ queries = []
107
+
108
+ # Extract useful fields
109
+ isolate = metadata.get("isolate")
110
+ author = metadata.get("authors")
111
+ institution = metadata.get("institution")
112
+ title = metadata.get("title")
113
+ combined = []
114
+ # Construct queries
115
+ if isolate and isolate!="unknown" and isolate!="Unpublished":
116
+ queries.append(f'"{isolate}" mitochondrial DNA')
117
+ queries.append(f'"{isolate}" site:ncbi.nlm.nih.gov')
118
+
119
+ if author and author!="unknown" and author!="Unpublished":
120
+ # try:
121
+ # author_name = ".".join(author.split(' ')[0].split(".")[:-1]) # Use last name only
122
+ # except:
123
+ # try:
124
+ # author_name = author.split(',')[0] # Use last name only
125
+ # except:
126
+ # author_name = author
127
+ try:
128
+ author_name = author.split(',')[0] # Use last name only
129
+ except:
130
+ author_name = author
131
+ queries.append(f'"{author_name}" mitochondrial DNA')
132
+ queries.append(f'"{author_name}" mtDNA site:researchgate.net')
133
+
134
+ if institution and institution!="unknown" and institution!="Unpublished":
135
+ try:
136
+ short_inst = ",".join(institution.split(',')[:2]) # Take first part of institution
137
+ except:
138
+ try:
139
+ short_inst = institution.split(',')[0]
140
+ except:
141
+ short_inst = institution
142
+ queries.append(f'"{short_inst}" mtDNA sequence')
143
+ #queries.append(f'"{short_inst}" isolate site:nature.com')
144
+ if title and title!='unknown' and title!="Unpublished":
145
+ if title!="Direct Submission":
146
+ queries.append(title)
147
+
148
+ return queries
149
+
150
+ # def filter_links_by_metadata(search_results, saveLinkFolder, accession=None, stop_flag=None):
151
+ # TRUSTED_DOMAINS = [
152
+ # "ncbi.nlm.nih.gov",
153
+ # "pubmed.ncbi.nlm.nih.gov",
154
+ # "pmc.ncbi.nlm.nih.gov",
155
+ # "biorxiv.org",
156
+ # "researchgate.net",
157
+ # "nature.com",
158
+ # "sciencedirect.com"
159
+ # ]
160
+ # if stop_flag is not None and stop_flag.value:
161
+ # print(f"🛑 Stop detected {accession}, aborting early...")
162
+ # return []
163
+ # def is_trusted_link(link):
164
+ # for domain in TRUSTED_DOMAINS:
165
+ # if domain in link:
166
+ # return True
167
+ # return False
168
+ # def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
169
+ # output = []
170
+ # keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
171
+ # if accession:
172
+ # keywords = [accession] + keywords
173
+ # title_snippet = link.lower()
174
+ # print("save link folder inside this filter function: ", saveLinkFolder)
175
+ # success_process, output_process = pipeline.run_with_timeout(data_preprocess.extract_text,args=(link,saveLinkFolder),timeout=60)
176
+ # if stop_flag is not None and stop_flag.value:
177
+ # print(f"🛑 Stop detected {accession}, aborting early...")
178
+ # return []
179
+ # if success_process:
180
+ # article_text = output_process
181
+ # print("yes succeed for getting article text")
182
+ # else:
183
+ # print("no suceed, fallback to no link")
184
+ # article_text = ""
185
+ # #article_text = data_preprocess.extract_text(link,saveLinkFolder)
186
+ # print("article text")
187
+ # #print(article_text)
188
+ # if stop_flag is not None and stop_flag.value:
189
+ # print(f"🛑 Stop detected {accession}, aborting early...")
190
+ # return []
191
+ # try:
192
+ # ext = link.split(".")[-1].lower()
193
+ # if ext not in ["pdf", "docx", "xlsx"]:
194
+ # html = extractHTML.HTML("", link)
195
+ # if stop_flag is not None and stop_flag.value:
196
+ # print(f"🛑 Stop detected {accession}, aborting early...")
197
+ # return []
198
+ # jsonSM = html.getSupMaterial()
199
+ # if jsonSM:
200
+ # output += sum((jsonSM[key] for key in jsonSM), [])
201
+ # except Exception:
202
+ # pass # continue silently
203
+ # for keyword in keywords:
204
+ # if keyword.lower() in article_text.lower():
205
+ # if link not in output:
206
+ # output.append([link,keyword.lower()])
207
+ # print("link and keyword for article text: ", link, keyword)
208
+ # return output
209
+ # if keyword.lower() in title_snippet.lower():
210
+ # if link not in output:
211
+ # output.append([link,keyword.lower()])
212
+ # print("link and keyword for title: ", link, keyword)
213
+ # return output
214
+ # return output
215
+
216
+ # filtered = []
217
+ # better_filter = []
218
+ # if len(search_results) > 0:
219
+ # for link in search_results:
220
+ # # if is_trusted_link(link):
221
+ # # if link not in filtered:
222
+ # # filtered.append(link)
223
+ # # else:
224
+ # print(link)
225
+ # if stop_flag is not None and stop_flag.value:
226
+ # print(f"🛑 Stop detected {accession}, aborting early...")
227
+ # return []
228
+ # if link:
229
+ # output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
230
+ # print("output link: ")
231
+ # print(output_link)
232
+ # for out_link in output_link:
233
+ # if isinstance(out_link,list) and len(out_link) > 1:
234
+ # print(out_link)
235
+ # kw = out_link[1]
236
+ # print("kw and acc: ", kw, accession.lower())
237
+ # if accession and kw == accession.lower():
238
+ # better_filter.append(out_link[0])
239
+ # filtered.append(out_link[0])
240
+ # else: filtered.append(out_link)
241
+ # print("done with link and here is filter: ",filtered)
242
+ # if better_filter:
243
+ # filtered = better_filter
244
+ # return filtered
245
+
246
+ def filter_links_by_metadata(search_results, saveLinkFolder, accession=None):
247
+ TRUSTED_DOMAINS = [
248
+ "ncbi.nlm.nih.gov",
249
+ "pubmed.ncbi.nlm.nih.gov",
250
+ "pmc.ncbi.nlm.nih.gov",
251
+ "biorxiv.org",
252
+ "researchgate.net",
253
+ "nature.com",
254
+ "sciencedirect.com"
255
+ ]
256
+ def is_trusted_link(link):
257
+ for domain in TRUSTED_DOMAINS:
258
+ if domain in link:
259
+ return True
260
+ return False
261
+ def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
262
+ output = []
263
+ keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
264
+ #keywords = ["mtDNA", "mitochondrial"]
265
+ if accession:
266
+ keywords = [accession] + keywords
267
+ title_snippet = link.lower()
268
+ #print("save link folder inside this filter function: ", saveLinkFolder)
269
+ article_text = data_preprocess.extract_text(link,saveLinkFolder)
270
+ print("article text done")
271
+ #print(article_text)
272
+ try:
273
+ ext = link.split(".")[-1].lower()
274
+ if ext not in ["pdf", "docx", "xlsx"]:
275
+ html = extractHTML.HTML("", link)
276
+ jsonSM = html.getSupMaterial()
277
+ if jsonSM:
278
+ output += sum((jsonSM[key] for key in jsonSM), [])
279
+ except Exception:
280
+ pass # continue silently
281
+ for keyword in keywords:
282
+ if article_text:
283
+ if keyword.lower() in article_text.lower():
284
+ if link not in output:
285
+ output.append([link,keyword.lower(), article_text])
286
+ return output
287
+ if keyword.lower() in title_snippet.lower():
288
+ if link not in output:
289
+ output.append([link,keyword.lower()])
290
+ print("link and keyword for title: ", link, keyword)
291
+ return output
292
+ return output
293
+
294
+ filtered = {}
295
+ better_filter = {}
296
+ if len(search_results) > 0:
297
+ print(search_results)
298
+ for link in search_results:
299
+ # if is_trusted_link(link):
300
+ # if link not in filtered:
301
+ # filtered.append(link)
302
+ # else:
303
+ print(link)
304
+ if link:
305
+ output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
306
+ print("output link: ")
307
+ print(output_link)
308
+ for out_link in output_link:
309
+ if isinstance(out_link,list) and len(out_link) > 1:
310
+ print(out_link)
311
+ kw = out_link[1]
312
+ if accession and kw == accession.lower():
313
+ if len(out_link) == 2:
314
+ better_filter[out_link[0]] = ""
315
+ elif len(out_link) == 3:
316
+ # save article
317
+ better_filter[out_link[0]] = out_link[2]
318
+ if len(out_link) == 2:
319
+ better_filter[out_link[0]] = ""
320
+ elif len(out_link) == 3:
321
+ # save article
322
+ better_filter[out_link[0]] = out_link[2]
323
+ else: filtered[out_link] = ""
324
+ print("done with link and here is filter: ",filtered)
325
+ if better_filter:
326
+ filtered = better_filter
327
+ return filtered
328
+
329
+ def smart_google_search(metadata):
330
+ queries = smart_google_queries(metadata)
331
+ links = []
332
+ for q in queries:
333
+ #print("\n🔍 Query:", q)
334
+ results = mtdna_classifier.search_google_custom(q,2)
335
+ for link in results:
336
+ #print(f"- {link}")
337
+ if link not in links:
338
+ links.append(link)
339
+ #filter_links = filter_links_by_metadata(links)
340
+ return links
341
+ # Method 2: Prompt LLM better or better ai search api with all
342
  # the total information from even ncbi and all search
standardize_location.py CHANGED
@@ -1,91 +1,91 @@
1
- import requests
2
- import re
3
- import os
4
- import model
5
- # Normalize input
6
- def normalize_key(text):
7
- return re.sub(r"[^a-z0-9]", "", text.strip().lower())
8
-
9
- # Search for city/place (normal flow)
10
- def get_country_from_geonames(city_name):
11
- url = os.environ["URL_SEARCHJSON"]
12
- username = os.environ["USERNAME_GEO"]
13
- print("geoname: ", cityname)
14
- params = {
15
- "q": city_name,
16
- "maxRows": 1,
17
- "username": username
18
- }
19
- try:
20
- r = requests.get(url, params=params, timeout=5)
21
- data = r.json()
22
- if data.get("geonames"):
23
- return data["geonames"][0]["countryName"]
24
- except Exception as e:
25
- print("GeoNames searchJSON error:", e)
26
- return None
27
-
28
- # Search for country info using alpha-2/3 codes or name
29
- def get_country_from_countryinfo(input_code):
30
- url = os.environ["URL_COUNTRYJSON"]
31
- username = os.environ["USERNAME_GEO"]
32
- print("countryINFO: ", input_code)
33
- params = {
34
- "username": username
35
- }
36
- try:
37
- r = requests.get(url, params=params, timeout=5)
38
- data = r.json()
39
- if data.get("geonames"):
40
- input_code = input_code.strip().upper()
41
- for country in data["geonames"]:
42
- # Match against country name, country code (alpha-2), iso alpha-3
43
- if input_code in [
44
- country.get("countryName", "").upper(),
45
- country.get("countryCode", "").upper(),
46
- country.get("isoAlpha3", "").upper()
47
- ]:
48
- return country["countryName"]
49
- except Exception as e:
50
- print("GeoNames countryInfoJSON error:", e)
51
- return None
52
-
53
- # Combined smart lookup
54
- def smart_country_lookup(user_input):
55
- try:
56
- raw_input = user_input.strip()
57
- normalized = re.sub(r"[^a-zA-Z0-9]", "", user_input).upper() # normalize for codes (no strip spaces!)
58
- print("raw input for smart country lookup: ",raw_input, ". Normalized country: ", normalized)
59
- # Special case: if user writes "UK: London" → split and take main country part
60
- if ":" in raw_input:
61
- raw_input = raw_input.split(":")[0].strip() # only take "UK"
62
- # First try as country code (if 2-3 letters or common abbreviation)
63
- if len(normalized) <= 3:
64
- if normalized.upper() in ["UK","U.K","U.K."]:
65
- country = get_country_from_geonames(normalized.upper())
66
- print("get_country_from_geonames(normalized.upper()) ", country)
67
- if country:
68
- return country
69
- else:
70
- country = get_country_from_countryinfo(raw_input)
71
- print("get_country_from_countryinfo(raw_input) ", country)
72
- if country:
73
- return country
74
- print(raw_input)
75
- country = get_country_from_countryinfo(raw_input) # try full names
76
- print("get_country_from_countryinfo(raw_input) ", country)
77
- if country:
78
- return country
79
- # Otherwise, treat as city/place
80
- country = get_country_from_geonames(raw_input)
81
- print("get_country_from_geonames(raw_input) ", country)
82
- if country:
83
- return country
84
-
85
- return "Not found"
86
- except:
87
- country = model.get_country_from_text(user_input)
88
- if country.lower() !="unknown":
89
- return country
90
- else:
91
  return "Not found"
 
1
+ import requests
2
+ import re
3
+ import os
4
+ import model
5
+ # Normalize input
6
+ def normalize_key(text):
7
+ return re.sub(r"[^a-z0-9]", "", text.strip().lower())
8
+
9
+ # Search for city/place (normal flow)
10
+ def get_country_from_geonames(city_name):
11
+ url = os.environ["URL_SEARCHJSON"]
12
+ username = os.environ["USERNAME_GEO"]
13
+ print("geoname: ", cityname)
14
+ params = {
15
+ "q": city_name,
16
+ "maxRows": 1,
17
+ "username": username
18
+ }
19
+ try:
20
+ r = requests.get(url, params=params, timeout=5)
21
+ data = r.json()
22
+ if data.get("geonames"):
23
+ return data["geonames"][0]["countryName"]
24
+ except Exception as e:
25
+ print("GeoNames searchJSON error:", e)
26
+ return None
27
+
28
+ # Search for country info using alpha-2/3 codes or name
29
+ def get_country_from_countryinfo(input_code):
30
+ url = os.environ["URL_COUNTRYJSON"]
31
+ username = os.environ["USERNAME_GEO"]
32
+ print("countryINFO: ", input_code)
33
+ params = {
34
+ "username": username
35
+ }
36
+ try:
37
+ r = requests.get(url, params=params, timeout=5)
38
+ data = r.json()
39
+ if data.get("geonames"):
40
+ input_code = input_code.strip().upper()
41
+ for country in data["geonames"]:
42
+ # Match against country name, country code (alpha-2), iso alpha-3
43
+ if input_code in [
44
+ country.get("countryName", "").upper(),
45
+ country.get("countryCode", "").upper(),
46
+ country.get("isoAlpha3", "").upper()
47
+ ]:
48
+ return country["countryName"]
49
+ except Exception as e:
50
+ print("GeoNames countryInfoJSON error:", e)
51
+ return None
52
+
53
+ # Combined smart lookup
54
+ def smart_country_lookup(user_input):
55
+ try:
56
+ raw_input = user_input.strip()
57
+ normalized = re.sub(r"[^a-zA-Z0-9]", "", user_input).upper() # normalize for codes (no strip spaces!)
58
+ print("raw input for smart country lookup: ",raw_input, ". Normalized country: ", normalized)
59
+ # Special case: if user writes "UK: London" → split and take main country part
60
+ if ":" in raw_input:
61
+ raw_input = raw_input.split(":")[0].strip() # only take "UK"
62
+ # First try as country code (if 2-3 letters or common abbreviation)
63
+ if len(normalized) <= 3:
64
+ if normalized.upper() in ["UK","U.K","U.K."]:
65
+ country = get_country_from_geonames(normalized.upper())
66
+ print("get_country_from_geonames(normalized.upper()) ", country)
67
+ if country:
68
+ return country
69
+ else:
70
+ country = get_country_from_countryinfo(raw_input)
71
+ print("get_country_from_countryinfo(raw_input) ", country)
72
+ if country:
73
+ return country
74
+ print(raw_input)
75
+ country = get_country_from_countryinfo(raw_input) # try full names
76
+ print("get_country_from_countryinfo(raw_input) ", country)
77
+ if country:
78
+ return country
79
+ # Otherwise, treat as city/place
80
+ country = get_country_from_geonames(raw_input)
81
+ print("get_country_from_geonames(raw_input) ", country)
82
+ if country:
83
+ return country
84
+
85
+ return "Not found"
86
+ except:
87
+ country = model.get_country_from_text(user_input)
88
+ if country.lower() !="unknown":
89
+ return country
90
+ else:
91
  return "Not found"