VyLala commited on
Commit
6be4ec1
·
verified ·
1 Parent(s): c63e748

Upload 8 files

Browse files
Files changed (7) hide show
  1. app.py +422 -120
  2. data_preprocess.py +625 -0
  3. model.py +1255 -0
  4. mtdna_backend.py +364 -106
  5. mtdna_classifier.py +707 -524
  6. pipeline.py +347 -0
  7. requirements.txt +16 -3
app.py CHANGED
@@ -1,50 +1,62 @@
1
  import gradio as gr
2
  import mtdna_backend
3
  import json
 
 
 
 
4
  # Gradio UI
 
 
 
 
 
 
5
  with gr.Blocks() as interface:
6
  gr.Markdown("# 🧬 mtDNA Location Classifier (MVP)")
7
 
8
- inputMode = gr.Radio(choices=["Single Accession", "Batch Input"], value="Single Accession", label="Choose Input Mode")
9
-
10
- with gr.Group() as single_input_group:
11
- single_accession = gr.Textbox(label="Enter Single Accession (e.g., KU131308)")
12
 
13
- with gr.Group(visible=False) as batch_input_group:
14
- raw_text = gr.Textbox(label="🧬 Paste Accession Numbers (e.g., MF362736.1,MF362738.1,KU131308,MW291678)")
15
- gr.HTML("""<a href="https://drive.google.com/file/d/1t-TFeIsGVu5Jh3CUZS-VE9jQWzNFCs_c/view?usp=sharing" download target="_blank">Download Example CSV Format</a>""")
16
- gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing&ouid=112390323314156876153&rtpof=true&sd=true" download target="_blank">Download Example Excel Format</a>""")
17
- file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True, elem_id="file-upload-box")
18
-
19
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  with gr.Row():
22
  run_button = gr.Button("🔍 Submit and Classify")
 
23
  reset_button = gr.Button("🔄 Reset")
24
 
25
  status = gr.Markdown(visible=False)
26
-
27
  with gr.Group(visible=False) as results_group:
28
- with gr.Accordion("Open to See the Result", open=False) as results:
29
- with gr.Row():
30
- output_summary = gr.Markdown(elem_id="output-summary")
31
- output_flag = gr.Markdown(elem_id="output-flag")
32
 
33
- gr.Markdown("---")
34
 
35
  with gr.Accordion("Open to See the Output Table", open=False) as table_accordion:
36
- """output_table = gr.Dataframe(
37
- headers=["Sample ID", "Technique", "Source", "Predicted Location", "Haplogroup", "Inferred Region", "Context Snippet"],
38
- interactive=False,
39
- row_count=(5, "dynamic")
40
- )"""
41
  output_table = gr.HTML(render=True)
42
 
43
-
44
  with gr.Row():
45
  output_type = gr.Dropdown(choices=["Excel", "JSON", "TXT"], label="Select Output Format", value="Excel")
46
  download_button = gr.Button("⬇️ Download Output")
47
- download_file = gr.File(label="Download File Here",visible=False)
 
 
48
 
49
  gr.Markdown("---")
50
 
@@ -56,47 +68,221 @@ with gr.Blocks() as interface:
56
  feedback_status = gr.Markdown()
57
 
58
  # Functions
59
-
60
- def toggle_input_mode(mode):
61
- if mode == "Single Accession":
62
- return gr.update(visible=True), gr.update(visible=False)
63
- else:
64
- return gr.update(visible=False), gr.update(visible=True)
65
 
66
  def classify_with_loading():
67
  return gr.update(value="⏳ Please wait... processing...",visible=True) # Show processing message
68
 
69
- def classify_dynamic(single_accession, file, text, mode):
70
- if mode == "Single Accession":
71
- return classify_main(single_accession) + (gr.update(visible=False),)
72
- else:
73
- #return summarize_batch(file, text) + (gr.update(visible=False),) # Hide processing message
74
- return classify_mulAcc(file, text) + (gr.update(visible=False),) # Hide processing message
 
 
 
 
 
 
 
 
75
 
76
  # for single accession
77
- def classify_main(accession):
78
- table, summary, labelAncient_Modern, explain_label = mtdna_backend.summarize_results(accession)
79
- flag_output = f"### 🏺 Ancient/Modern Flag\n**{labelAncient_Modern}**\n\n_Explanation:_ {explain_label}"
80
- return (
81
- #table,
82
- make_html_table(table),
83
- summary,
84
- flag_output,
85
- gr.update(visible=True),
86
- gr.update(visible=False)
87
- )
88
- # for batch accessions
89
- def classify_mulAcc(file, text):
90
- table, summary, flag_output, gr1, gr2 = mtdna_backend.summarize_batch(file, text)
91
- #flag_output = f"### 🏺 Ancient/Modern Flag\n**{labelAncient_Modern}**\n\n_Explanation:_ {explain_label}"
92
- return (
93
- #table,
94
- make_html_table(table),
95
- summary,
96
- flag_output,
97
- gr.update(visible=True),
98
- gr.update(visible=False)
99
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def make_html_table(rows):
102
  html = """
@@ -106,7 +292,7 @@ with gr.Blocks() as interface:
106
  <thead style='position: sticky; top: 0; background-color: #2c2c2c; z-index: 1;'>
107
  <tr>
108
  """
109
- headers = ["Sample ID", "Technique", "Source", "Predicted Location", "Haplogroup", "Inferred Region", "Context Snippet"]
110
  html += "".join(
111
  f"<th style='padding: 10px; border: 1px solid #555; text-align: left; white-space: nowrap;'>{h}</th>"
112
  for h in headers
@@ -120,11 +306,31 @@ with gr.Blocks() as interface:
120
  style = "padding: 10px; border: 1px solid #555; vertical-align: top;"
121
 
122
  # For specific columns like Haplogroup, force nowrap
123
- if header in ["Haplogroup", "Sample ID", "Technique"]:
 
 
124
  style += " white-space: nowrap; text-overflow: ellipsis; max-width: 200px; overflow: hidden;"
125
 
126
- if header == "Source" and isinstance(col, str) and col.strip().lower().startswith("http"):
127
- col = f"<a href='{col}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{col}</a>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  html += f"<td style='{style}'>{col}</td>"
130
  html += "</tr>"
@@ -133,78 +339,174 @@ with gr.Blocks() as interface:
133
  return html
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def reset_fields():
137
- return (
138
- gr.update(value=""), # single_accession
139
- gr.update(value=""), # raw_text
140
- gr.update(value=None), # file_upload
141
- gr.update(value="Single Accession"), # inputMode
142
- gr.update(value=[], visible=True), # output_table
143
- gr.update(value="", visible=True), # output_summary
144
- gr.update(value="", visible=True), # output_flag
145
- gr.update(visible=False), # status
146
- gr.update(visible=False) # results_group
147
- )
148
-
149
- inputMode.change(fn=toggle_input_mode, inputs=inputMode, outputs=[single_input_group, batch_input_group])
150
- run_button.click(fn=classify_with_loading, inputs=[], outputs=[status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  run_button.click(
152
- fn=classify_dynamic,
153
- inputs=[single_accession, file_upload, raw_text, inputMode],
154
- outputs=[output_table, output_summary, output_flag, results_group, status]
 
 
 
155
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  reset_button.click(
157
  fn=reset_fields,
158
  inputs=[],
159
- outputs=[
160
- single_accession, raw_text, file_upload, inputMode,
161
- output_table, output_summary, output_flag,
162
- status, results_group
163
- ]
164
- )
165
 
166
  download_button.click(
167
  fn=mtdna_backend.save_batch_output,
168
- inputs=[output_table, output_summary, output_flag, output_type],
 
169
  outputs=[download_file])
170
 
 
 
 
 
171
  submit_feedback.click(
172
- fn=mtdna_backend.store_feedback_to_google_sheets, inputs=[single_accession, q1, q2, contact], outputs=feedback_status
 
 
173
  )
174
- # Custom CSS styles
175
- gr.HTML("""
176
- <style>
177
- /* Ensures both sections are equally spaced with the same background size */
178
- #output-summary, #output-flag {
179
- background-color: #f0f4f8; /* Light Grey for both */
180
- padding: 20px;
181
- border-radius: 10px;
182
- margin-top: 10px;
183
- width: 100%; /* Ensure full width */
184
- min-height: 150px; /* Ensures both have a minimum height */
185
- box-sizing: border-box; /* Prevents padding from increasing size */
186
- display: flex;
187
- flex-direction: column;
188
- justify-content: space-between;
189
- }
190
 
191
- /* Specific background colors */
192
- #output-summary {
193
- background-color: #434a4b;
194
- }
195
-
196
- #output-flag {
197
- background-color: #141616;
198
- }
199
-
200
- /* Ensuring they are in a row and evenly spaced */
201
- .gradio-row {
202
- display: flex;
203
- justify-content: space-between;
204
- width: 100%;
205
- }
206
- </style>
207
- """)
208
 
209
 
210
  interface.launch(share=True,debug=True)
 
1
  import gradio as gr
2
  import mtdna_backend
3
  import json
4
+ from iterate3 import data_preprocess, model, pipeline
5
+ import os
6
+ import hashlib
7
+ import threading
8
  # Gradio UI
9
+ #stop_flag = gr.State(value=False)
10
+ class StopFlag:
11
+ def __init__(self):
12
+ self.value = False
13
+ global_stop_flag = StopFlag() # Shared between run + stop
14
+
15
  with gr.Blocks() as interface:
16
  gr.Markdown("# 🧬 mtDNA Location Classifier (MVP)")
17
 
18
+ #inputMode = gr.Radio(choices=["Single Accession", "Batch Input"], value="Single Accession", label="Choose Input Mode")
19
+ user_email = gr.Textbox(label="📧 Your email (used to track free quota)")
20
+ usage_display = gr.Markdown("", visible=False)
 
21
 
22
+ # with gr.Group() as single_input_group:
23
+ # single_accession = gr.Textbox(label="Enter Single Accession (e.g., KU131308)")
 
 
 
 
24
 
25
+ # with gr.Group(visible=False) as batch_input_group:
26
+ # raw_text = gr.Textbox(label="🧬 Paste Accession Numbers (e.g., MF362736.1,MF362738.1,KU131308,MW291678)")
27
+ # resume_file = gr.File(label="🗃️ Previously saved Excel output (optional)", file_types=[".xlsx"], interactive=True)
28
+ # gr.HTML("""<a href="https://drive.google.com/file/d/1t-TFeIsGVu5Jh3CUZS-VE9jQWzNFCs_c/view?usp=sharing" download target="_blank">Download Example CSV Format</a>""")
29
+ # gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing&ouid=112390323314156876153&rtpof=true&sd=true" download target="_blank">Download Example Excel Format</a>""")
30
+ # file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True, elem_id="file-upload-box")
31
+ raw_text = gr.Textbox(label="🧚 Input Accession Number(s) (single (KU131308) or comma-separated (e.g., MF362736.1,MF362738.1,KU131308,MW291678))")
32
+ #resume_file = gr.File(label="🗃️ Previously saved Excel output (optional)", file_types=[".xlsx"], interactive=True)
33
+ gr.HTML("""<a href="https://docs.google.com/spreadsheets/d/1lKqPp17EfHsshJGZRWEpcNOZlGo3F5qU/edit?usp=sharing" download target="_blank">Download Example Excel Format</a>""")
34
+ file_upload = gr.File(label="📁 Or Upload CSV/Excel File", file_types=[".csv", ".xlsx"], interactive=True)
35
 
36
  with gr.Row():
37
  run_button = gr.Button("🔍 Submit and Classify")
38
+ stop_button = gr.Button("❌ Stop Batch", visible=True)
39
  reset_button = gr.Button("🔄 Reset")
40
 
41
  status = gr.Markdown(visible=False)
42
+
43
  with gr.Group(visible=False) as results_group:
44
+ # with gr.Accordion("Open to See the Result", open=False) as results:
45
+ # with gr.Row():
46
+ # output_summary = gr.Markdown(elem_id="output-summary")
47
+ # output_flag = gr.Markdown(elem_id="output-flag")
48
 
49
+ # gr.Markdown("---")
50
 
51
  with gr.Accordion("Open to See the Output Table", open=False) as table_accordion:
 
 
 
 
 
52
  output_table = gr.HTML(render=True)
53
 
 
54
  with gr.Row():
55
  output_type = gr.Dropdown(choices=["Excel", "JSON", "TXT"], label="Select Output Format", value="Excel")
56
  download_button = gr.Button("⬇️ Download Output")
57
+ #download_file = gr.File(label="Download File Here",visible=False)
58
+ download_file = gr.File(label="Download File Here", visible=False, interactive=True)
59
+ progress_box = gr.Textbox(label="Live Processing Log", lines=20, interactive=False)
60
 
61
  gr.Markdown("---")
62
 
 
68
  feedback_status = gr.Markdown()
69
 
70
  # Functions
71
+ # def toggle_input_mode(mode):
72
+ # if mode == "Single Accession":
73
+ # return gr.update(visible=True), gr.update(visible=False)
74
+ # else:
75
+ # return gr.update(visible=False), gr.update(visible=True)
 
76
 
77
  def classify_with_loading():
78
  return gr.update(value="⏳ Please wait... processing...",visible=True) # Show processing message
79
 
80
+ # def classify_dynamic(single_accession, file, text, resume, email, mode):
81
+ # if mode == "Single Accession":
82
+ # return classify_main(single_accession) + (gr.update(visible=False),)
83
+ # else:
84
+ # #return summarize_batch(file, text) + (gr.update(visible=False),) # Hide processing message
85
+ # return classify_mulAcc(file, text, resume) + (gr.update(visible=False),) # Hide processing message
86
+ # Logging helpers defined early to avoid NameError
87
+
88
+
89
+ # def classify_dynamic(single_accession, file, text, resume, email, mode):
90
+ # if mode == "Single Accession":
91
+ # return classify_main(single_accession) + (gr.update(value="", visible=False),)
92
+ # else:
93
+ # return classify_mulAcc(file, text, resume, email, log_callback=real_time_logger, log_collector=log_collector)
94
 
95
  # for single accession
96
+ # def classify_main(accession):
97
+ # #table, summary, labelAncient_Modern, explain_label = mtdna_backend.summarize_results(accession)
98
+ # table = mtdna_backend.summarize_results(accession)
99
+ # #flag_output = f"### 🏺 Ancient/Modern Flag\n**{labelAncient_Modern}**\n\n_Explanation:_ {explain_label}"
100
+ # return (
101
+ # #table,
102
+ # make_html_table(table),
103
+ # # summary,
104
+ # # flag_output,
105
+ # gr.update(visible=True),
106
+ # gr.update(visible=False),
107
+ # gr.update(visible=False)
108
+ # )
109
+
110
+ #stop_flag = gr.State(value=False)
111
+ #stop_flag = StopFlag()
112
+
113
+ # def stop_batch(stop_flag):
114
+ # stop_flag.value = True
115
+ # return gr.update(value="❌ Stopping...", visible=True), stop_flag
116
+ def stop_batch():
117
+ global_stop_flag.value = True
118
+ return gr.update(value="❌ Stopping...", visible=True)
119
+
120
+ # def threaded_batch_runner(file, text, email):
121
+ # global_stop_flag.value = False
122
+ # log_lines = []
123
+
124
+ # def update_log(line):
125
+ # log_lines.append(line)
126
+ # yield (
127
+ # gr.update(visible=False), # output_table (not yet)
128
+ # gr.update(visible=False), # results_group
129
+ # gr.update(visible=False), # download_file
130
+ # gr.update(visible=False), # usage_display
131
+ # gr.update(value="⏳ Still processing...", visible=True), # status
132
+ # gr.update(value="\n".join(log_lines)) # progress_box
133
+ # )
134
+
135
+ # # Start a dummy update to say "Starting..."
136
+ # yield from update_log("🚀 Starting batch processing...")
137
+
138
+ # rows, file_path, count, final_log, warning = mtdna_backend.summarize_batch(
139
+ # file=file,
140
+ # raw_text=text,
141
+ # resume_file=None,
142
+ # user_email=email,
143
+ # stop_flag=global_stop_flag,
144
+ # yield_callback=lambda line: (yield from update_log(line))
145
+ # )
146
+
147
+ # html = make_html_table(rows)
148
+ # file_update = gr.update(value=file_path, visible=True) if os.path.exists(file_path) else gr.update(visible=False)
149
+ # usage_or_warning_text = f"**{count}** samples used by this email." if email.strip() else warning
150
+
151
+ # yield (
152
+ # html,
153
+ # gr.update(visible=True), # results_group
154
+ # file_update, # download_file
155
+ # gr.update(value=usage_or_warning_text, visible=True),
156
+ # gr.update(value="✅ Done", visible=True),
157
+ # gr.update(value=final_log)
158
+ # )
159
+
160
+ def threaded_batch_runner(file=None, text="", email=""):
161
+ print("📧 EMAIL RECEIVED:", email)
162
+ import tempfile
163
+ from mtdna_backend import (
164
+ extract_accessions_from_input,
165
+ summarize_results,
166
+ save_to_excel,
167
+ hash_user_id,
168
+ increment_usage,
169
+ )
170
+ import os
171
+
172
+ global_stop_flag.value = False # reset stop flag
173
+
174
+ tmp_dir = tempfile.mkdtemp()
175
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
176
+ limited_acc = 50 + (10 if email.strip() else 0)
177
+
178
+ # Step 1: Parse input
179
+ accessions, error = extract_accessions_from_input(file, text)
180
+ if error:
181
+ yield (
182
+ "", # output_table
183
+ gr.update(visible=False), # results_group
184
+ gr.update(visible=False), # download_file
185
+ "", # usage_display
186
+ "❌ Error", # status
187
+ str(error) # progress_box
188
+ )
189
+ return
190
+
191
+ total = len(accessions)
192
+ if total > limited_acc:
193
+ accessions = accessions[:limited_acc]
194
+ warning = f"⚠️ Only processing first {limited_acc} accessions."
195
+ else:
196
+ warning = f"✅ All {total} accessions will be processed."
197
+
198
+ all_rows = []
199
+ log_lines = []
200
+
201
+ # Step 2: Loop through accessions
202
+ for i, acc in enumerate(accessions):
203
+ if global_stop_flag.value:
204
+ log_lines.append(f"🛑 Stopped at {acc} ({i+1}/{total})")
205
+ usage_text = ""
206
+ if email.strip():
207
+ # user_hash = hash_user_id(email)
208
+ # usage_count = increment_usage(user_hash, len(all_rows))
209
+ usage_count = increment_usage(email, len(all_rows))
210
+ usage_text = f"**{usage_count}** samples used by this email. Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
211
+ else:
212
+ usage_text = f"The limited accession is 50. The user has used {len(all_rows)}, and only {50-len(all_rows)} left."
213
+ yield (
214
+ make_html_table(all_rows),
215
+ gr.update(visible=True),
216
+ gr.update(value=output_file_path, visible=True),
217
+ gr.update(value=usage_text, visible=True),
218
+ "🛑 Stopped",
219
+ "\n".join(log_lines)
220
+ )
221
+ return
222
+
223
+ log_lines.append(f"[{i+1}/{total}] Processing {acc}")
224
+ yield (
225
+ make_html_table(all_rows),
226
+ gr.update(visible=True),
227
+ gr.update(visible=False),
228
+ "",
229
+ "⏳ Processing...",
230
+ "\n".join(log_lines)
231
+ )
232
+
233
+ try:
234
+ rows = summarize_results(acc)
235
+ all_rows.extend(rows)
236
+ save_to_excel(all_rows, "", "", output_file_path, is_resume=False)
237
+ log_lines.append(f"✅ Processed {acc} ({i+1}/{total})")
238
+ except Exception as e:
239
+ log_lines.append(f"❌ Failed to process {acc}: {e}")
240
+
241
+ yield (
242
+ make_html_table(all_rows),
243
+ gr.update(visible=True),
244
+ gr.update(visible=False),
245
+ "",
246
+ "⏳ Processing...",
247
+ "\n".join(log_lines)
248
+ )
249
+
250
+ # Final update
251
+ usage_text = ""
252
+ if email.strip():
253
+ # user_hash = hash_user_id(email)
254
+ # usage_count = increment_usage(user_hash, len(all_rows))
255
+ usage_count = increment_usage(email, len(all_rows))
256
+ usage_text = f"**{usage_count}** samples used by this email. Ten more samples are added first (you now have 60 limited accessions), then wait we will contact you via this email."
257
+ else:
258
+ usage_text = f"The limited accession is 50. The user has used {len(all_rows)}, and only {50-len(all_rows)} left."
259
+ yield (
260
+ make_html_table(all_rows),
261
+ gr.update(visible=True),
262
+ gr.update(value=output_file_path, visible=True),
263
+ gr.update(value=usage_text, visible=True),
264
+ "✅ Done",
265
+ "\n".join(log_lines)
266
+ )
267
+
268
+ # def threaded_batch_runner(file=None, text="", email=""):
269
+ # global_stop_flag.value = False
270
+
271
+ # # Dummy test output that matches expected schema
272
+ # return (
273
+ # "<div>✅ Dummy output table</div>", # HTML string
274
+ # gr.update(visible=True), # Group visibility
275
+ # gr.update(visible=False), # Download file
276
+ # "**0** samples used.", # Markdown
277
+ # "✅ Done", # Status string
278
+ # "Processing finished." # Progress string
279
+ # )
280
+
281
+
282
+ # def classify_mulAcc(file, text, resume, email, log_callback=None, log_collector=None):
283
+ # stop_flag.value = False
284
+ # return threaded_batch_runner(file, text, resume, email, status, stop_flag, log_callback=log_callback, log_collector=log_collector)
285
+
286
 
287
  def make_html_table(rows):
288
  html = """
 
292
  <thead style='position: sticky; top: 0; background-color: #2c2c2c; z-index: 1;'>
293
  <tr>
294
  """
295
+ headers = ["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]
296
  html += "".join(
297
  f"<th style='padding: 10px; border: 1px solid #555; text-align: left; white-space: nowrap;'>{h}</th>"
298
  for h in headers
 
306
  style = "padding: 10px; border: 1px solid #555; vertical-align: top;"
307
 
308
  # For specific columns like Haplogroup, force nowrap
309
+ if header in ["Country Explanation", "Sample Type Explanation"]:
310
+ style += " max-width: 400px; word-wrap: break-word; white-space: normal;"
311
+ elif header in ["Sample ID", "Predicted Country", "Predicted Sample Type", "Time cost"]:
312
  style += " white-space: nowrap; text-overflow: ellipsis; max-width: 200px; overflow: hidden;"
313
 
314
+ # if header == "Sources" and isinstance(col, str) and col.strip().lower().startswith("http"):
315
+ # col = f"<a href='{col}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{col}</a>"
316
+
317
+ #html += f"<td style='{style}'>{col}</td>"
318
+ if header == "Sources" and isinstance(col, str):
319
+ links = [f"<a href='{url.strip()}' target='_blank' style='color: #4ea1f3; text-decoration: underline;'>{url.strip()}</a>" for url in col.strip().split("\n") if url.strip()]
320
+ col = "- "+"<br>- ".join(links)
321
+ elif isinstance(col, str):
322
+ # lines = []
323
+ # for line in col.split("\n"):
324
+ # line = line.strip()
325
+ # if not line:
326
+ # continue
327
+ # if line.lower().startswith("rag_llm-"):
328
+ # content = line[len("rag_llm-"):].strip()
329
+ # line = f"{content} (Method: RAG_LLM)"
330
+ # lines.append(f"- {line}")
331
+ col = col.replace("\n", "<br>")
332
+ #col = col.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
333
+ #col = "<br>".join(lines)
334
 
335
  html += f"<td style='{style}'>{col}</td>"
336
  html += "</tr>"
 
339
  return html
340
 
341
 
342
+ # def reset_fields():
343
+ # global_stop_flag.value = False # 💡 Add this to reset the flag
344
+ # return (
345
+ # #gr.update(value=""), # single_accession
346
+ # gr.update(value=""), # raw_text
347
+ # gr.update(value=None), # file_upload
348
+ # #gr.update(value=None), # resume_file
349
+ # #gr.update(value="Single Accession"), # inputMode
350
+ # gr.update(value=[], visible=True), # output_table
351
+ # # gr.update(value="", visible=True), # output_summary
352
+ # # gr.update(value="", visible=True), # output_flag
353
+ # gr.update(visible=False), # status
354
+ # gr.update(visible=False), # results_group
355
+ # gr.update(value="", visible=False), # usage_display
356
+ # gr.update(value="", visible=False), # progress_box
357
+ # )
358
  def reset_fields():
359
+ global_stop_flag.value = False # Reset the stop flag
360
+
361
+ return (
362
+ gr.update(value=""), # raw_text
363
+ gr.update(value=None), # file_upload
364
+ gr.update(value=[], visible=True), # output_table
365
+ gr.update(value="", visible=True), # status — reset and make visible again
366
+ gr.update(visible=False), # results_group
367
+ gr.update(value="", visible=True), # usage_display — reset and make visible again
368
+ gr.update(value="", visible=True), # progress_box — reset AND visible!
369
+ )
370
+ #inputMode.change(fn=toggle_input_mode, inputs=inputMode, outputs=[single_input_group, batch_input_group])
371
+ #run_button.click(fn=classify_with_loading, inputs=[], outputs=[status])
372
+ # run_button.click(
373
+ # fn=classify_dynamic,
374
+ # inputs=[single_accession, file_upload, raw_text, resume_file,user_email,inputMode],
375
+ # outputs=[output_table,
376
+ # #output_summary, output_flag,
377
+ # results_group, download_file, usage_display,status, progress_box]
378
+ # )
379
+
380
+ # run_button.click(
381
+ # fn=threaded_batch_runner,
382
+ # #inputs=[file_upload, raw_text, resume_file, user_email],
383
+ # inputs=[file_upload, raw_text, user_email],
384
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box]
385
+ # )
386
+ # run_button.click(
387
+ # fn=threaded_batch_runner,
388
+ # inputs=[file_upload, raw_text, user_email],
389
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
390
+ # every=0.5 # <-- this tells Gradio to expect streaming
391
+ # )
392
+ # output_table = gr.HTML()
393
+ # results_group = gr.Group(visible=False)
394
+ # download_file = gr.File(visible=False)
395
+ # usage_display = gr.Markdown(visible=False)
396
+ # status = gr.Markdown(visible=False)
397
+ # progress_box = gr.Textbox(visible=False)
398
+
399
+ # run_button.click(
400
+ # fn=threaded_batch_runner,
401
+ # inputs=[file_upload, raw_text, user_email],
402
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
403
+ # every=0.5, # streaming enabled
404
+ # show_progress="full"
405
+ # )
406
+ print("🎯 DEBUG COMPONENT TYPES")
407
+ print(type(output_table))
408
+ print(type(results_group))
409
+ print(type(download_file))
410
+ print(type(usage_display))
411
+ print(type(status))
412
+ print(type(progress_box))
413
+
414
+
415
+ # interface.stream(
416
+ # fn=threaded_batch_runner,
417
+ # inputs=[file_upload, raw_text, user_email],
418
+ # outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
419
+ # trigger=run_button,
420
+ # every=0.5,
421
+ # show_progress="full",
422
+ # )
423
+ interface.queue() # No arguments here!
424
+
425
  run_button.click(
426
+ fn=threaded_batch_runner,
427
+ inputs=[file_upload, raw_text, user_email],
428
+ outputs=[output_table, results_group, download_file, usage_display, status, progress_box],
429
+ concurrency_limit=1, # ✅ correct in Gradio 5.x
430
+ queue=True, # ✅ ensure the queue is used
431
+ #every=0.5
432
  )
433
+
434
+
435
+
436
+
437
+ stop_button.click(fn=stop_batch, inputs=[], outputs=[status])
438
+
439
+ # reset_button.click(
440
+ # #fn=reset_fields,
441
+ # fn=lambda: (
442
+ # gr.update(value=""), gr.update(value=""), gr.update(value=None), gr.update(value=None), gr.update(value="Single Accession"),
443
+ # gr.update(value=[], visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="", visible=False), gr.update(value="", visible=False)
444
+ # ),
445
+ # inputs=[],
446
+ # outputs=[
447
+ # single_accession, raw_text, file_upload, resume_file,inputMode,
448
+ # output_table,# output_summary, output_flag,
449
+ # status, results_group, usage_display, progress_box
450
+ # ]
451
+ # )
452
+ #stop_button.click(fn=lambda sf: (gr.update(value="❌ Stopping...", visible=True), setattr(sf, "value", True) or sf), inputs=[gr.State(stop_flag)], outputs=[status, gr.State(stop_flag)])
453
+
454
  reset_button.click(
455
  fn=reset_fields,
456
  inputs=[],
457
+ #outputs=[raw_text, file_upload, resume_file, output_table, status, results_group, usage_display, progress_box]
458
+ outputs=[raw_text, file_upload, output_table, status, results_group, usage_display, progress_box]
459
+ )
 
 
 
460
 
461
  download_button.click(
462
  fn=mtdna_backend.save_batch_output,
463
+ #inputs=[output_table, output_summary, output_flag, output_type],
464
+ inputs=[output_table, output_type],
465
  outputs=[download_file])
466
 
467
+ # submit_feedback.click(
468
+ # fn=mtdna_backend.store_feedback_to_google_sheets,
469
+ # inputs=[single_accession, q1, q2, contact], outputs=feedback_status
470
+ # )
471
  submit_feedback.click(
472
+ fn=mtdna_backend.store_feedback_to_google_sheets,
473
+ inputs=[raw_text, q1, q2, contact],
474
+ outputs=[feedback_status]
475
  )
476
+ # # Custom CSS styles
477
+ # gr.HTML("""
478
+ # <style>
479
+ # /* Ensures both sections are equally spaced with the same background size */
480
+ # #output-summary, #output-flag {
481
+ # background-color: #f0f4f8; /* Light Grey for both */
482
+ # padding: 20px;
483
+ # border-radius: 10px;
484
+ # margin-top: 10px;
485
+ # width: 100%; /* Ensure full width */
486
+ # min-height: 150px; /* Ensures both have a minimum height */
487
+ # box-sizing: border-box; /* Prevents padding from increasing size */
488
+ # display: flex;
489
+ # flex-direction: column;
490
+ # justify-content: space-between;
491
+ # }
492
 
493
+ # /* Specific background colors */
494
+ # #output-summary {
495
+ # background-color: #434a4b;
496
+ # }
497
+
498
+ # #output-flag {
499
+ # background-color: #141616;
500
+ # }
501
+
502
+ # /* Ensuring they are in a row and evenly spaced */
503
+ # .gradio-row {
504
+ # display: flex;
505
+ # justify-content: space-between;
506
+ # width: 100%;
507
+ # }
508
+ # </style>
509
+ # """)
510
 
511
 
512
  interface.launch(share=True,debug=True)
data_preprocess.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import streamlit as st
4
+ import subprocess
5
+ import re
6
+ from Bio import Entrez
7
+ from docx import Document
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ import pandas as pd
19
+ from iterate3 import model
20
+ import nltk
21
+ nltk.download('punkt_tab')
22
+ def download_excel_file(url, save_path="temp.xlsx"):
23
+ if "view.officeapps.live.com" in url:
24
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
25
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
26
+ response = requests.get(real_url)
27
+ with open(save_path, "wb") as f:
28
+ f.write(response.content)
29
+ return save_path
30
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
31
+ response = requests.get(url)
32
+ response.raise_for_status() # Raises error if download fails
33
+ with open(save_path, "wb") as f:
34
+ f.write(response.content)
35
+ print(len(response.content))
36
+ return save_path
37
+ else:
38
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
39
+ return url
40
+ def extract_text(link,saveFolder):
41
+ text = ""
42
+ name = link.split("/")[-1]
43
+ file_path = Path(saveFolder) / name
44
+ # pdf
45
+ if link.endswith(".pdf"):
46
+ if file_path.is_file():
47
+ link = saveFolder + "/" + name
48
+ print("File exists.")
49
+ p = pdf.PDF(link,saveFolder)
50
+ text = p.extractTextWithPDFReader()
51
+ #text_exclude_table = p.extract_text_excluding_tables()
52
+ # worddoc
53
+ elif link.endswith(".doc") or link.endswith(".docx"):
54
+ d = wordDoc.wordDoc(link,saveFolder)
55
+ text = d.extractTextByPage()
56
+ # html
57
+ if link.split(".")[-1].lower() not in "xlsx":
58
+ if "http" in link or "html" in link:
59
+ html = extractHTML.HTML("",link)
60
+ text = html.getListSection() # the text already clean
61
+ return text
62
+ def extract_table(link,saveFolder):
63
+ table = []
64
+ name = link.split("/")[-1]
65
+ file_path = Path(saveFolder) / name
66
+ # pdf
67
+ if link.endswith(".pdf"):
68
+ if file_path.is_file():
69
+ link = saveFolder + "/" + name
70
+ print("File exists.")
71
+ p = pdf.PDF(link,saveFolder)
72
+ table = p.extractTable()
73
+ # worddoc
74
+ elif link.endswith(".doc") or link.endswith(".docx"):
75
+ d = wordDoc.wordDoc(link,saveFolder)
76
+ table = d.extractTableAsList()
77
+ # excel
78
+ elif link.split(".")[-1].lower() in "xlsx":
79
+ # download excel file if it not downloaded yet
80
+ savePath = saveFolder +"/"+ link.split("/")[-1]
81
+ excelPath = download_excel_file(link, savePath)
82
+ try:
83
+ xls = pd.ExcelFile(excelPath)
84
+ table_list = []
85
+ for sheet_name in xls.sheet_names:
86
+ df = pd.read_excel(xls, sheet_name=sheet_name)
87
+ cleaned_table = df.fillna("").astype(str).values.tolist()
88
+ table_list.append(cleaned_table)
89
+ table = table_list
90
+ except Exception as e:
91
+ print("❌ Failed to extract tables from Excel:", e)
92
+ # html
93
+ elif "http" in link or "html" in link:
94
+ html = extractHTML.HTML("",link)
95
+ table = html.extractTable() # table is a list
96
+ table = clean_tables_format(table)
97
+ return table
98
+
99
+ def clean_tables_format(tables):
100
+ """
101
+ Ensures all tables are in consistent format: List[List[List[str]]]
102
+ Cleans by:
103
+ - Removing empty strings and rows
104
+ - Converting all cells to strings
105
+ - Handling DataFrames and list-of-lists
106
+ """
107
+ cleaned = []
108
+ if tables:
109
+ for table in tables:
110
+ standardized = []
111
+
112
+ # Case 1: Pandas DataFrame
113
+ if isinstance(table, pd.DataFrame):
114
+ table = table.fillna("").astype(str).values.tolist()
115
+
116
+ # Case 2: List of Lists
117
+ if isinstance(table, list) and all(isinstance(row, list) for row in table):
118
+ for row in table:
119
+ filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
120
+ if filtered_row:
121
+ standardized.append(filtered_row)
122
+
123
+ if standardized:
124
+ cleaned.append(standardized)
125
+
126
+ return cleaned
127
+
128
+ import json
129
+ import tiktoken # Optional: for OpenAI token counting
130
+ def normalize_text_for_comparison(s: str) -> str:
131
+ """
132
+ Normalizes text for robust comparison by:
133
+ 1. Converting to lowercase.
134
+ 2. Replacing all types of newlines with a single consistent newline (\n).
135
+ 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
136
+ 4. Stripping leading/trailing whitespace from the entire string.
137
+ """
138
+ s = s.lower()
139
+ s = s.replace('\r\n', '\n') # Handle Windows newlines
140
+ s = s.replace('\r', '\n') # Handle Mac classic newlines
141
+
142
+ # Replace sequences of whitespace (including multiple newlines) with a single space
143
+ # This might be too aggressive if you need to preserve paragraph breaks,
144
+ # but good for exact word-sequence matching.
145
+ s = re.sub(r'\s+', ' ', s)
146
+
147
+ return s.strip()
148
+ def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
149
+ """
150
+ Merge cleaned text and table into one string for LLM input.
151
+ - Avoids duplicating tables already in text
152
+ - Extracts only relevant rows from large tables
153
+ - Skips or saves oversized tables
154
+ """
155
+ import importlib
156
+ json = importlib.import_module("json")
157
+
158
+ def estimate_tokens(text_str):
159
+ try:
160
+ enc = tiktoken.get_encoding(tokenizer)
161
+ return len(enc.encode(text_str))
162
+ except:
163
+ return len(text_str) // 4 # Fallback estimate
164
+
165
+ def is_table_relevant(table, keywords, accession_id=None):
166
+ flat = " ".join(" ".join(row).lower() for row in table)
167
+ if accession_id and accession_id.lower() in flat:
168
+ return True
169
+ return any(kw.lower() in flat for kw in keywords)
170
+ preview, preview1 = "",""
171
+ llm_input = "## Document Text\n" + text.strip() + "\n"
172
+ clean_text = normalize_text_for_comparison(text)
173
+
174
+ if tables:
175
+ for idx, table in enumerate(tables):
176
+ keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
177
+ if accession_id: keywords += [accession_id.lower()]
178
+ if isolate: keywords += [isolate.lower()]
179
+ if is_table_relevant(table, keywords, accession_id):
180
+ if len(table) > 0:
181
+ for tab in table:
182
+ preview = " ".join(tab) if tab else ""
183
+ preview1 = "\n".join(tab) if tab else ""
184
+ clean_preview = normalize_text_for_comparison(preview)
185
+ clean_preview1 = normalize_text_for_comparison(preview1)
186
+ if clean_preview not in clean_text:
187
+ if clean_preview1 not in clean_text:
188
+ table_str = json.dumps([tab], indent=2)
189
+ llm_input += f"## Table {idx+1}\n{table_str}\n"
190
+ return llm_input.strip()
191
+
192
+ def preprocess_document(link, saveFolder, accession=None, isolate=None):
193
+ try:
194
+ text = extract_text(link, saveFolder)
195
+ except: text = ""
196
+ try:
197
+ tables = extract_table(link, saveFolder)
198
+ except: tables = []
199
+ if accession: accession = accession
200
+ if isolate: isolate = isolate
201
+ try:
202
+ final_input = merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
203
+ except: final_input = ""
204
+ return text, tables, final_input
205
+
206
+ def extract_sentences(text):
207
+ sentences = re.split(r'(?<=[.!?])\s+', text)
208
+ return [s.strip() for s in sentences if s.strip()]
209
+
210
+ def is_irrelevant_number_sequence(text):
211
+ if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
212
+ return False
213
+ word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
214
+ number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
215
+ total_tokens = len(re.findall(r'\S+', text))
216
+ if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
217
+ return True
218
+ elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
219
+ return True
220
+ return False
221
+
222
+ def remove_isolated_single_digits(sentence):
223
+ tokens = sentence.split()
224
+ filtered_tokens = []
225
+ for token in tokens:
226
+ if token == '0' or token == '1':
227
+ pass
228
+ else:
229
+ filtered_tokens.append(token)
230
+ return ' '.join(filtered_tokens).strip()
231
+
232
+ def get_contextual_sentences_BFS(text_content, keyword, depth=2):
233
+ def extract_codes(sentence):
234
+ # Match codes like 'A1YU101', 'KM1', 'MO6' — at least 2 letters + numbers
235
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
236
+ sentences = extract_sentences(text_content)
237
+ relevant_sentences = set()
238
+ initial_keywords = set()
239
+
240
+ # Define a regex to capture codes like A1YU101 or KM1
241
+ # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
242
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
243
+
244
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
245
+ keyword_match = code_pattern.search(keyword)
246
+
247
+ keyword_prefix = None
248
+ keyword_num = None
249
+
250
+ if keyword_match:
251
+ keyword_prefix = keyword_match.group(1).lower()
252
+ keyword_num = int(keyword_match.group(2))
253
+
254
+ for sentence in sentences:
255
+ sentence_added = False
256
+
257
+ # 1. Check for exact match of the keyword
258
+ if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
259
+ relevant_sentences.add(sentence.strip())
260
+ initial_keywords.add(keyword.lower())
261
+ sentence_added = True
262
+
263
+ # 2. Check for range patterns (e.g., A1YU101-A1YU137)
264
+ # The range pattern should be broad enough to capture the full code string within the range.
265
+ range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
266
+ range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
267
+
268
+ for r_match in range_matches:
269
+ start_code_str = r_match.group(1)
270
+ end_code_str = r_match.group(2)
271
+
272
+ # CRITICAL FIX: Use code_pattern.search for start_match and end_match
273
+ start_match = code_pattern.search(start_code_str)
274
+ end_match = code_pattern.search(end_code_str)
275
+
276
+ if keyword_prefix and keyword_num is not None and start_match and end_match:
277
+ start_prefix = start_match.group(1).lower()
278
+ end_prefix = end_match.group(1).lower()
279
+ start_num = int(start_match.group(2))
280
+ end_num = int(end_match.group(2))
281
+
282
+ # Check if the keyword's prefix matches and its number is within the range
283
+ if keyword_prefix == start_prefix and \
284
+ keyword_prefix == end_prefix and \
285
+ start_num <= keyword_num <= end_num:
286
+ relevant_sentences.add(sentence.strip())
287
+ initial_keywords.add(start_code_str.lower())
288
+ initial_keywords.add(end_code_str.lower())
289
+ sentence_added = True
290
+ break # Only need to find one matching range per sentence
291
+
292
+ # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
293
+ # to initial_keywords to ensure graph traversal from related terms.
294
+ if sentence_added:
295
+ for word in extract_codes(sentence):
296
+ initial_keywords.add(word.lower())
297
+
298
+
299
+ # Build word_to_sentences mapping for all sentences
300
+ word_to_sentences = {}
301
+ for sent in sentences:
302
+ codes_in_sent = set(extract_codes(sent))
303
+ for code in codes_in_sent:
304
+ word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
305
+
306
+
307
+ # Build the graph
308
+ graph = {}
309
+ for sent in sentences:
310
+ codes = set(extract_codes(sent))
311
+ for word1 in codes:
312
+ word1_lower = word1.lower()
313
+ graph.setdefault(word1_lower, set())
314
+ for word2 in codes:
315
+ word2_lower = word2.lower()
316
+ if word1_lower != word2_lower:
317
+ graph[word1_lower].add(word2_lower)
318
+
319
+
320
+ # Perform BFS/graph traversal
321
+ queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
322
+ visited_words = set(initial_keywords)
323
+
324
+ while queue:
325
+ current_word, level = queue.pop(0)
326
+ if level >= depth:
327
+ continue
328
+
329
+ relevant_sentences.update(word_to_sentences.get(current_word, []))
330
+
331
+ for neighbor in graph.get(current_word, []):
332
+ if neighbor not in visited_words:
333
+ visited_words.add(neighbor)
334
+ queue.append((neighbor, level + 1))
335
+
336
+ final_sentences = set()
337
+ for sentence in relevant_sentences:
338
+ if not is_irrelevant_number_sequence(sentence):
339
+ processed_sentence = remove_isolated_single_digits(sentence)
340
+ if processed_sentence:
341
+ final_sentences.add(processed_sentence)
342
+
343
+ return "\n".join(sorted(list(final_sentences)))
344
+
345
+
346
+
347
+ def get_contextual_sentences_DFS(text_content, keyword, depth=2):
348
+ sentences = extract_sentences(text_content)
349
+
350
+ # Build word-to-sentences mapping
351
+ word_to_sentences = {}
352
+ for sent in sentences:
353
+ words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
354
+ for word in words_in_sent:
355
+ word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
356
+
357
+ # Function to extract codes in a sentence
358
+ def extract_codes(sentence):
359
+ # Only codes like 'KSK1', 'MG272794', not pure numbers
360
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
361
+
362
+ # DFS with priority based on distance to keyword and early stop if country found
363
+ def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
364
+ country = "unknown"
365
+ if current_depth > max_depth:
366
+ return country, False
367
+
368
+ if current_word not in word_to_sentences:
369
+ return country, False
370
+
371
+ for sentence in word_to_sentences[current_word]:
372
+ if sentence == parent_sentence:
373
+ continue # avoid reusing the same sentence
374
+
375
+ collected_sentences.add(sentence)
376
+
377
+ #print("current_word:", current_word)
378
+ small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
379
+ #print(small_sen)
380
+ country = model.get_country_from_text(small_sen)
381
+ #print("small context country:", country)
382
+ if country.lower() != "unknown":
383
+ return country, True
384
+ else:
385
+ country = model.get_country_from_text(sentence)
386
+ #print("full sentence country:", country)
387
+ if country.lower() != "unknown":
388
+ return country, True
389
+
390
+ codes_in_sentence = extract_codes(sentence)
391
+ idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
392
+ if idx is None:
393
+ continue
394
+
395
+ sorted_children = sorted(
396
+ [code for code in codes_in_sentence if code.lower() not in visited_words],
397
+ key=lambda x: (abs(codes_in_sentence.index(x) - idx),
398
+ 0 if codes_in_sentence.index(x) > idx else 1)
399
+ )
400
+
401
+ #print("sorted_children:", sorted_children)
402
+ for child in sorted_children:
403
+ child_lower = child.lower()
404
+ if child_lower not in visited_words:
405
+ visited_words.add(child_lower)
406
+ country, should_stop = dfs_traverse(
407
+ child_lower, current_depth + 1, max_depth,
408
+ visited_words, collected_sentences, parent_sentence=sentence
409
+ )
410
+ if should_stop:
411
+ return country, True
412
+
413
+ return country, False
414
+
415
+ # Begin DFS
416
+ collected_sentences = set()
417
+ visited_words = set([keyword.lower()])
418
+ country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
419
+
420
+ # Filter irrelevant sentences
421
+ final_sentences = set()
422
+ for sentence in collected_sentences:
423
+ if not is_irrelevant_number_sequence(sentence):
424
+ processed = remove_isolated_single_digits(sentence)
425
+ if processed:
426
+ final_sentences.add(processed)
427
+ if not final_sentences:
428
+ return country, text_content
429
+ return country, "\n".join(sorted(list(final_sentences)))
430
+
431
+ # Helper function for normalizing text for overlap comparison
432
+ def normalize_for_overlap(s: str) -> str:
433
+ s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
434
+ s = re.sub(r'\s+', ' ', s).strip()
435
+ return s
436
+
437
+ def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
438
+ if not text1: return text2
439
+ if not text2: return text1
440
+
441
+ # Case 1: text2 is fully contained in text1 or vice-versa
442
+ if text2 in text1:
443
+ return text1
444
+ if text1 in text2:
445
+ return text2
446
+
447
+ # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
448
+ # This is what your function was primarily designed for.
449
+ # It looks for the overlap at the "junction" of text1 and text2.
450
+
451
+ max_junction_overlap = 0
452
+ for i in range(min(len(text1), len(text2)), 0, -1):
453
+ suffix1 = text1[-i:]
454
+ prefix2 = text2[:i]
455
+ # Prioritize exact match, then normalized match
456
+ if suffix1 == prefix2:
457
+ max_junction_overlap = i
458
+ break
459
+ elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
460
+ max_junction_overlap = i
461
+ break # Take the first (longest) normalized match
462
+
463
+ if max_junction_overlap > 0:
464
+ merged_text = text1 + text2[max_junction_overlap:]
465
+ return re.sub(r'\s+', ' ', merged_text).strip()
466
+
467
+ # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
468
+ # This addresses your specific test case where the overlap is at the very beginning of both strings.
469
+ # This is often used when trying to deduplicate content that shares a common start.
470
+
471
+ longest_common_prefix_len = 0
472
+ min_len = min(len(text1), len(text2))
473
+ for i in range(min_len):
474
+ if text1[i] == text2[i]:
475
+ longest_common_prefix_len = i + 1
476
+ else:
477
+ break
478
+
479
+ # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
480
+ # AND the remaining parts are distinct, then apply this merge.
481
+ # This is a heuristic and might need fine-tuning.
482
+ if longest_common_prefix_len > 0 and \
483
+ text1[longest_common_prefix_len:].strip() and \
484
+ text2[longest_common_prefix_len:].strip():
485
+
486
+ # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
487
+ # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
488
+ # common prefix is "Hi, I am Vy."
489
+ # Remaining text1: " Nice to meet you."
490
+ # Remaining text2: " Goodbye Vy."
491
+ # So we merge common_prefix + remaining_text1 + remaining_text2
492
+
493
+ common_prefix_str = text1[:longest_common_prefix_len]
494
+ remainder_text1 = text1[longest_common_prefix_len:]
495
+ remainder_text2 = text2[longest_common_prefix_len:]
496
+
497
+ merged_text = common_prefix_str + remainder_text1 + remainder_text2
498
+ return re.sub(r'\s+', ' ', merged_text).strip()
499
+
500
+
501
+ # If neither specific overlap type is found, just concatenate
502
+ merged_text = text1 + text2
503
+ return re.sub(r'\s+', ' ', merged_text).strip()
504
+
505
+ def save_text_to_docx(text_content: str, file_path: str):
506
+ """
507
+ Saves a given text string into a .docx file.
508
+
509
+ Args:
510
+ text_content (str): The text string to save.
511
+ file_path (str): The full path including the filename where the .docx file will be saved.
512
+ Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
513
+ """
514
+ try:
515
+ document = Document()
516
+
517
+ # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
518
+ for paragraph_text in text_content.split('\n'):
519
+ document.add_paragraph(paragraph_text)
520
+
521
+ document.save(file_path)
522
+ print(f"Text successfully saved to '{file_path}'")
523
+ except Exception as e:
524
+ print(f"Error saving text to docx file: {e}")
525
+
526
+ '''2 scenerios:
527
+ - quick look then found then deepdive and directly get location then stop
528
+ - quick look then found then deepdive but not find location then hold the related words then
529
+ look another files iteratively for each related word and find location and stop'''
530
+ def extract_context(text, keyword, window=500):
531
+ # firstly try accession number
532
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
533
+
534
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
535
+ keyword_match = code_pattern.search(keyword)
536
+
537
+ keyword_prefix = None
538
+ keyword_num = None
539
+
540
+ if keyword_match:
541
+ keyword_prefix = keyword_match.group(1).lower()
542
+ keyword_num = int(keyword_match.group(2))
543
+ text = text.lower()
544
+ idx = text.find(keyword.lower())
545
+ if idx == -1:
546
+ if keyword_prefix:
547
+ idx = text.find(keyword_prefix)
548
+ if idx == -1:
549
+ return "Sample ID not found."
550
+ return text[max(0, idx-window): idx+window]
551
+ return text[max(0, idx-window): idx+window]
552
+ def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
553
+ cache = {}
554
+ country = "unknown"
555
+ output = ""
556
+ tem_output, small_output = "",""
557
+ keyword_appear = (False,"")
558
+ keywords = []
559
+ if isolate: keywords.append(isolate)
560
+ if accession: keywords.append(accession)
561
+ for f in filePaths:
562
+ # scenerio 1: direct location: truncate the context and then use qa model?
563
+ if keywords:
564
+ for keyword in keywords:
565
+ text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
566
+ if keyword in final_input:
567
+ context = extract_context(final_input, keyword)
568
+ # quick look if country already in context and if yes then return
569
+ country = model.get_country_from_text(context)
570
+ if country != "unknown":
571
+ return country, context, final_input
572
+ else:
573
+ country = model.get_country_from_text(final_input)
574
+ if country != "unknown":
575
+ return country, context, final_input
576
+ else: # might be cross-ref
577
+ keyword_appear = (True, f)
578
+ cache[f] = context
579
+ small_output = merge_texts_skipping_overlap(output, context) + "\n"
580
+ chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
581
+ countryBFS = model.get_country_from_text(chunkBFS)
582
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
583
+ output = merge_texts_skipping_overlap(output, final_input)
584
+ if countryDFS != "unknown" and countryBFS != "unknown":
585
+ if len(chunkDFS) <= len(chunkBFS):
586
+ return countryDFS, chunkDFS, output
587
+ else:
588
+ return countryBFS, chunkBFS, output
589
+ else:
590
+ if countryDFS != "unknown":
591
+ return countryDFS, chunkDFS, output
592
+ if countryBFS != "unknown":
593
+ return countryBFS, chunkBFS, output
594
+ else:
595
+ # scenerio 2:
596
+ '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
597
+ but if we look at file 1 first then maybe we can have lookup dict which country
598
+ such as Thailand as the key and its re'''
599
+ cache[f] = final_input
600
+ if keyword_appear[0] == True:
601
+ for c in cache:
602
+ if c!=keyword_appear[1]:
603
+ if cache[c].lower() not in output.lower():
604
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
605
+ chunkBFS = get_contextual_sentences_BFS(output, keyword)
606
+ countryBFS = model.get_country_from_text(chunkBFS)
607
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
608
+ if countryDFS != "unknown" and countryBFS != "unknown":
609
+ if len(chunkDFS) <= len(chunkBFS):
610
+ return countryDFS, chunkDFS, output
611
+ else:
612
+ return countryBFS, chunkBFS, output
613
+ else:
614
+ if countryDFS != "unknown":
615
+ return countryDFS, chunkDFS, output
616
+ if countryBFS != "unknown":
617
+ return countryBFS, chunkBFS, output
618
+ else:
619
+ if cache[f].lower() not in output.lower():
620
+ output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
621
+ if len(output) == 0 or keyword_appear[0]==False:
622
+ for c in cache:
623
+ if cache[c].lower() not in output.lower():
624
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
625
+ return country, "", output
model.py ADDED
@@ -0,0 +1,1255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pycountry
3
+ from docx import Document
4
+ import json
5
+ import os
6
+ import numpy as np
7
+ import faiss
8
+ from collections import defaultdict
9
+ import ast # For literal_eval
10
+ import math # For ceiling function
11
+ from iterate3 import data_preprocess
12
+ import mtdna_classifier
13
+ # --- IMPORTANT: UNCOMMENT AND CONFIGURE YOUR REAL API KEY ---
14
+ import google.generativeai as genai
15
+ os.environ["GOOGLE_API_KEY"] = "AIzaSyDi0CNKBgEtnr6YuPaY6YNEuC5wT0cdKhk"
16
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
17
+
18
+ import nltk
19
+ from nltk.corpus import stopwords
20
+ try:
21
+ nltk.data.find('corpora/stopwords')
22
+ except LookupError:
23
+ nltk.download('stopwords')
24
+ nltk.download('punkt_tab')
25
+ # --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) ---
26
+ # Prices are per 1,000 tokens
27
+ PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
28
+ PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
29
+ PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
30
+
31
+ # --- API Functions (REAL API FUNCTIONS) ---
32
+
33
+ # def get_embedding(text, task_type="RETRIEVAL_DOCUMENT"):
34
+ # """Generates an embedding for the given text using a Google embedding model."""
35
+ # try:
36
+ # result = genai.embed_content(
37
+ # model="models/text-embedding-004", # Specify the embedding model
38
+ # content=text,
39
+ # task_type=task_type
40
+ # )
41
+ # return np.array(result['embedding']).astype('float32')
42
+ # except Exception as e:
43
+ # print(f"Error getting embedding: {e}")
44
+ # return np.zeros(768, dtype='float32')
45
+ def get_embedding(text, task_type="RETRIEVAL_DOCUMENT"):
46
+ """Safe Gemini 1.5 embedding call with fallback."""
47
+ import numpy as np
48
+ try:
49
+ if not text or len(text.strip()) == 0:
50
+ raise ValueError("Empty text cannot be embedded.")
51
+ result = genai.embed_content(
52
+ model="models/text-embedding-004",
53
+ content=text,
54
+ task_type=task_type
55
+ )
56
+ return np.array(result['embedding'], dtype='float32')
57
+ except Exception as e:
58
+ print(f"❌ Embedding error: {e}")
59
+ return np.zeros(768, dtype='float32')
60
+
61
+
62
+ def call_llm_api(prompt, model_name='gemini-1.5-flash-latest'):
63
+ """Calls a Google Gemini LLM with the given prompt."""
64
+ try:
65
+ model = genai.GenerativeModel(model_name)
66
+ response = model.generate_content(prompt)
67
+ return response.text, model # Return model instance for token counting
68
+ except Exception as e:
69
+ print(f"Error calling LLM: {e}")
70
+ return "Error: Could not get response from LLM API.", None
71
+
72
+
73
+ # --- Core Document Processing Functions (All previously provided and fixed) ---
74
+
75
+ def read_docx_text(path):
76
+ """
77
+ Reads text and extracts potential table-like strings from a .docx document.
78
+ Separates plain text from structured [ [ ] ] list-like tables.
79
+ Also attempts to extract a document title.
80
+ """
81
+ doc = Document(path)
82
+ plain_text_paragraphs = []
83
+ table_strings = []
84
+ document_title = "Unknown Document Title" # Default
85
+
86
+ # Attempt to extract the document title from the first few paragraphs
87
+ title_paragraphs = [p.text.strip() for p in doc.paragraphs[:5] if p.text.strip()]
88
+ if title_paragraphs:
89
+ # A heuristic to find a title: often the first or second non-empty paragraph
90
+ # or a very long first paragraph if it's the title
91
+ if len(title_paragraphs[0]) > 50 and "Human Genetics" not in title_paragraphs[0]:
92
+ document_title = title_paragraphs[0]
93
+ elif len(title_paragraphs) > 1 and len(title_paragraphs[1]) > 50 and "Human Genetics" not in title_paragraphs[1]:
94
+ document_title = title_paragraphs[1]
95
+ elif any("Complete mitochondrial genomes" in p for p in title_paragraphs):
96
+ # Fallback to a known title phrase if present
97
+ document_title = "Complete mitochondrial genomes of Thai and Lao populations indicate an ancient origin of Austroasiatic groups and demic diffusion in the spread of Tai–Kadai languages"
98
+
99
+ current_table_lines = []
100
+ in_table_parsing_mode = False
101
+
102
+ for p in doc.paragraphs:
103
+ text = p.text.strip()
104
+ if not text:
105
+ continue
106
+
107
+ # Condition to start or continue table parsing
108
+ if text.startswith("## Table "): # Start of a new table section
109
+ if in_table_parsing_mode and current_table_lines:
110
+ table_strings.append("\n".join(current_table_lines))
111
+ current_table_lines = [text] # Include the "## Table X" line
112
+ in_table_parsing_mode = True
113
+ elif in_table_parsing_mode and (text.startswith("[") or text.startswith('"')):
114
+ # Continue collecting lines if we're in table mode and it looks like table data
115
+ # Table data often starts with '[' for lists, or '"' for quoted strings within lists.
116
+ current_table_lines.append(text)
117
+ else:
118
+ # If not in table mode, or if a line doesn't look like table data,
119
+ # then close the current table (if any) and add the line to plain text.
120
+ if in_table_parsing_mode and current_table_lines:
121
+ table_strings.append("\n".join(current_table_lines))
122
+ current_table_lines = []
123
+ in_table_parsing_mode = False
124
+ plain_text_paragraphs.append(text)
125
+
126
+ # After the loop, add any remaining table lines
127
+ if current_table_lines:
128
+ table_strings.append("\n".join(current_table_lines))
129
+
130
+ return "\n".join(plain_text_paragraphs), table_strings, document_title
131
+
132
+ # --- Structured Data Extraction and RAG Functions ---
133
+
134
+ def parse_literal_python_list(table_str):
135
+ list_match = re.search(r'(\[\s*\[\s*(?:.|\n)*?\s*\]\s*\])', table_str)
136
+ #print("Debug: list_match object (before if check):", list_match)
137
+ if not list_match:
138
+ if "table" in table_str.lower(): # then the table doest have the "]]" at the end
139
+ table_str += "]]"
140
+ list_match = re.search(r'(\[\s*\[\s*(?:.|\n)*?\s*\]\s*\])', table_str)
141
+ if list_match:
142
+ try:
143
+ matched_string = list_match.group(1)
144
+ #print("Debug: Matched string for literal_eval:", matched_string)
145
+ return ast.literal_eval(matched_string)
146
+ except (ValueError, SyntaxError) as e:
147
+ print(f"Error evaluating literal: {e}")
148
+ return []
149
+ return []
150
+
151
+
152
+ _individual_code_parser = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
153
+ def _parse_individual_code_parts(code_str):
154
+ match = _individual_code_parser.search(code_str)
155
+ if match:
156
+ return match.group(1), match.group(2)
157
+ return None, None
158
+
159
+
160
+ def parse_sample_id_to_population_code(plain_text_content):
161
+ sample_id_map = {}
162
+ contiguous_ranges_data = defaultdict(list)
163
+
164
+ #section_start_marker = "The sample identification of each population is as follows:"
165
+ section_start_marker = ["The sample identification of each population is as follows:","## table"]
166
+
167
+ for s in section_start_marker:
168
+ relevant_text_search = re.search(
169
+ re.escape(s.lower()) + r"\s*(.*?)(?=\n##|\Z)",
170
+ plain_text_content.lower(),
171
+ re.DOTALL
172
+ )
173
+ if relevant_text_search:
174
+ break
175
+
176
+ if not relevant_text_search:
177
+ print("Warning: 'Sample ID Population Code' section start marker not found or block empty.")
178
+ return sample_id_map, contiguous_ranges_data
179
+
180
+ relevant_text_block = relevant_text_search.group(1).strip()
181
+
182
+ # print(f"\nDEBUG_PARSING: --- Start of relevant_text_block (first 500 chars) ---")
183
+ # print(relevant_text_block[:500])
184
+ # print(f"DEBUG_PARSING: --- End of relevant_text_block (last 500 chars) ---")
185
+ # print(relevant_text_block[-500:])
186
+ # print(f"DEBUG_PARSING: Relevant text block length: {len(relevant_text_block)}")
187
+
188
+ mapping_pattern = re.compile(
189
+ r'\b([A-Z0-9]+\d+)(?:-([A-Z0-9]+\d+))?\s+([A-Z0-9]+)\b', # Changed the last group
190
+ re.IGNORECASE)
191
+
192
+ range_expansion_count = 0
193
+ direct_id_count = 0
194
+ total_matches_found = 0
195
+ for match in mapping_pattern.finditer(relevant_text_block):
196
+ total_matches_found += 1
197
+ id1_full_str, id2_full_str_opt, pop_code = match.groups()
198
+
199
+ #print(f" DEBUG_PARSING: Matched: '{match.group(0)}'")
200
+
201
+ pop_code_upper = pop_code.upper()
202
+
203
+ id1_prefix, id1_num_str = _parse_individual_code_parts(id1_full_str)
204
+ if id1_prefix is None:
205
+ #print(f" DEBUG_PARSING: Failed to parse ID1: {id1_full_str}. Skipping this mapping.")
206
+ continue
207
+
208
+ if id2_full_str_opt:
209
+ id2_prefix_opt, id2_num_str_opt = _parse_individual_code_parts(id2_full_str_opt)
210
+ if id2_prefix_opt is None:
211
+ #print(f" DEBUG_PARSING: Failed to parse ID2: {id2_full_str_opt}. Treating {id1_full_str} as single ID1.")
212
+ sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
213
+ direct_id_count += 1
214
+ continue
215
+
216
+ #print(f" DEBUG_PARSING: Comparing prefixes: '{id1_prefix.lower()}' vs '{id2_prefix_opt.lower()}'")
217
+ if id1_prefix.lower() == id2_prefix_opt.lower():
218
+ #print(f" DEBUG_PARSING: ---> Prefixes MATCH for range expansion! Range: {id1_prefix}{id1_num_str}-{id2_prefix_opt}{id2_num_str_opt}")
219
+ try:
220
+ start_num = int(id1_num_str)
221
+ end_num = int(id2_num_str_opt)
222
+ for num in range(start_num, end_num + 1):
223
+ sample_id = f"{id1_prefix.upper()}{num}"
224
+ sample_id_map[sample_id] = pop_code_upper
225
+ range_expansion_count += 1
226
+ contiguous_ranges_data[id1_prefix.upper()].append(
227
+ (start_num, end_num, pop_code_upper)
228
+ )
229
+ except ValueError:
230
+ print(f" DEBUG_PARSING: ValueError in range conversion for {id1_num_str}-{id2_num_str_opt}. Adding endpoints only.")
231
+ sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
232
+ sample_id_map[f"{id2_prefix_opt.upper()}{id2_num_str_opt}"] = pop_code_upper
233
+ direct_id_count += 2
234
+ else:
235
+ #print(f" DEBUG_PARSING: Prefixes MISMATCH for range: '{id1_prefix}' vs '{id2_prefix_opt}'. Adding endpoints only.")
236
+ sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
237
+ sample_id_map[f"{id2_prefix_opt.upper()}{id2_num_str_opt}"] = pop_code_upper
238
+ direct_id_count += 2
239
+ else:
240
+ sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
241
+ direct_id_count += 1
242
+
243
+ # print(f"DEBUG_PARSING: Total matches found by regex: {total_matches_found}.")
244
+ # print(f"DEBUG_PARSING: Parsed sample IDs: {len(sample_id_map)} total entries.")
245
+ # print(f"DEBUG_PARSING: (including {range_expansion_count} from range expansion and {direct_id_count} direct ID/endpoint entries).")
246
+ return sample_id_map, contiguous_ranges_data
247
+
248
+ country_keywords_regional_overrides = {
249
+ "north thailand": "Thailand", "central thailand": "Thailand",
250
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
251
+ "central india": "India", "east india": "India", "northeast india": "India",
252
+ "south sibera": "Russia", "siberia": "Russia", "yunnan": "China", #"tibet": "China",
253
+ "sumatra": "Indonesia", "borneo": "Indonesia",
254
+ "northern mindanao": "Philippines", "west malaysia": "Malaysia",
255
+ "mongolia": "China",
256
+ "beijing": "China",
257
+ "north laos": "Laos", "central laos": "Laos",
258
+ "east myanmar": "Myanmar", "west myanmar": "Myanmar"}
259
+
260
+ # Updated get_country_from_text function
261
+ def get_country_from_text(text):
262
+ text_lower = text.lower()
263
+
264
+ # 1. Use pycountry for official country names and common aliases
265
+ for country in pycountry.countries:
266
+ # Check full name match first
267
+ if text_lower == country.name.lower():
268
+ return country.name
269
+
270
+ # Safely check for common_name
271
+ if hasattr(country, 'common_name') and text_lower == country.common_name.lower():
272
+ return country.common_name
273
+
274
+ # Safely check for official_name
275
+ if hasattr(country, 'official_name') and text_lower == country.official_name.lower():
276
+ return country.official_name
277
+
278
+ # Check if country name is part of the text (e.g., 'Thailand' in 'Thailand border')
279
+ if country.name.lower() in text_lower:
280
+ return country.name
281
+
282
+ # Safely check if common_name is part of the text
283
+ if hasattr(country, 'common_name') and country.common_name.lower() in text_lower:
284
+ return country.common_name
285
+ # 2. Prioritize specific regional overrides
286
+ for keyword, country in country_keywords_regional_overrides.items():
287
+ if keyword in text_lower:
288
+ return country
289
+ # 3. Check for broader regions that you want to map to "unknown" or a specific country
290
+ if "north asia" in text_lower or "southeast asia" in text_lower or "east asia" in text_lower:
291
+ return "unknown"
292
+
293
+ return "unknown"
294
+
295
+ # Get the list of English stop words from NLTK
296
+ non_meaningful_pop_names = set(stopwords.words('english'))
297
+
298
+ def parse_population_code_to_country(plain_text_content, table_strings):
299
+ pop_code_country_map = {}
300
+ pop_code_ethnicity_map = {} # NEW: To store ethnicity for structured lookup
301
+ pop_code_specific_loc_map = {} # NEW: To store specific location for structured lookup
302
+
303
+ # Regex for parsing population info in structured lists and general text
304
+ # This pattern captures: (Pop Name/Ethnicity) (Pop Code) (Region/Specific Location) (Country) (Linguistic Family)
305
+ # The 'Pop Name/Ethnicity' (Group 1) is often the ethnicity
306
+ pop_info_pattern = re.compile(
307
+ r'([A-Za-z\s]+?)\s+([A-Z]+\d*)\s+' # Pop Name (Group 1), Pop Code (Group 2) - Changed \d+ to \d* for codes like 'SH'
308
+ r'([A-Za-z\s\(\)\-,\/]+?)\s+' # Region/Specific Location (Group 3)
309
+ r'(North+|South+|West+|East+|Thailand|Laos|Cambodia|Myanmar|Philippines|Indonesia|Malaysia|China|India|Taiwan|Vietnam|Russia|Nepal|Japan|South Korea)\b' # Country (Group 4)
310
+ r'(?:.*?([A-Za-z\s\-]+))?\s*' # Optional Linguistic Family (Group 5), made optional with ?, followed by optional space
311
+ r'(\d+(?:\s+\d+\.?\d*)*)?', # Match all the numbers (Group 6) - made optional
312
+ re.IGNORECASE
313
+ )
314
+ for table_str in table_strings:
315
+ table_data = parse_literal_python_list(table_str)
316
+ if table_data:
317
+ is_list_of_lists = bool(table_data) and isinstance(table_data[0], list)
318
+ if is_list_of_lists:
319
+ for row_idx, row in enumerate(table_data):
320
+ row_text = " ".join(map(str, row))
321
+ match = pop_info_pattern.search(row_text)
322
+ if match:
323
+ pop_name = match.group(1).strip()
324
+ pop_code = match.group(2).upper()
325
+ specific_loc_text = match.group(3).strip()
326
+ country_text = match.group(4).strip()
327
+ linguistic_family = match.group(5).strip() if match.group(5) else 'unknown'
328
+
329
+ final_country = get_country_from_text(country_text)
330
+ if final_country == 'unknown': # Try specific loc text for country if direct country is not found
331
+ final_country = get_country_from_text(specific_loc_text)
332
+
333
+ if pop_code:
334
+ pop_code_country_map[pop_code] = final_country
335
+
336
+ # Populate ethnicity map (often Pop Name is ethnicity)
337
+ pop_code_ethnicity_map[pop_code] = pop_name
338
+
339
+ # Populate specific location map
340
+ pop_code_specific_loc_map[pop_code] = specific_loc_text # Store as is from text
341
+ else:
342
+ row_text = " ".join(map(str, table_data))
343
+ match = pop_info_pattern.search(row_text)
344
+ if match:
345
+ pop_name = match.group(1).strip()
346
+ pop_code = match.group(2).upper()
347
+ specific_loc_text = match.group(3).strip()
348
+ country_text = match.group(4).strip()
349
+ linguistic_family = match.group(5).strip() if match.group(5) else 'unknown'
350
+
351
+ final_country = get_country_from_text(country_text)
352
+ if final_country == 'unknown': # Try specific loc text for country if direct country is not found
353
+ final_country = get_country_from_text(specific_loc_text)
354
+
355
+ if pop_code:
356
+ pop_code_country_map[pop_code] = final_country
357
+
358
+ # Populate ethnicity map (often Pop Name is ethnicity)
359
+ pop_code_ethnicity_map[pop_code] = pop_name
360
+
361
+ # Populate specific location map
362
+ pop_code_specific_loc_map[pop_code] = specific_loc_text # Store as is from text
363
+
364
+ # # Special case refinements for ethnicity/location if more specific rules are known from document:
365
+ # if pop_name.lower() == "khon mueang": # and specific conditions if needed
366
+ # pop_code_ethnicity_map[pop_code] = "Khon Mueang"
367
+ # # If Khon Mueang has a specific city/district, add here
368
+ # # e.g., if 'Chiang Mai' is directly linked to KM1 in a specific table
369
+ # # pop_code_specific_loc_map[pop_code] = "Chiang Mai"
370
+ # elif pop_name.lower() == "lawa":
371
+ # pop_code_ethnicity_map[pop_code] = "Lawa"
372
+ # # Add similar specific rules for other populations (e.g., Mon for MO1, MO2, MO3)
373
+ # elif pop_name.lower() == "mon":
374
+ # pop_code_ethnicity_map[pop_code] = "Mon"
375
+ # # For MO2: "West Thailand (Thailand Myanmar border)" -> no city
376
+ # # For MO3: "East Myanmar (Thailand Myanmar border)" -> no city
377
+ # # If the doc gives "Bangkok" for MO4, add it here for MO4's actual specific_location.
378
+ # # etc.
379
+
380
+ # Fallback to parsing general plain text content (sentences)
381
+ sentences = data_preprocess.extract_sentences(plain_text_content)
382
+ for s in sentences: # Still focusing on just this one sentence
383
+ # Use re.finditer to get all matches
384
+ matches = pop_info_pattern.finditer(s)
385
+ pop_name, pop_code, specific_loc_text, country_text = "unknown", "unknown", "unknown", "unknown"
386
+ for match in matches:
387
+ if match.group(1):
388
+ pop_name = match.group(1).strip()
389
+ if match.group(2):
390
+ pop_code = match.group(2).upper()
391
+ if match.group(3):
392
+ specific_loc_text = match.group(3).strip()
393
+ if match.group(4):
394
+ country_text = match.group(4).strip()
395
+ # linguistic_family = match.group(5).strip() if match.group(5) else 'unknown' # Already captured by pop_info_pattern
396
+
397
+ final_country = get_country_from_text(country_text)
398
+ if final_country == 'unknown':
399
+ final_country = get_country_from_text(specific_loc_text)
400
+
401
+ if pop_code.lower() not in non_meaningful_pop_names:
402
+ if final_country.lower() not in non_meaningful_pop_names:
403
+ pop_code_country_map[pop_code] = final_country
404
+ if pop_name.lower() not in non_meaningful_pop_names:
405
+ pop_code_ethnicity_map[pop_code] = pop_name # Default ethnicity from Pop Name
406
+ if specific_loc_text.lower() not in non_meaningful_pop_names:
407
+ pop_code_specific_loc_map[pop_code] = specific_loc_text
408
+
409
+ # Specific rules for ethnicity/location in plain text:
410
+ if pop_name.lower() == "khon mueang":
411
+ pop_code_ethnicity_map[pop_code] = "Khon Mueang"
412
+ elif pop_name.lower() == "lawa":
413
+ pop_code_ethnicity_map[pop_code] = "Lawa"
414
+ elif pop_name.lower() == "mon":
415
+ pop_code_ethnicity_map[pop_code] = "Mon"
416
+ elif pop_name.lower() == "seak": # Added specific rule for Seak
417
+ pop_code_ethnicity_map[pop_code] = "Seak"
418
+ elif pop_name.lower() == "nyaw": # Added specific rule for Nyaw
419
+ pop_code_ethnicity_map[pop_code] = "Nyaw"
420
+ elif pop_name.lower() == "nyahkur": # Added specific rule for Nyahkur
421
+ pop_code_ethnicity_map[pop_code] = "Nyahkur"
422
+ elif pop_name.lower() == "suay": # Added specific rule for Suay
423
+ pop_code_ethnicity_map[pop_code] = "Suay"
424
+ elif pop_name.lower() == "soa": # Added specific rule for Soa
425
+ pop_code_ethnicity_map[pop_code] = "Soa"
426
+ elif pop_name.lower() == "bru": # Added specific rule for Bru
427
+ pop_code_ethnicity_map[pop_code] = "Bru"
428
+ elif pop_name.lower() == "khamu": # Added specific rule for Khamu
429
+ pop_code_ethnicity_map[pop_code] = "Khamu"
430
+
431
+ return pop_code_country_map, pop_code_ethnicity_map, pop_code_specific_loc_map
432
+
433
+ def general_parse_population_code_to_country(plain_text_content, table_strings):
434
+ pop_code_country_map = {}
435
+ pop_code_ethnicity_map = {}
436
+ pop_code_specific_loc_map = {}
437
+ sample_id_to_pop_code = {}
438
+
439
+ for table_str in table_strings:
440
+ table_data = parse_literal_python_list(table_str)
441
+ if not table_data or not isinstance(table_data[0], list):
442
+ continue
443
+
444
+ header_row = [col.lower() for col in table_data[0]]
445
+ header_map = {col: idx for idx, col in enumerate(header_row)}
446
+
447
+ # MJ17: Direct PopCode → Country
448
+ if 'id' in header_map and 'country' in header_map:
449
+ for row in table_strings[1:]:
450
+ row = parse_literal_python_list(row)[0]
451
+ if len(row) < len(header_row):
452
+ continue
453
+ pop_code = str(row[header_map['id']]).strip()
454
+ country = str(row[header_map['country']]).strip()
455
+ province = row[header_map['province']].strip() if 'province' in header_map else 'unknown'
456
+ pop_group = row[header_map['population group / region']].strip() if 'population group / region' in header_map else 'unknown'
457
+ pop_code_country_map[pop_code] = country
458
+ pop_code_specific_loc_map[pop_code] = province
459
+ pop_code_ethnicity_map[pop_code] = pop_group
460
+
461
+ # A1YU101 or EBK/KSK: SampleID → PopCode
462
+ elif 'sample id' in header_map and 'population code' in header_map:
463
+ for row in table_strings[1:]:
464
+ row = parse_literal_python_list(row)[0]
465
+ if len(row) < 2:
466
+ continue
467
+ sample_id = row[header_map['sample id']].strip().upper()
468
+ pop_code = row[header_map['population code']].strip().upper()
469
+ sample_id_to_pop_code[sample_id] = pop_code
470
+
471
+ # PopCode → Country (A1YU101/EBK mapping)
472
+ elif 'population code' in header_map and 'country' in header_map:
473
+ for row in table_strings[1:]:
474
+ row = parse_literal_python_list(row)[0]
475
+ if len(row) < 2:
476
+ continue
477
+ pop_code = row[header_map['population code']].strip().upper()
478
+ country = row[header_map['country']].strip()
479
+ pop_code_country_map[pop_code] = country
480
+
481
+ return pop_code_country_map, pop_code_ethnicity_map, pop_code_specific_loc_map, sample_id_to_pop_code
482
+
483
+ def chunk_text(text, chunk_size=500, overlap=50):
484
+ """Splits text into chunks (by words) with overlap."""
485
+ chunks = []
486
+ words = text.split()
487
+ num_words = len(words)
488
+
489
+ start = 0
490
+ while start < num_words:
491
+ end = min(start + chunk_size, num_words)
492
+ chunk = " ".join(words[start:end])
493
+ chunks.append(chunk)
494
+
495
+ if end == num_words:
496
+ break
497
+ start += chunk_size - overlap # Move start by (chunk_size - overlap)
498
+ return chunks
499
+
500
+ def build_vector_index_and_data(doc_path, index_path="faiss_index.bin", chunks_path="document_chunks.json", structured_path="structured_lookup.json"):
501
+ """
502
+ Reads document, builds structured lookup, chunks remaining text, embeds chunks,
503
+ and builds/saves a FAISS index.
504
+ """
505
+ print("Step 1: Reading document and extracting structured data...")
506
+ # plain_text_content, table_strings, document_title = read_docx_text(doc_path) # Get document_title here
507
+
508
+ # sample_id_map, contiguous_ranges_data = parse_sample_id_to_population_code(plain_text_content)
509
+ # pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc = parse_population_code_to_country(plain_text_content, table_strings)
510
+
511
+ # master_structured_lookup = {}
512
+ # master_structured_lookup['document_title'] = document_title # Store document title
513
+ # master_structured_lookup['sample_id_map'] = sample_id_map
514
+ # master_structured_lookup['contiguous_ranges'] = dict(contiguous_ranges_data)
515
+ # master_structured_lookup['pop_code_to_country'] = pop_code_to_country
516
+ # master_structured_lookup['pop_code_to_ethnicity'] = pop_code_to_ethnicity # NEW: Store pop_code to ethnicity map
517
+ # master_structured_lookup['pop_code_to_specific_loc'] = pop_code_to_specific_loc # NEW: Store pop_code to specific_loc map
518
+
519
+
520
+ # # Final consolidation: Use sample_id_map to derive full info for queries
521
+ # final_structured_entries = {}
522
+ # for sample_id, pop_code in master_structured_lookup['sample_id_map'].items():
523
+ # country = master_structured_lookup['pop_code_to_country'].get(pop_code, 'unknown')
524
+ # ethnicity = master_structured_lookup['pop_code_to_ethnicity'].get(pop_code, 'unknown') # Retrieve ethnicity
525
+ # specific_location = master_structured_lookup['pop_code_to_specific_loc'].get(pop_code, 'unknown') # Retrieve specific location
526
+
527
+ # final_structured_entries[sample_id] = {
528
+ # 'population_code': pop_code,
529
+ # 'country': country,
530
+ # 'type': 'modern',
531
+ # 'ethnicity': ethnicity, # Store ethnicity
532
+ # 'specific_location': specific_location # Store specific location
533
+ # }
534
+ # master_structured_lookup['final_structured_entries'] = final_structured_entries
535
+ plain_text_content, table_strings, document_title = read_docx_text(doc_path)
536
+ pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc, sample_id_map = general_parse_population_code_to_country(plain_text_content, table_strings)
537
+
538
+ final_structured_entries = {}
539
+ if sample_id_map:
540
+ for sample_id, pop_code in sample_id_map.items():
541
+ country = pop_code_to_country.get(pop_code, 'unknown')
542
+ ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
543
+ specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
544
+ final_structured_entries[sample_id] = {
545
+ 'population_code': pop_code,
546
+ 'country': country,
547
+ 'type': 'modern',
548
+ 'ethnicity': ethnicity,
549
+ 'specific_location': specific_loc
550
+ }
551
+ else:
552
+ for pop_code in pop_code_to_country.keys():
553
+ country = pop_code_to_country.get(pop_code, 'unknown')
554
+ ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
555
+ specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
556
+ final_structured_entries[pop_code] = {
557
+ 'population_code': pop_code,
558
+ 'country': country,
559
+ 'type': 'modern',
560
+ 'ethnicity': ethnicity,
561
+ 'specific_location': specific_loc
562
+ }
563
+ if not final_structured_entries:
564
+ # traditional way of A1YU101
565
+ sample_id_map, contiguous_ranges_data = parse_sample_id_to_population_code(plain_text_content)
566
+ pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc = parse_population_code_to_country(plain_text_content, table_strings)
567
+ if sample_id_map:
568
+ for sample_id, pop_code in sample_id_map.items():
569
+ country = pop_code_to_country.get(pop_code, 'unknown')
570
+ ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
571
+ specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
572
+ final_structured_entries[sample_id] = {
573
+ 'population_code': pop_code,
574
+ 'country': country,
575
+ 'type': 'modern',
576
+ 'ethnicity': ethnicity,
577
+ 'specific_location': specific_loc
578
+ }
579
+ else:
580
+ for pop_code in pop_code_to_country.keys():
581
+ country = pop_code_to_country.get(pop_code, 'unknown')
582
+ ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
583
+ specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
584
+ final_structured_entries[pop_code] = {
585
+ 'population_code': pop_code,
586
+ 'country': country,
587
+ 'type': 'modern',
588
+ 'ethnicity': ethnicity,
589
+ 'specific_location': specific_loc
590
+ }
591
+
592
+ master_lookup = {
593
+ 'document_title': document_title,
594
+ 'pop_code_to_country': pop_code_to_country,
595
+ 'pop_code_to_ethnicity': pop_code_to_ethnicity,
596
+ 'pop_code_to_specific_loc': pop_code_to_specific_loc,
597
+ 'sample_id_map': sample_id_map,
598
+ 'final_structured_entries': final_structured_entries
599
+ }
600
+ print(f"Structured lookup built with {len(final_structured_entries)} entries in 'final_structured_entries'.")
601
+
602
+ with open(structured_path, 'w') as f:
603
+ json.dump(master_lookup, f, indent=4)
604
+ print(f"Structured lookup saved to {structured_path}.")
605
+
606
+ print("Step 2: Chunking document for RAG vector index...")
607
+ # replace the chunk here with the all_output from process_inputToken and fallback to this traditional chunk
608
+ clean_text, clean_table = "", ""
609
+ if plain_text_content:
610
+ clean_text = data_preprocess.normalize_for_overlap(plain_text_content)
611
+ if table_strings:
612
+ clean_table = data_preprocess.normalize_for_overlap(". ".join(table_strings))
613
+ all_clean_chunk = clean_text + clean_table
614
+ document_chunks = chunk_text(all_clean_chunk)
615
+ print(f"Document chunked into {len(document_chunks)} chunks.")
616
+
617
+ print("Step 3: Generating embeddings for chunks (this might take time and cost API calls)...")
618
+
619
+ embedding_model_for_chunks = genai.GenerativeModel('models/text-embedding-004')
620
+
621
+ chunk_embeddings = []
622
+ for i, chunk in enumerate(document_chunks):
623
+ embedding = get_embedding(chunk, task_type="RETRIEVAL_DOCUMENT")
624
+ if embedding is not None and embedding.shape[0] > 0:
625
+ chunk_embeddings.append(embedding)
626
+ else:
627
+ print(f"Warning: Failed to get valid embedding for chunk {i}. Skipping.")
628
+ chunk_embeddings.append(np.zeros(768, dtype='float32'))
629
+
630
+ if not chunk_embeddings:
631
+ raise ValueError("No valid embeddings generated. Check get_embedding function and API.")
632
+
633
+ embedding_dimension = chunk_embeddings[0].shape[0]
634
+ index = faiss.IndexFlatL2(embedding_dimension)
635
+ index.add(np.array(chunk_embeddings))
636
+
637
+ faiss.write_index(index, index_path)
638
+ with open(chunks_path, "w") as f:
639
+ json.dump(document_chunks, f)
640
+
641
+ print(f"FAISS index built and saved to {index_path}.")
642
+ print(f"Document chunks saved to {chunks_path}.")
643
+ return master_lookup, index, document_chunks, all_clean_chunk
644
+
645
+
646
+ def load_rag_assets(index_path="faiss_index.bin", chunks_path="document_chunks.json", structured_path="structured_lookup.json"):
647
+ """Loads pre-built RAG assets (FAISS index, chunks, structured lookup)."""
648
+ print("Loading RAG assets...")
649
+ master_structured_lookup = {}
650
+ if os.path.exists(structured_path):
651
+ with open(structured_path, 'r') as f:
652
+ master_structured_lookup = json.load(f)
653
+ print("Structured lookup loaded.")
654
+ else:
655
+ print("Structured lookup file not found. Rebuilding is likely needed.")
656
+
657
+ index = None
658
+ chunks = []
659
+ if os.path.exists(index_path) and os.path.exists(chunks_path):
660
+ try:
661
+ index = faiss.read_index(index_path)
662
+ with open(chunks_path, "r") as f:
663
+ chunks = json.load(f)
664
+ print("FAISS index and chunks loaded.")
665
+ except Exception as e:
666
+ print(f"Error loading FAISS index or chunks: {e}. Will rebuild.")
667
+ index = None
668
+ chunks = []
669
+ else:
670
+ print("FAISS index or chunks files not found.")
671
+
672
+ return master_structured_lookup, index, chunks
673
+ # Helper function for query_document_info
674
+ def exactInContext(text, keyword):
675
+ # try keyword_prfix
676
+ # code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
677
+ # # Attempt to parse the keyword into its prefix and numerical part using re.search
678
+ # keyword_match = code_pattern.search(keyword)
679
+ # keyword_prefix = None
680
+ # keyword_num = None
681
+ # if keyword_match:
682
+ # keyword_prefix = keyword_match.group(1).lower()
683
+ # keyword_num = int(keyword_match.group(2))
684
+ text = text.lower()
685
+ idx = text.find(keyword.lower())
686
+ if idx == -1:
687
+ # if keyword_prefix:
688
+ # idx = text.find(keyword_prefix)
689
+ # if idx == -1:
690
+ # return False
691
+ return False
692
+ return True
693
+ def chooseContextLLM(contexts, kw):
694
+ # if kw in context
695
+ for con in contexts:
696
+ context = contexts[con]
697
+ if context:
698
+ if exactInContext(context, kw):
699
+ return con, context
700
+ #if cannot find anything related to kw in context, return all output
701
+ if contexts["all_output"]:
702
+ return "all_output", contexts["all_output"]
703
+ else:
704
+ # if all_output not exist
705
+ # look of chunk and still not exist return document chunk
706
+ if contexts["chunk"]: return "chunk", contexts["chunk"]
707
+ elif contexts["document_chunk"]: return "document_chunk", contexts["document_chunk"]
708
+ else: return None, None
709
+ def clean_llm_output(llm_response_text, output_format_str):
710
+ results = []
711
+ lines = llm_response_text.strip().split('\n')
712
+ output_country, output_type, output_ethnicity, output_specific_location = [],[],[],[]
713
+ for line in lines:
714
+ extracted_country, extracted_type, extracted_ethnicity, extracted_specific_location = "unknown", "unknown", "unknown", "unknown"
715
+ line = line.strip()
716
+ if output_format_str == "ethnicity, specific_location/unknown": # Targeted RAG output
717
+ parsed_output = re.search(r'^\s*([^,]+?),\s*(.+?)\s*$', llm_response_text)
718
+ if parsed_output:
719
+ extracted_ethnicity = parsed_output.group(1).strip()
720
+ extracted_specific_location = parsed_output.group(2).strip()
721
+ else:
722
+ print(" DEBUG: LLM did not follow expected 2-field format for targeted RAG. Defaulting to unknown for ethnicity/specific_location.")
723
+ extracted_ethnicity = 'unknown'
724
+ extracted_specific_location = 'unknown'
725
+ elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown":
726
+ parsed_output = re.search(r'^\s*([^,]+?),\s*([^,]+?),\s*(.+?)\s*$', llm_response_text)
727
+ if parsed_output:
728
+ extracted_type = parsed_output.group(1).strip()
729
+ extracted_ethnicity = parsed_output.group(2).strip()
730
+ extracted_specific_location = parsed_output.group(3).strip()
731
+ else:
732
+ # Fallback: check if only 2 fields
733
+ parsed_output_2_fields = re.search(r'^\s*([^,]+?),\s*([^,]+?)\s*$', llm_response_text)
734
+ if parsed_output_2_fields:
735
+ extracted_type = parsed_output_2_fields.group(1).strip()
736
+ extracted_ethnicity = parsed_output_2_fields.group(2).strip()
737
+ extracted_specific_location = 'unknown'
738
+ else:
739
+ # even simpler fallback: 1 field only
740
+ parsed_output_1_field = re.search(r'^\s*([^,]+?)\s*$', llm_response_text)
741
+ if parsed_output_1_field:
742
+ extracted_type = parsed_output_1_field.group(1).strip()
743
+ extracted_ethnicity = 'unknown'
744
+ extracted_specific_location = 'unknown'
745
+ else:
746
+ print(" DEBUG: LLM did not follow any expected simplified format. Attempting verbose parsing fallback.")
747
+ type_match_fallback = re.search(r'Type:\s*([A-Za-z\s-]+)', llm_response_text)
748
+ extracted_type = type_match_fallback.group(1).strip() if type_match_fallback else 'unknown'
749
+ extracted_ethnicity = 'unknown'
750
+ extracted_specific_location = 'unknown'
751
+ else:
752
+ parsed_output = re.search(r'^\s*([^,]+?),\s*([^,]+?),\s*([^,]+?),\s*(.+?)\s*$', line)
753
+ if parsed_output:
754
+ extracted_country = parsed_output.group(1).strip()
755
+ extracted_type = parsed_output.group(2).strip()
756
+ extracted_ethnicity = parsed_output.group(3).strip()
757
+ extracted_specific_location = parsed_output.group(4).strip()
758
+ else:
759
+ print(f" DEBUG: Line did not follow expected 4-field format: {line}")
760
+ parsed_output_2_fields = re.search(r'^\s*([^,]+?),\s*([^,]+?)\s*$', line)
761
+ if parsed_output_2_fields:
762
+ extracted_country = parsed_output_2_fields.group(1).strip()
763
+ extracted_type = parsed_output_2_fields.group(2).strip()
764
+ extracted_ethnicity = 'unknown'
765
+ extracted_specific_location = 'unknown'
766
+ else:
767
+ print(f" DEBUG: Fallback to verbose-style parsing: {line}")
768
+ country_match_fallback = re.search(r'Country:\s*([A-Za-z\s-]+)', line)
769
+ type_match_fallback = re.search(r'Type:\s*([A-Za-z\s-]+)', line)
770
+ extracted_country = country_match_fallback.group(1).strip() if country_match_fallback else 'unknown'
771
+ extracted_type = type_match_fallback.group(1).strip() if type_match_fallback else 'unknown'
772
+ extracted_ethnicity = 'unknown'
773
+ extracted_specific_location = 'unknown'
774
+
775
+ results.append({
776
+ "country": extracted_country,
777
+ "type": extracted_type,
778
+ "ethnicity": extracted_ethnicity,
779
+ "specific_location": extracted_specific_location
780
+ #"country_explain":extracted_country_explain,
781
+ #"type_explain": extracted_type_explain
782
+ })
783
+ # if more than 2 results
784
+ if output_format_str == "ethnicity, specific_location/unknown":
785
+ for result in results:
786
+ if result["ethnicity"] not in output_ethnicity:
787
+ output_ethnicity.append(result["ethnicity"])
788
+ if result["specific_location"] not in output_specific_location:
789
+ output_specific_location.append(result["specific_location"])
790
+ return " or ".join(output_ethnicity), " or ".join(output_specific_location)
791
+ elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown":
792
+ for result in results:
793
+ if result["type"] not in output_type:
794
+ output_type.append(result["type"])
795
+ if result["ethnicity"] not in output_ethnicity:
796
+ output_ethnicity.append(result["ethnicity"])
797
+ if result["specific_location"] not in output_specific_location:
798
+ output_specific_location.append(result["specific_location"])
799
+
800
+ return " or ".join(output_type)," or ".join(output_ethnicity), " or ".join(output_specific_location)
801
+ else:
802
+ for result in results:
803
+ if result["country"] not in output_country:
804
+ output_country.append(result["country"])
805
+ if result["type"] not in output_type:
806
+ output_type.append(result["type"])
807
+ if result["ethnicity"] not in output_ethnicity:
808
+ output_ethnicity.append(result["ethnicity"])
809
+ if result["specific_location"] not in output_specific_location:
810
+ output_specific_location.append(result["specific_location"])
811
+ return " or ".join(output_country)," or ".join(output_type)," or ".join(output_ethnicity), " or ".join(output_specific_location)
812
+
813
+ def parse_multi_sample_llm_output(raw_response: str, output_format_str):
814
+ """
815
+ Parse LLM output with possibly multiple metadata lines + shared explanations.
816
+ """
817
+ lines = [line.strip() for line in raw_response.strip().splitlines() if line.strip()]
818
+ metadata_list = []
819
+ explanation_lines = []
820
+ if output_format_str == "country_name, modern/ancient/unknown":
821
+ parts = [x.strip() for x in lines[0].split(",")]
822
+ if len(parts)==2:
823
+ metadata_list.append({
824
+ "country": parts[0],
825
+ "sample_type": parts[1]#,
826
+ #"ethnicity": parts[2],
827
+ #"location": parts[3]
828
+ })
829
+ if 1<len(lines):
830
+ line = lines[1]
831
+ if "\n" in line: line = line.split("\n")
832
+ if ". " in line: line = line.split(". ")
833
+ if isinstance(line,str): line = [line]
834
+ explanation_lines += line
835
+ elif output_format_str == "modern/ancient/unknown":
836
+ metadata_list.append({
837
+ "country": "unknown",
838
+ "sample_type": lines[0]#,
839
+ #"ethnicity": parts[2],
840
+ #"location": parts[3]
841
+ })
842
+ explanation_lines.append(lines[1])
843
+
844
+ # Assign explanations (optional) to each sample — same explanation reused
845
+ for md in metadata_list:
846
+ md["country_explanation"] = None
847
+ md["sample_type_explanation"] = None
848
+
849
+ if md["country"].lower() != "unknown" and len(explanation_lines) >= 1:
850
+ md["country_explanation"] = explanation_lines[0]
851
+
852
+ if md["sample_type"].lower() != "unknown":
853
+ if len(explanation_lines) >= 2:
854
+ md["sample_type_explanation"] = explanation_lines[1]
855
+ elif len(explanation_lines) == 1 and md["country"].lower() == "unknown":
856
+ md["sample_type_explanation"] = explanation_lines[0]
857
+ elif len(explanation_lines) == 1:
858
+ md["sample_type_explanation"] = explanation_lines[0]
859
+ return metadata_list
860
+
861
+ def merge_metadata_outputs(metadata_list):
862
+ """
863
+ Merge a list of metadata dicts into one, combining differing values with 'or'.
864
+ Assumes all dicts have the same keys.
865
+ """
866
+ if not metadata_list:
867
+ return {}
868
+
869
+ merged = {}
870
+ keys = metadata_list[0].keys()
871
+
872
+ for key in keys:
873
+ values = [md[key] for md in metadata_list if key in md]
874
+ unique_values = list(dict.fromkeys(values)) # preserve order, remove dupes
875
+ if "unknown" in unique_values:
876
+ unique_values.pop(unique_values.index("unknown"))
877
+ if len(unique_values) == 1:
878
+ merged[key] = unique_values[0]
879
+ else:
880
+ merged[key] = " or ".join(unique_values)
881
+
882
+ return merged
883
+
884
+
885
+ def query_document_info(query_word, alternative_query_word, metadata, master_structured_lookup, faiss_index, document_chunks, llm_api_function, chunk=None, all_output=None):
886
+ """
887
+ Queries the document using a hybrid approach:
888
+ 1. Local structured lookup (fast, cheap, accurate for known patterns).
889
+ 2. RAG with semantic search and LLM (general, flexible, cost-optimized).
890
+ """
891
+ if metadata:
892
+ extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = metadata["country"], metadata["specific_location"], metadata["ethnicity"], metadata["sample_type"]
893
+ extracted_col_date, extracted_iso, extracted_title, extracted_features = metadata["collection_date"], metadata["isolate"], metadata["title"], metadata["all_features"]
894
+ else:
895
+ extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = "unknown", "unknown", "unknown", "unknown"
896
+ extracted_col_date, extracted_iso, extracted_title = "unknown", "unknown", "unknown"
897
+ # --- NEW: Pre-process alternative_query_word to remove '.X' suffix if present ---
898
+ if alternative_query_word:
899
+ alternative_query_word_cleaned = alternative_query_word.split('.')[0]
900
+ else:
901
+ alternative_query_word_cleaned = alternative_query_word
902
+ country_explanation, sample_type_explanation = None, None
903
+
904
+ # Use the consolidated final_structured_entries for direct lookup
905
+ final_structured_entries = master_structured_lookup.get('final_structured_entries', {})
906
+ document_title = master_structured_lookup.get('document_title', 'Unknown Document Title') # Retrieve document title
907
+
908
+ # Default values for all extracted fields. These will be updated.
909
+ method_used = 'unknown' # Will be updated based on the method that yields a result
910
+ population_code_from_sl = 'unknown' # To pass to RAG prompt if available
911
+ total_query_cost = 0
912
+ # Attempt 1: Try primary query_word (e.g., isolate name) with structured lookup
913
+ structured_info = final_structured_entries.get(query_word.upper())
914
+ if structured_info:
915
+ if extracted_country == 'unknown':
916
+ extracted_country = structured_info['country']
917
+ if extracted_type == 'unknown':
918
+ extracted_type = structured_info['type']
919
+
920
+ # if extracted_ethnicity == 'unknown':
921
+ # extracted_ethnicity = structured_info.get('ethnicity', 'unknown') # Get ethnicity from structured lookup
922
+ # if extracted_specific_location == 'unknown':
923
+ # extracted_specific_location = structured_info.get('specific_location', 'unknown') # Get specific_location from structured lookup
924
+ population_code_from_sl = structured_info['population_code']
925
+ method_used = "structured_lookup_direct"
926
+ print(f"'{query_word}' found in structured lookup (direct match).")
927
+
928
+ # Attempt 2: Try primary query_word with heuristic range lookup if direct fails (only if not already resolved)
929
+ if method_used == 'unknown':
930
+ query_prefix, query_num_str = _parse_individual_code_parts(query_word)
931
+ if query_prefix is not None and query_num_str is not None:
932
+ try: query_num = int(query_num_str)
933
+ except ValueError: query_num = None
934
+ if query_num is not None:
935
+ query_prefix_upper = query_prefix.upper()
936
+ contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list))
937
+ pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {})
938
+ pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {})
939
+ pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {})
940
+
941
+ if query_prefix_upper in contiguous_ranges:
942
+ for start_num, end_num, pop_code_for_range in contiguous_ranges[query_prefix_upper]:
943
+ if start_num <= query_num <= end_num:
944
+ country_from_heuristic = pop_code_to_country.get(pop_code_for_range, 'unknown')
945
+ if country_from_heuristic != 'unknown':
946
+ if extracted_country == 'unknown':
947
+ extracted_country = country_from_heuristic
948
+ if extracted_type == 'unknown':
949
+ extracted_type = 'modern'
950
+ # if extracted_ethnicity == 'unknown':
951
+ # extracted_ethnicity = pop_code_to_ethnicity.get(pop_code_for_range, 'unknown')
952
+ # if extracted_specific_location == 'unknown':
953
+ # extracted_specific_location = pop_code_to_specific_loc.get(pop_code_for_range, 'unknown')
954
+ population_code_from_sl = pop_code_for_range
955
+ method_used = "structured_lookup_heuristic_range_match"
956
+ print(f"'{query_word}' not direct. Heuristic: Falls within range {query_prefix_upper}{start_num}-{query_prefix_upper}{end_num}.")
957
+ break
958
+ else:
959
+ print(f"'{query_word}' heuristic match found, but country unknown. Will fall to RAG below.")
960
+
961
+ # Attempt 3: If primary query_word failed all structured lookups, try alternative_query_word (cleaned)
962
+ if method_used == 'unknown' and alternative_query_word_cleaned and alternative_query_word_cleaned != query_word:
963
+ print(f"'{query_word}' not found in structured (or heuristic). Trying alternative '{alternative_query_word_cleaned}'.")
964
+
965
+ # Try direct lookup for alternative word
966
+ structured_info_alt = final_structured_entries.get(alternative_query_word_cleaned.upper())
967
+ if structured_info_alt:
968
+ if extracted_country == 'unknown':
969
+ extracted_country = structured_info_alt['country']
970
+ if extracted_type == 'unknown':
971
+ extracted_type = structured_info_alt['type']
972
+ # if extracted_ethnicity == 'unknown':
973
+ # extracted_ethnicity = structured_info_alt.get('ethnicity', 'unknown')
974
+ # if extracted_specific_location == 'unknown':
975
+ # extracted_specific_location = structured_info_alt.get('specific_location', 'unknown')
976
+ population_code_from_sl = structured_info_alt['population_code']
977
+ method_used = "structured_lookup_alt_direct"
978
+ print(f"Alternative '{alternative_query_word_cleaned}' found in structured lookup (direct match).")
979
+ else:
980
+ # Try heuristic lookup for alternative word
981
+ alt_prefix, alt_num_str = _parse_individual_code_parts(alternative_query_word_cleaned)
982
+ if alt_prefix is not None and alt_num_str is not None:
983
+ try: alt_num = int(alt_num_str)
984
+ except ValueError: alt_num = None
985
+ if alt_num is not None:
986
+ alt_prefix_upper = alt_prefix.upper()
987
+ contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list))
988
+ pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {})
989
+ pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {})
990
+ pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {})
991
+ if alt_prefix_upper in contiguous_ranges:
992
+ for start_num, end_num, pop_code_for_range in contiguous_ranges[alt_prefix_upper]:
993
+ if start_num <= alt_num <= end_num:
994
+ country_from_heuristic_alt = pop_code_to_country.get(pop_code_for_range, 'unknown')
995
+ if country_from_heuristic_alt != 'unknown':
996
+ if extracted_country == 'unknown':
997
+ extracted_country = country_from_heuristic_alt
998
+ if extracted_type == 'unknown':
999
+ extracted_type = 'modern'
1000
+ # if extracted_ethnicity == 'unknown':
1001
+ # extracted_ethnicity = pop_code_to_ethnicity.get(pop_code_for_range, 'unknown')
1002
+ # if extracted_specific_location == 'unknown':
1003
+ # extracted_specific_location = pop_code_to_specific_loc.get(pop_code_for_range, 'unknown')
1004
+ population_code_from_sl = pop_code_for_range
1005
+ method_used = "structured_lookup_alt_heuristic_range_match"
1006
+ break
1007
+ else:
1008
+ print(f"Alternative '{alternative_query_word_cleaned}' heuristic match found, but country unknown. Will fall to RAG below.")
1009
+
1010
+ # use the context_for_llm to detect present_ancient before using llm model
1011
+ # retrieved_chunks_text = []
1012
+ # if document_chunks:
1013
+ # for idx in range(len(document_chunks)):
1014
+ # retrieved_chunks_text.append(document_chunks[idx])
1015
+ # context_for_llm = ""
1016
+ # all_context = "\n".join(retrieved_chunks_text) #
1017
+ # listOfcontexts = {"chunk": chunk,
1018
+ # "all_output": all_output,
1019
+ # "document_chunk": all_context}
1020
+ # label, context_for_llm = chooseContextLLM(listOfcontexts, query_word)
1021
+ # if not context_for_llm:
1022
+ # label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned)
1023
+ # if not context_for_llm:
1024
+ # context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features
1025
+ # if context_for_llm:
1026
+ # extracted_type, explain = mtdna_classifier.detect_ancient_flag(context_for_llm)
1027
+ # extracted_type = extracted_type.lower()
1028
+ # sample_type_explanation = explain
1029
+ # 5. Execute RAG if needed (either full RAG or targeted RAG for missing fields)
1030
+
1031
+ # Determine if a RAG call is necessary
1032
+ run_rag = (extracted_country == 'unknown' or extracted_type == 'unknown')# or \
1033
+ #extracted_ethnicity == 'unknown' or extracted_specific_location == 'unknown')
1034
+ global_llm_model_for_counting_tokens = genai.GenerativeModel('gemini-1.5-flash-latest')
1035
+ if run_rag:
1036
+
1037
+ # Determine the phrase for LLM query
1038
+ rag_query_phrase = f"'{query_word}'"
1039
+ if alternative_query_word_cleaned and alternative_query_word_cleaned != query_word:
1040
+ rag_query_phrase += f" or its alternative word '{alternative_query_word_cleaned}'"
1041
+
1042
+ # Construct a more specific semantic query phrase for embedding if structured info is available
1043
+ semantic_query_for_embedding = rag_query_phrase # Default
1044
+ # if extracted_country != 'unknown': # If country is known from structured lookup (for targeted RAG)
1045
+ # if population_code_from_sl != 'unknown':
1046
+ # semantic_query_for_embedding = f"ethnicity and specific location for {query_word} population {population_code_from_sl} in {extracted_country}"
1047
+ # else: # If pop_code not found in structured, still use country hint
1048
+ # semantic_query_for_embedding = f"ethnicity and specific location for {query_word} in {extracted_country}"
1049
+ # print(f" DEBUG: Semantic query for embedding: '{semantic_query_for_embedding}'")
1050
+
1051
+
1052
+ # Determine fields to ask LLM for and output format based on what's known/needed
1053
+ prompt_instruction_prefix = ""
1054
+ output_format_str = ""
1055
+
1056
+ # Determine if it's a full RAG or targeted RAG scenario based on what's already extracted
1057
+ is_full_rag_scenario = True#(extracted_country == 'unknown')
1058
+
1059
+ if is_full_rag_scenario: # Full RAG scenario
1060
+ output_format_str = "country_name, modern/ancient/unknown"#, ethnicity, specific_location/unknown"
1061
+ method_used = "rag_llm"
1062
+ print(f"Proceeding to FULL RAG for {rag_query_phrase}.")
1063
+ # else: # Targeted RAG scenario (country/type already known, need ethnicity/specific_location)
1064
+ # if extracted_type == "unknown":
1065
+ # prompt_instruction_prefix = (
1066
+ # f"I already know the country is {extracted_country}. "
1067
+ # f"{f'The population code is {population_code_from_sl}. ' if population_code_from_sl != 'unknown' else ''}"
1068
+ # )
1069
+ # #output_format_str = "modern/ancient/unknown, ethnicity, specific_location/unknown"
1070
+ # output_format_str = "modern/ancient/unknown"
1071
+ # # else:
1072
+ # # prompt_instruction_prefix = (
1073
+ # # f"I already know the country is {extracted_country} and the sample type is {extracted_type}. "
1074
+ # # f"{f'The population code is {population_code_from_sl}. ' if population_code_from_sl != 'unknown' else ''}"
1075
+ # # )
1076
+ # # output_format_str = "ethnicity, specific_location/unknown"
1077
+
1078
+ # method_used = "hybrid_sl_rag"
1079
+ # print(f"Proceeding to TARGETED RAG for {rag_query_phrase}.")
1080
+
1081
+
1082
+ # Calculate embedding cost for the primary query word
1083
+ current_embedding_cost = 0
1084
+ try:
1085
+ query_embedding_vector = get_embedding(semantic_query_for_embedding, task_type="RETRIEVAL_QUERY")
1086
+ query_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(semantic_query_for_embedding).total_tokens
1087
+ current_embedding_cost += (query_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT
1088
+ print(f" DEBUG: Query embedding tokens (for '{semantic_query_for_embedding}'): {query_embedding_tokens}, cost: ${current_embedding_cost:.6f}")
1089
+
1090
+ if alternative_query_word_cleaned and alternative_query_word_cleaned != query_word:
1091
+ alt_embedding_vector = get_embedding(alternative_query_word_cleaned, task_type="RETRIEVAL_QUERY")
1092
+ alt_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(alternative_query_word_cleaned).total_tokens
1093
+ current_embedding_cost += (alt_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT
1094
+ print(f" DEBUG: Alternative query ('{alternative_query_word_cleaned}') embedding tokens: {alt_embedding_tokens}, cost: ${current_embedding_cost:.6f}")
1095
+
1096
+ except Exception as e:
1097
+ print(f"Error getting query embedding for RAG: {e}")
1098
+ return extracted_country, extracted_type, "embedding_failed", extracted_ethnicity, extracted_specific_location, total_query_cost
1099
+
1100
+ if query_embedding_vector is None or query_embedding_vector.shape[0] == 0:
1101
+ return extracted_country, extracted_type, "embedding_failed", extracted_ethnicity, extracted_specific_location, total_query_cost
1102
+
1103
+ D, I = faiss_index.search(np.array([query_embedding_vector]), 4)
1104
+
1105
+ retrieved_chunks_text = []
1106
+ for idx in I[0]:
1107
+ if 0 <= idx < len(document_chunks):
1108
+ retrieved_chunks_text.append(document_chunks[idx])
1109
+
1110
+ context_for_llm = ""
1111
+
1112
+ all_context = "\n".join(retrieved_chunks_text) #
1113
+ listOfcontexts = {"chunk": chunk,
1114
+ "all_output": all_output,
1115
+ "document_chunk": all_context}
1116
+ label, context_for_llm = chooseContextLLM(listOfcontexts, query_word)
1117
+ if not context_for_llm:
1118
+ label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned)
1119
+ if not context_for_llm:
1120
+ context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features
1121
+ #print("context for llm: ", label)
1122
+ # prompt_for_llm = (
1123
+ # f"{prompt_instruction_prefix}"
1124
+ # f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} or the mitochondrial DNA sample in general if these specific identifiers are not explicitly found. "
1125
+ # f"Identify its primary associated country/geographic location. "
1126
+ # f"Also, determine if the genetic sample or individual mentioned is from a 'modern' (present-day living individual) "
1127
+ # f"or 'ancient' (e.g., prehistoric remains, archaeological sample) source. "
1128
+ # f"If the text does not mention whether the sample is ancient or modern, assume the sample is modern unless otherwise explicitly described as ancient or archaeological. "
1129
+ # f"Additionally, extract its ethnicity and a more specific location (city/district level) within the predicted country. "
1130
+ # f"If any information is not explicitly present in the provided text snippets, state 'unknown' for that specific piece of information. "
1131
+ # f"Provide only the country, sample type, ethnicity, and specific location, do not add extra explanations.\n\n"
1132
+ # f"Text Snippets:\n{context_for_llm}\n\n"
1133
+ # f"Output Format: {output_format_str}"
1134
+ # )
1135
+ if len(context_for_llm) > 1000*1000:
1136
+ context_for_llm = context_for_llm[:900000]
1137
+ prompt_for_llm = (
1138
+ f"{prompt_instruction_prefix}"
1139
+ f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} or the mitochondrial DNA sample in general if these specific identifiers are not explicitly found. "
1140
+ f"Identify its primary associated country/geographic location. "
1141
+ f"Also, determine if the genetic sample or individual mentioned is from a 'modern' (present-day living individual) "
1142
+ f"or 'ancient' (e.g., prehistoric remains, archaeological sample) source. "
1143
+ f"If the text does not mention whether the sample is ancient or modern, assume the sample is modern unless otherwise explicitly described as ancient or archaeological. "
1144
+ f"Provide only {output_format_str}. "
1145
+ f"If any information is not explicitly present in the provided text snippets, state 'unknown' for that specific piece of information. "
1146
+ f"If the country or sample type (modern/ancient) is not 'unknown', write 1 sentence after the output explaining how you inferred it from the text (one sentence for each)."
1147
+ f"\n\nText Snippets:\n{context_for_llm}\n\n"
1148
+ f"Output Format: {output_format_str}"
1149
+ )
1150
+
1151
+ llm_response_text, model_instance = call_llm_api(prompt_for_llm)
1152
+ print("\n--- DEBUG INFO FOR RAG ---")
1153
+ print("Retrieved Context Sent to LLM (first 500 chars):")
1154
+ print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm)
1155
+ print("\nRaw LLM Response:")
1156
+ print(llm_response_text)
1157
+ print("--- END DEBUG INFO ---")
1158
+
1159
+ llm_cost = 0
1160
+ if model_instance:
1161
+ try:
1162
+ input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
1163
+ output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
1164
+ print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
1165
+ print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
1166
+ llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1167
+ (output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1168
+ print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
1169
+ except Exception as e:
1170
+ print(f" DEBUG: Error counting LLM tokens: {e}")
1171
+ llm_cost = 0
1172
+
1173
+ total_query_cost += current_embedding_cost + llm_cost
1174
+ print(f" DEBUG: Total estimated cost for this RAG query: ${total_query_cost:.6f}")
1175
+ # Parse the LLM's response based on the Output Format actually used
1176
+ # if output_format_str == "ethnicity, specific_location/unknown": # Targeted RAG output
1177
+ # extracted_ethnicity,extracted_specific_location = clean_llm_output(llm_response_text, output_format_str)
1178
+ # elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown":
1179
+ # extracted_type, extracted_ethnicity,extracted_specific_location=clean_llm_output(llm_response_text, output_format_str)
1180
+ # else: # Full RAG output (country, type, ethnicity, specific_location)
1181
+ # extracted_country,extracted_type, extracted_ethnicity,extracted_specific_location=clean_llm_output(llm_response_text, output_format_str)
1182
+ metadata_list = parse_multi_sample_llm_output(llm_response_text, output_format_str)
1183
+ merge_metadata = merge_metadata_outputs(metadata_list)
1184
+ if output_format_str == "country_name, modern/ancient/unknown":
1185
+ extracted_country, extracted_type = merge_metadata["country"], merge_metadata["sample_type"]
1186
+ country_explanation,sample_type_explanation = merge_metadata["country_explanation"], merge_metadata["sample_type_explanation"]
1187
+ elif output_format_str == "modern/ancient/unknown":
1188
+ extracted_type = merge_metadata["sample_type"]
1189
+ sample_type_explanation = merge_metadata["sample_type_explanation"]
1190
+ # 6. Optional: Second LLM call for specific_location from general knowledge if still unknown
1191
+ # if extracted_specific_location == 'unknown':
1192
+ # # Check if we have enough info to ask general knowledge LLM
1193
+ # if extracted_country != 'unknown' and extracted_ethnicity != 'unknown':
1194
+ # print(f" DEBUG: Specific location still unknown. Querying general knowledge LLM from '{extracted_ethnicity}' and '{extracted_country}'.")
1195
+
1196
+ # general_knowledge_prompt = (
1197
+ # f"Based on general knowledge, what is a highly specific location (city or district) "
1198
+ # f"associated with the ethnicity '{extracted_ethnicity}' in '{extracted_country}'? "
1199
+ # f"Consider the context of scientific studies on human genetics, if known. "
1200
+ # f"If no common specific location is known, state 'unknown'. "
1201
+ # f"Provide only the city or district name, or 'unknown'."
1202
+ # )
1203
+
1204
+ # general_llm_response, general_llm_model_instance = call_llm_api(general_knowledge_prompt, model_name='gemini-1.5-flash-latest')
1205
+
1206
+ # if general_llm_response and general_llm_response.lower().strip() != 'unknown':
1207
+ # extracted_specific_location = general_llm_response.strip() + " (predicted from general knowledge)"
1208
+ # # Add cost of this second LLM call
1209
+ # if general_llm_model_instance:
1210
+ # try:
1211
+ # gk_input_tokens = general_llm_model_instance.count_tokens(general_knowledge_prompt).total_tokens
1212
+ # gk_output_tokens = general_llm_model_instance.count_tokens(general_llm_response).total_tokens
1213
+ # gk_cost = (gk_input_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1214
+ # (gk_output_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1215
+ # print(f" DEBUG: General Knowledge LLM cost to predict specific location alone: ${gk_cost:.6f}")
1216
+ # total_query_cost += gk_cost # Accumulate cost
1217
+ # except Exception as e:
1218
+ # print(f" DEBUG: Error counting GK LLM tokens: {e}")
1219
+ # else:
1220
+ # print(" DEBUG: General knowledge LLM returned unknown or empty for specific location.")
1221
+ # # 6. Optional: Second LLM call for ethnicity from general knowledge if still unknown
1222
+ # if extracted_ethnicity == 'unknown':
1223
+ # # Check if we have enough info to ask general knowledge LLM
1224
+ # if extracted_country != 'unknown' and extracted_specific_location != 'unknown':
1225
+ # print(f" DEBUG: Ethnicity still unknown. Querying general knowledge LLM from '{extracted_specific_location}' and '{extracted_country}'.")
1226
+
1227
+ # general_knowledge_prompt = (
1228
+ # f"Based on general knowledge, what is a highly ethnicity (population) "
1229
+ # f"associated with the specific location '{extracted_specific_location}' in '{extracted_country}'? "
1230
+ # f"Consider the context of scientific studies on human genetics, if known. "
1231
+ # f"If no common ethnicity is known, state 'unknown'. "
1232
+ # f"Provide only the ethnicity or popluation name, or 'unknown'."
1233
+ # )
1234
+
1235
+ # general_llm_response, general_llm_model_instance = call_llm_api(general_knowledge_prompt, model_name='gemini-1.5-flash-latest')
1236
+
1237
+ # if general_llm_response and general_llm_response.lower().strip() != 'unknown':
1238
+ # extracted_ethnicity = general_llm_response.strip() + " (predicted from general knowledge)"
1239
+ # # Add cost of this second LLM call
1240
+ # if general_llm_model_instance:
1241
+ # try:
1242
+ # gk_input_tokens = general_llm_model_instance.count_tokens(general_knowledge_prompt).total_tokens
1243
+ # gk_output_tokens = general_llm_model_instance.count_tokens(general_llm_response).total_tokens
1244
+ # gk_cost = (gk_input_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
1245
+ # (gk_output_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
1246
+ # print(f" DEBUG: General Knowledge LLM cost to predict ethnicity alone: ${gk_cost:.6f}")
1247
+ # total_query_cost += gk_cost # Accumulate cost
1248
+ # except Exception as e:
1249
+ # print(f" DEBUG: Error counting GK LLM tokens: {e}")
1250
+ # else:
1251
+ # print(" DEBUG: General knowledge LLM returned unknown or empty for ethnicity.")
1252
+
1253
+
1254
+ #return extracted_country, extracted_type, method_used, extracted_ethnicity, extracted_specific_location, total_query_cost
1255
+ return extracted_country, extracted_type, method_used, country_explanation, sample_type_explanation, total_query_cost
mtdna_backend.py CHANGED
@@ -3,7 +3,9 @@ from collections import Counter
3
  import csv
4
  import os
5
  from functools import lru_cache
 
6
  from mtdna_classifier import classify_sample_location
 
7
  import subprocess
8
  import json
9
  import pandas as pd
@@ -13,41 +15,47 @@ import tempfile
13
  import gspread
14
  from oauth2client.service_account import ServiceAccountCredentials
15
  from io import StringIO
 
 
16
 
17
- @lru_cache(maxsize=128)
18
- def classify_sample_location_cached(accession):
19
- return classify_sample_location(accession)
 
 
 
 
20
 
21
  # Count and suggest final location
22
- def compute_final_suggested_location(rows):
23
- candidates = [
24
- row.get("Predicted Location", "").strip()
25
- for row in rows
26
- if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
27
- ] + [
28
- row.get("Inferred Region", "").strip()
29
- for row in rows
30
- if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
31
- ]
32
-
33
- if not candidates:
34
- return Counter(), ("Unknown", 0)
35
- # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
36
- tokens = []
37
- for item in candidates:
38
- # Split by comma, whitespace, and newlines
39
- parts = re.split(r'[\s,]+', item)
40
- tokens.extend(parts)
41
-
42
- # Step 2: Clean and normalize tokens
43
- tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
44
-
45
- # Step 3: Count
46
- counts = Counter(tokens)
47
-
48
- # Step 4: Get most common
49
- top_location, count = counts.most_common(1)[0]
50
- return counts, (top_location, count)
51
 
52
  # Store feedback (with required fields)
53
 
@@ -100,74 +108,216 @@ def extract_accessions_from_input(file=None, raw_text=""):
100
  seen.add(acc)
101
 
102
  return list(accessions), None
 
 
 
 
 
 
 
103
 
104
- def summarize_results(accession):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  try:
106
- output, labelAncient_Modern, explain_label = classify_sample_location_cached(accession)
107
- #print(output)
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- return [], f"Error: {e}", f"Error: {e}", f"Error: {e}"
110
 
111
- if accession not in output:
112
- return [], "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
113
 
114
- isolate = next((k for k in output if k != accession), None)
115
  row_score = []
116
  rows = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- for key in [accession, isolate]:
119
- if key not in output:
120
- continue
121
- sample_id_label = f"{key} ({'accession number' if key == accession else 'isolate of accession'})"
122
- for section, techniques in output[key].items():
123
- for technique, content in techniques.items():
124
- source = content.get("source", "")
125
- predicted = content.get("predicted_location", "")
126
- haplogroup = content.get("haplogroup", "")
127
- inferred = content.get("inferred_location", "")
128
- context = content.get("context_snippet", "")[:300] if "context_snippet" in content else ""
129
-
130
- row = {
131
- "Sample ID": sample_id_label,
132
- "Technique": technique,
133
- "Source": f"The region of haplogroup is inferred\nby using this source: {source}" if technique == "haplogroup" else source,
134
- "Predicted Location": "" if technique == "haplogroup" else predicted,
135
- "Haplogroup": haplogroup if technique == "haplogroup" else "",
136
- "Inferred Region": inferred if technique == "haplogroup" else "",
137
- "Context Snippet": context
138
- }
139
-
140
- row_score.append(row)
141
- rows.append(list(row.values()))
142
-
143
- location_counts, (final_location, count) = compute_final_suggested_location(row_score)
144
- summary_lines = [f"### 🧭 Location Frequency Summary", "After counting all predicted and inferred locations:\n"]
145
- summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
146
- summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
147
- summary = "\n".join(summary_lines)
148
- return rows, summary, labelAncient_Modern, explain_label
149
 
150
  # save the batch input in excel file
151
- def save_to_excel(all_rows, summary_text, flag_text, filename):
152
- with pd.ExcelWriter(filename) as writer:
153
- # Save table
154
- df = pd.DataFrame(all_rows, columns=["Sample ID", "Technique", "Source", "Predicted Location", "Haplogroup", "Inferred Region", "Context Snippet"])
155
- df.to_excel(writer, sheet_name="Detailed Results", index=False)
156
-
157
- # Save summary
158
- summary_df = pd.DataFrame({"Summary": [summary_text]})
159
- summary_df.to_excel(writer, sheet_name="Summary", index=False)
 
 
 
 
 
 
 
 
 
 
160
 
161
- # Save flag
162
- flag_df = pd.DataFrame({"Flag": [flag_text]})
163
- flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # save the batch input in JSON file
166
  def save_to_json(all_rows, summary_text, flag_text, filename):
167
  output_dict = {
168
- "Detailed_Results": all_rows, # <-- make sure this is a plain list, not a DataFrame
169
- "Summary_Text": summary_text,
170
- "Ancient_Modern_Flag": flag_text
171
  }
172
 
173
  # If all_rows is a DataFrame, convert it
@@ -189,13 +339,13 @@ def save_to_txt(all_rows, summary_text, flag_text, filename):
189
  f.write("=== Detailed Results ===\n")
190
  f.write(output + "\n")
191
 
192
- f.write("\n=== Summary ===\n")
193
- f.write(summary_text + "\n")
194
 
195
- f.write("\n=== Ancient/Modern Flag ===\n")
196
- f.write(flag_text + "\n")
197
 
198
- def save_batch_output(all_rows, summary_text, flag_text, output_type):
199
  tmp_dir = tempfile.mkdtemp()
200
 
201
  #html_table = all_rows.value # assuming this is stored somewhere
@@ -219,34 +369,142 @@ def save_batch_output(all_rows, summary_text, flag_text, output_type):
219
  return gr.update(visible=False) # invalid option
220
 
221
  return gr.update(value=file_path, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  # run the batch
224
- def summarize_batch(file=None, raw_text=""):
 
 
 
 
225
  accessions, error = extract_accessions_from_input(file, raw_text)
226
  if error:
227
- return [], "", "", f"Error: {error}"
 
 
 
 
 
 
 
 
 
228
 
229
  all_rows = []
230
- all_summaries = []
231
- all_flags = []
232
-
233
- for acc in accessions:
 
 
 
 
 
 
 
 
 
 
 
 
234
  try:
235
- rows, summary, label, explain = summarize_results(acc)
 
236
  all_rows.extend(rows)
237
- all_summaries.append(f"**{acc}**\n{summary}")
238
- all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
 
 
 
 
 
 
239
  except Exception as e:
240
- all_summaries.append(f"**{acc}**: Failed - {e}")
241
-
 
 
 
242
  """for row in all_rows:
243
  source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
244
 
245
  if source_column.startswith("http"): # Check if the source is a URL
246
  # Wrap it with HTML anchor tags to make it clickable
247
  row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
248
-
249
-
250
- summary_text = "\n\n---\n\n".join(all_summaries)
251
- flag_text = "\n\n---\n\n".join(all_flags)
252
- return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
3
  import csv
4
  import os
5
  from functools import lru_cache
6
+ import mtdna_ui_app
7
  from mtdna_classifier import classify_sample_location
8
+ from iterate3 import data_preprocess, model, pipeline
9
  import subprocess
10
  import json
11
  import pandas as pd
 
15
  import gspread
16
  from oauth2client.service_account import ServiceAccountCredentials
17
  from io import StringIO
18
+ import hashlib
19
+ import threading
20
 
21
+ # @lru_cache(maxsize=3600)
22
+ # def classify_sample_location_cached(accession):
23
+ # return classify_sample_location(accession)
24
+
25
+ @lru_cache(maxsize=3600)
26
+ def pipeline_classify_sample_location_cached(accession):
27
+ return pipeline.pipeline_with_gemini([accession])
28
 
29
  # Count and suggest final location
30
+ # def compute_final_suggested_location(rows):
31
+ # candidates = [
32
+ # row.get("Predicted Location", "").strip()
33
+ # for row in rows
34
+ # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
35
+ # ] + [
36
+ # row.get("Inferred Region", "").strip()
37
+ # for row in rows
38
+ # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
39
+ # ]
40
+
41
+ # if not candidates:
42
+ # return Counter(), ("Unknown", 0)
43
+ # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
44
+ # tokens = []
45
+ # for item in candidates:
46
+ # # Split by comma, whitespace, and newlines
47
+ # parts = re.split(r'[\s,]+', item)
48
+ # tokens.extend(parts)
49
+
50
+ # # Step 2: Clean and normalize tokens
51
+ # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
52
+
53
+ # # Step 3: Count
54
+ # counts = Counter(tokens)
55
+
56
+ # # Step 4: Get most common
57
+ # top_location, count = counts.most_common(1)[0]
58
+ # return counts, (top_location, count)
59
 
60
  # Store feedback (with required fields)
61
 
 
108
  seen.add(acc)
109
 
110
  return list(accessions), None
111
+ # ✅ Add a new helper to backend: `filter_unprocessed_accessions()`
112
+ def get_incomplete_accessions(file_path):
113
+ df = pd.read_excel(file_path)
114
+
115
+ incomplete_accessions = []
116
+ for _, row in df.iterrows():
117
+ sample_id = str(row.get("Sample ID", "")).strip()
118
 
119
+ # Skip if no sample ID
120
+ if not sample_id:
121
+ continue
122
+
123
+ # Drop the Sample ID and check if the rest is empty
124
+ other_cols = row.drop(labels=["Sample ID"], errors="ignore")
125
+ if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
126
+ # Extract the accession number from the sample ID using regex
127
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
128
+ if match:
129
+ incomplete_accessions.append(match.group(0))
130
+ print(len(incomplete_accessions))
131
+ return incomplete_accessions
132
+
133
+ def summarize_results(accession, KNOWN_OUTPUT_PATH = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/known_samples.xlsx"):
134
+ # try cache first
135
+ cached = check_known_output(accession)
136
+ if cached:
137
+ print(f"✅ Using cached result for {accession}")
138
+ return [[
139
+ cached["Sample ID"],
140
+ cached["Predicted Country"],
141
+ cached["Country Explanation"],
142
+ cached["Predicted Sample Type"],
143
+ cached["Sample Type Explanation"],
144
+ cached["Sources"],
145
+ cached["Time cost"]
146
+ ]]
147
+ # only run when nothing in the cache
148
  try:
149
+ outputs = pipeline_classify_sample_location_cached(accession)
150
+ # outputs = {'KU131308': {'isolate':'BRU18',
151
+ # 'country': {'brunei': ['ncbi',
152
+ # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
153
+ # 'sample_type': {'modern':
154
+ # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
155
+ # 'query_cost': 9.754999999999999e-05,
156
+ # 'time_cost': '24.776 seconds',
157
+ # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
158
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
159
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
160
  except Exception as e:
161
+ return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
162
 
163
+ if accession not in outputs:
164
+ return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
165
 
 
166
  row_score = []
167
  rows = []
168
+ save_rows = []
169
+ for key in outputs:
170
+ pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
171
+ for section, results in outputs[key].items():
172
+ if section == "country" or section =="sample_type":
173
+ pred_output = "\n".join(list(results.keys()))
174
+ output_explanation = ""
175
+ for result, content in results.items():
176
+ if len(result) == 0: result = "unknown"
177
+ if len(content) == 0: output_explanation = "unknown"
178
+ else:
179
+ output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
180
+ if section == "country":
181
+ pred_country, country_explanation = pred_output, output_explanation
182
+ elif section == "sample_type":
183
+ pred_sample, sample_explanation = pred_output, output_explanation
184
+ if outputs[key]["isolate"].lower()!="unknown":
185
+ label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
186
+ else: label = key
187
+ if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
188
+ row = {
189
+ "Sample ID": label,
190
+ "Predicted Country": pred_country,
191
+ "Country Explanation": country_explanation,
192
+ "Predicted Sample Type":pred_sample,
193
+ "Sample Type Explanation":sample_explanation,
194
+ "Sources": "\n".join(outputs[key]["source"]),
195
+ "Time cost": outputs[key]["time_cost"]
196
+ }
197
+ #row_score.append(row)
198
+ rows.append(list(row.values()))
199
+
200
+ save_row = {
201
+ "Sample ID": label,
202
+ "Predicted Country": pred_country,
203
+ "Country Explanation": country_explanation,
204
+ "Predicted Sample Type":pred_sample,
205
+ "Sample Type Explanation":sample_explanation,
206
+ "Sources": "\n".join(outputs[key]["source"]),
207
+ "Query_cost": outputs[key]["query_cost"],
208
+ "Time cost": outputs[key]["time_cost"]
209
+ }
210
+ #row_score.append(row)
211
+ save_rows.append(list(save_row.values()))
212
+
213
+ # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
214
+ # summary_lines = [f"### 🧭 Location Summary:\n"]
215
+ # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
216
+ # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
217
+ # summary = "\n".join(summary_lines)
218
+
219
+ # save the new running sample to known excel file
220
+ try:
221
+ df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
222
+ if os.path.exists(KNOWN_OUTPUT_PATH):
223
+ df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
224
+ df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
225
+ else:
226
+ df_combined = df_new
227
+ df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
228
+ except Exception as e:
229
+ print(f"⚠️ Failed to save known output: {e}")
230
 
231
+ return rows#, summary, labelAncient_Modern, explain_label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  # save the batch input in excel file
234
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
235
+ # with pd.ExcelWriter(filename) as writer:
236
+ # # Save table
237
+ # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
238
+ # df.to_excel(writer, sheet_name="Detailed Results", index=False)
239
+ # try:
240
+ # df_old = pd.read_excel(filename)
241
+ # except:
242
+ # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
243
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
244
+ # # if os.path.exists(filename):
245
+ # # df_old = pd.read_excel(filename)
246
+ # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
247
+ # # else:
248
+ # # df_combined = df_new
249
+ # df_combined.to_excel(filename, index=False)
250
+ # # # Save summary
251
+ # # summary_df = pd.DataFrame({"Summary": [summary_text]})
252
+ # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
253
 
254
+ # # # Save flag
255
+ # # flag_df = pd.DataFrame({"Flag": [flag_text]})
256
+ # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
257
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
258
+ # df_new = pd.DataFrame(all_rows, columns=[
259
+ # "Sample ID", "Predicted Country", "Country Explanation",
260
+ # "Predicted Sample Type", "Sample Type Explanation",
261
+ # "Sources", "Time cost"
262
+ # ])
263
+
264
+ # try:
265
+ # if os.path.exists(filename):
266
+ # df_old = pd.read_excel(filename)
267
+ # else:
268
+ # df_old = pd.DataFrame(columns=df_new.columns)
269
+ # except Exception as e:
270
+ # print(f"⚠️ Warning reading old Excel file: {e}")
271
+ # df_old = pd.DataFrame(columns=df_new.columns)
272
+
273
+ # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
274
+ # df_old.set_index("Sample ID", inplace=True)
275
+ # df_new.set_index("Sample ID", inplace=True)
276
+
277
+ # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
278
+
279
+ # df_combined = df_old.reset_index()
280
+
281
+ # try:
282
+ # df_combined.to_excel(filename, index=False)
283
+ # except Exception as e:
284
+ # print(f"❌ Failed to write Excel file {filename}: {e}")
285
+ def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
286
+ df_new = pd.DataFrame(all_rows, columns=[
287
+ "Sample ID", "Predicted Country", "Country Explanation",
288
+ "Predicted Sample Type", "Sample Type Explanation",
289
+ "Sources", "Time cost"
290
+ ])
291
+
292
+ if is_resume and os.path.exists(filename):
293
+ try:
294
+ df_old = pd.read_excel(filename)
295
+ except Exception as e:
296
+ print(f"⚠️ Warning reading old Excel file: {e}")
297
+ df_old = pd.DataFrame(columns=df_new.columns)
298
+
299
+ # Set index and update existing rows
300
+ df_old.set_index("Sample ID", inplace=True)
301
+ df_new.set_index("Sample ID", inplace=True)
302
+ df_old.update(df_new)
303
+
304
+ df_combined = df_old.reset_index()
305
+ else:
306
+ # If not resuming or file doesn't exist, just use new rows
307
+ df_combined = df_new
308
+
309
+ try:
310
+ df_combined.to_excel(filename, index=False)
311
+ except Exception as e:
312
+ print(f"❌ Failed to write Excel file {filename}: {e}")
313
+
314
 
315
  # save the batch input in JSON file
316
  def save_to_json(all_rows, summary_text, flag_text, filename):
317
  output_dict = {
318
+ "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
319
+ # "Summary_Text": summary_text,
320
+ # "Ancient_Modern_Flag": flag_text
321
  }
322
 
323
  # If all_rows is a DataFrame, convert it
 
339
  f.write("=== Detailed Results ===\n")
340
  f.write(output + "\n")
341
 
342
+ # f.write("\n=== Summary ===\n")
343
+ # f.write(summary_text + "\n")
344
 
345
+ # f.write("\n=== Ancient/Modern Flag ===\n")
346
+ # f.write(flag_text + "\n")
347
 
348
+ def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
349
  tmp_dir = tempfile.mkdtemp()
350
 
351
  #html_table = all_rows.value # assuming this is stored somewhere
 
369
  return gr.update(visible=False) # invalid option
370
 
371
  return gr.update(value=file_path, visible=True)
372
+ # save cost by checking the known outputs
373
+
374
+ def check_known_output(accession, KNOWN_OUTPUT_PATH = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/known_samples.xlsx"):
375
+ if not os.path.exists(KNOWN_OUTPUT_PATH):
376
+ return None
377
+
378
+ try:
379
+ df = pd.read_excel(KNOWN_OUTPUT_PATH)
380
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
381
+ if match:
382
+ accession = match.group(0)
383
+
384
+ matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
385
+ if not matched.empty:
386
+ return matched.iloc[0].to_dict() # Return the cached row
387
+ except Exception as e:
388
+ print(f"⚠️ Failed to load known samples: {e}")
389
+ return None
390
+
391
+ USER_USAGE_TRACK_FILE = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/iterate3/user_usage_log.json"
392
+
393
+ def hash_user_id(user_input):
394
+ return hashlib.sha256(user_input.encode()).hexdigest()
395
+
396
+ # ✅ Load and save usage count
397
+
398
+ # def load_user_usage():
399
+ # if os.path.exists(USER_USAGE_TRACK_FILE):
400
+ # with open(USER_USAGE_TRACK_FILE, "r") as f:
401
+ # return json.load(f)
402
+ # return {}
403
+
404
+ def load_user_usage():
405
+ if not os.path.exists(USER_USAGE_TRACK_FILE):
406
+ return {}
407
+
408
+ try:
409
+ with open(USER_USAGE_TRACK_FILE, "r") as f:
410
+ content = f.read().strip()
411
+ if not content:
412
+ return {} # file is empty
413
+ return json.loads(content)
414
+ except (json.JSONDecodeError, ValueError):
415
+ print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
416
+ return {} # fallback to empty dict
417
+
418
+
419
+ def save_user_usage(usage):
420
+ with open(USER_USAGE_TRACK_FILE, "w") as f:
421
+ json.dump(usage, f, indent=2)
422
+
423
+ # def increment_usage(user_id, num_samples=1):
424
+ # usage = load_user_usage()
425
+ # if user_id not in usage:
426
+ # usage[user_id] = 0
427
+ # usage[user_id] += num_samples
428
+ # save_user_usage(usage)
429
+ # return usage[user_id]
430
+ def increment_usage(email: str, count: int):
431
+ usage = load_user_usage()
432
+ email_key = email.strip().lower()
433
+ usage[email_key] = usage.get(email_key, 0) + count
434
+ save_user_usage(usage)
435
+ return usage[email_key]
436
 
437
  # run the batch
438
+ def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
439
+ stop_flag=None, output_file_path=None,
440
+ limited_acc=50, yield_callback=None):
441
+ if user_email:
442
+ limited_acc += 10
443
  accessions, error = extract_accessions_from_input(file, raw_text)
444
  if error:
445
+ #return [], "", "", f"Error: {error}"
446
+ return [], f"Error: {error}", 0, "", ""
447
+ if resume_file:
448
+ accessions = get_incomplete_accessions(resume_file)
449
+ tmp_dir = tempfile.mkdtemp()
450
+ if not output_file_path:
451
+ if resume_file:
452
+ output_file_path = os.path.join(tmp_dir, resume_file)
453
+ else:
454
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
455
 
456
  all_rows = []
457
+ # all_summaries = []
458
+ # all_flags = []
459
+ progress_lines = []
460
+ warning = ""
461
+ if len(accessions) > limited_acc:
462
+ accessions = accessions[:limited_acc]
463
+ warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
464
+ for i, acc in enumerate(accessions):
465
+ if stop_flag and stop_flag.value:
466
+ line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
467
+ progress_lines.append(line)
468
+ if yield_callback:
469
+ yield_callback(line)
470
+ print("🛑 User requested stop.")
471
+ break
472
+ print(f"[{i+1}/{len(accessions)}] Processing {acc}")
473
  try:
474
+ # rows, summary, label, explain = summarize_results(acc)
475
+ rows = summarize_results(acc)
476
  all_rows.extend(rows)
477
+ # all_summaries.append(f"**{acc}**\n{summary}")
478
+ # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
479
+ #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
480
+ save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
481
+ line = f"✅ Processed {acc} ({i+1}/{len(accessions)})"
482
+ progress_lines.append(line)
483
+ if yield_callback:
484
+ yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
485
  except Exception as e:
486
+ print(f" Failed to process {acc}: {e}")
487
+ continue
488
+ #all_summaries.append(f"**{acc}**: Failed - {e}")
489
+ #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
490
+ limited_acc -= 1
491
  """for row in all_rows:
492
  source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
493
 
494
  if source_column.startswith("http"): # Check if the source is a URL
495
  # Wrap it with HTML anchor tags to make it clickable
496
  row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
497
+ if not warning:
498
+ warning = f"You only have {limited_acc} left"
499
+ if user_email.strip():
500
+ user_hash = hash_user_id(user_email)
501
+ total_queries = increment_usage(user_hash, len(all_rows))
502
+ else:
503
+ total_queries = 0
504
+ yield_callback("✅ Finished!")
505
+
506
+ # summary_text = "\n\n---\n\n".join(all_summaries)
507
+ # flag_text = "\n\n---\n\n".join(all_flags)
508
+ #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
509
+ #return all_rows, gr.update(visible=True), gr.update(visible=False)
510
+ return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
mtdna_classifier.py CHANGED
@@ -1,524 +1,707 @@
1
- # mtDNA Location Classifier MVP (Google Colab)
2
- # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
- import os
4
- import subprocess
5
- import re
6
- from Bio import Entrez
7
- import fitz
8
- import spacy
9
- from spacy.cli import download
10
- from NER.PDF import pdf
11
- from NER.WordDoc import wordDoc
12
- from NER.html import extractHTML
13
- from NER.word2Vec import word2vec
14
- from transformers import pipeline
15
- import urllib.parse, requests
16
- from pathlib import Path
17
- from upgradeClassify import filter_context_for_sample, infer_location_for_sample
18
- # Set your email (required by NCBI Entrez)
19
- #Entrez.email = "[email protected]"
20
- import nltk
21
-
22
- nltk.download("stopwords")
23
- #nltk.download("punkt")
24
- nltk.download('punkt', download_dir='/home/user/nltk_data')
25
-
26
- nltk.download('punkt_tab')
27
- # Step 1: Get PubMed ID from Accession using EDirect
28
-
29
- '''def get_info_from_accession(accession):
30
- cmd = f'{os.environ["HOME"]}/edirect/esummary -db nuccore -id {accession} -format medline | egrep "PUBMED|isolate"'
31
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
32
- output = result.stdout
33
- pubmedID, isolate = "", ""
34
- for line in output.split("\n"):
35
- if len(line) > 0:
36
- if "PUBMED" in line:
37
- pubmedID = line.split()[-1]
38
- if "isolate" in line: # Check for isolate information
39
- # Try direct GenBank annotation: /isolate="XXX"
40
- match1 = re.search(r'/isolate\s*=\s*"([^"]+)"', line) # search on current line
41
- if match1:
42
- isolate = match1.group(1)
43
- else:
44
- # Try from DEFINITION line: ...isolate XXX...
45
- match2 = re.search(r'isolate\s+([A-Za-z0-9_-]+)', line) # search on current line
46
- if match2:
47
- isolate = match2.group(1)'''
48
- from Bio import Entrez, Medline
49
- import re
50
-
51
- Entrez.email = "[email protected]"
52
-
53
- def get_info_from_accession(accession):
54
- try:
55
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
56
- text = handle.read()
57
- handle.close()
58
-
59
- # Extract PUBMED ID from the Medline text
60
- pubmed_match = re.search(r'PUBMED\s+(\d+)', text)
61
- pubmed_id = pubmed_match.group(1) if pubmed_match else ""
62
-
63
- # Extract isolate if available
64
- isolate_match = re.search(r'/isolate="([^"]+)"', text)
65
- if not isolate_match:
66
- isolate_match = re.search(r'isolate\s+([A-Za-z0-9_-]+)', text)
67
- isolate = isolate_match.group(1) if isolate_match else ""
68
-
69
- if not pubmed_id:
70
- print(f"⚠️ No PubMed ID found for accession {accession}")
71
-
72
- return pubmed_id, isolate
73
-
74
- except Exception as e:
75
- print(" Entrez error:", e)
76
- return "", ""
77
- # Step 2: Get doi link to access the paper
78
- '''def get_doi_from_pubmed_id(pubmed_id):
79
- cmd = f'{os.environ["HOME"]}/edirect/esummary -db pubmed -id {pubmed_id} -format medline | grep -i "AID"'
80
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
81
- output = result.stdout
82
-
83
- doi_pattern = r'10\.\d{4,9}/[-._;()/:A-Z0-9]+(?=\s*\[doi\])'
84
- match = re.search(doi_pattern, output, re.IGNORECASE)
85
-
86
- if match:
87
- return match.group(0)
88
- else:
89
- return None # or raise an Exception with a helpful message'''
90
-
91
- def get_doi_from_pubmed_id(pubmed_id):
92
- try:
93
- handle = Entrez.efetch(db="pubmed", id=pubmed_id, rettype="medline", retmode="text")
94
- records = list(Medline.parse(handle))
95
- handle.close()
96
-
97
- if not records:
98
- return None
99
-
100
- record = records[0]
101
- if "AID" in record:
102
- for aid in record["AID"]:
103
- if "[doi]" in aid:
104
- return aid.split(" ")[0] # extract the DOI
105
-
106
- return None
107
-
108
- except Exception as e:
109
- print(f"❌ Failed to get DOI from PubMed ID {pubmed_id}: {e}")
110
- return None
111
-
112
-
113
- # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
114
- # Step 3.1: Extract Text
115
- # sub: download excel file
116
- def download_excel_file(url, save_path="temp.xlsx"):
117
- if "view.officeapps.live.com" in url:
118
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
119
- real_url = urllib.parse.unquote(parsed_url["src"][0])
120
- response = requests.get(real_url)
121
- with open(save_path, "wb") as f:
122
- f.write(response.content)
123
- return save_path
124
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
125
- response = requests.get(url)
126
- response.raise_for_status() # Raises error if download fails
127
- with open(save_path, "wb") as f:
128
- f.write(response.content)
129
- return save_path
130
- else:
131
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
132
- return url
133
- def get_paper_text(doi,id,manualLinks=None):
134
- # create the temporary folder to contain the texts
135
- '''folder_path = Path("data/"+str(id))
136
- if not folder_path.exists():
137
- cmd = f'mkdir data/{id}'
138
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
139
- print("data/"+str(id) +" created.")
140
- else:
141
- print("data/"+str(id) +" already exists.")'''
142
-
143
- cmd = f'mkdir data/{id}'
144
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
145
- saveLinkFolder = "data/"+id
146
-
147
- link = 'https://doi.org/' + doi
148
- '''textsToExtract = { "doiLink":"paperText"
149
- "file1.pdf":"text1",
150
- "file2.doc":"text2",
151
- "file3.xlsx":excelText3'''
152
- textsToExtract = {}
153
- # get the file to create listOfFile for each id
154
- html = extractHTML.HTML("",link)
155
- jsonSM = html.getSupMaterial()
156
- text = ""
157
- links = [link] + sum((jsonSM[key] for key in jsonSM),[])
158
- if manualLinks != None:
159
- links += manualLinks
160
- for l in links:
161
- # get the main paper
162
- name = l.split("/")[-1]
163
- #file_path = folder_path / name
164
- if l == link:
165
- text = html.getListSection()
166
- textsToExtract[link] = text
167
- elif l.endswith(".pdf"):
168
- '''if file_path.is_file():
169
- l = saveLinkFolder + "/" + name
170
- print("File exists.")'''
171
- p = pdf.PDF(l,saveLinkFolder,doi)
172
- f = p.openPDFFile()
173
- pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
174
- doc = fitz.open(pdf_path)
175
- text = "\n".join([page.get_text() for page in doc])
176
- textsToExtract[l] = text
177
- elif l.endswith(".doc") or l.endswith(".docx"):
178
- d = wordDoc.wordDoc(l,saveLinkFolder)
179
- text = d.extractTextByPage()
180
- textsToExtract[l] = text
181
- elif l.split(".")[-1].lower() in "xlsx":
182
- wc = word2vec.word2Vec()
183
- # download excel file if it not downloaded yet
184
- savePath = saveLinkFolder +"/"+ l.split("/")[-1]
185
- excelPath = download_excel_file(l, savePath)
186
- corpus = wc.tableTransformToCorpusText([],excelPath)
187
- text = ''
188
- for c in corpus:
189
- para = corpus[c]
190
- for words in para:
191
- text += " ".join(words)
192
- textsToExtract[l] = text
193
- # delete folder after finishing getting text
194
- cmd = f'rm -r data/{id}'
195
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
196
- return textsToExtract
197
- # Step 3.2: Extract context
198
- def extract_context(text, keyword, window=500):
199
- # firstly try accession number
200
- idx = text.find(keyword)
201
- if idx == -1:
202
- return "Sample ID not found."
203
- return text[max(0, idx-window): idx+window]
204
- def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
205
- if keep_if is None:
206
- keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
207
-
208
- outputs = ""
209
- text = text.lower()
210
-
211
- # If isolate is provided, prioritize paragraphs that mention it
212
- # If isolate is provided, prioritize paragraphs that mention it
213
- if accession and accession.lower() in text:
214
- if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
215
- outputs += extract_context(text, accession.lower(), window=700)
216
- if isolate and isolate.lower() in text:
217
- if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
218
- outputs += extract_context(text, isolate.lower(), window=700)
219
- for keyword in keep_if:
220
- para = extract_context(text, keyword)
221
- if para and para not in outputs:
222
- outputs += para + "\n"
223
- return outputs
224
- # Step 4: Classification for now (demo purposes)
225
- # 4.1: Using a HuggingFace model (question-answering)
226
- def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
227
- try:
228
- qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
229
- result = qa({"context": context, "question": question})
230
- return result.get("answer", "Unknown")
231
- except Exception as e:
232
- return f"Error: {str(e)}"
233
-
234
- # 4.2: Infer from haplogroup
235
- # Load pre-trained spaCy model for NER
236
- try:
237
- nlp = spacy.load("en_core_web_sm")
238
- except OSError:
239
- download("en_core_web_sm")
240
- nlp = spacy.load("en_core_web_sm")
241
-
242
- # Define the haplogroup-to-region mapping (simple rule-based)
243
- import csv
244
-
245
- def load_haplogroup_mapping(csv_path):
246
- mapping = {}
247
- with open(csv_path) as f:
248
- reader = csv.DictReader(f)
249
- for row in reader:
250
- mapping[row["haplogroup"]] = [row["region"],row["source"]]
251
- return mapping
252
-
253
- # Function to extract haplogroup from the text
254
- def extract_haplogroup(text):
255
- match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
256
- if match:
257
- submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
258
- if submatch:
259
- return submatch.group(0)
260
- else:
261
- return match.group(1) # fallback
262
- fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
263
- if fallback:
264
- return fallback.group(1)
265
- return None
266
-
267
-
268
- # Function to extract location based on NER
269
- def extract_location(text):
270
- doc = nlp(text)
271
- locations = []
272
- for ent in doc.ents:
273
- if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
274
- locations.append(ent.text)
275
- return locations
276
-
277
- # Function to infer location from haplogroup
278
- def infer_location_from_haplogroup(haplogroup):
279
- haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
280
- return haplo_map.get(haplogroup, ["Unknown","Unknown"])
281
-
282
- # Function to classify the mtDNA sample
283
- def classify_mtDNA_sample_from_haplo(text):
284
- # Extract haplogroup
285
- haplogroup = extract_haplogroup(text)
286
- # Extract location based on NER
287
- locations = extract_location(text)
288
- # Infer location based on haplogroup
289
- inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
290
- return {
291
- "source":sourceHaplo,
292
- "locations_found_in_context": locations,
293
- "haplogroup": haplogroup,
294
- "inferred_location": inferred_location
295
-
296
- }
297
- # 4.3 Get from available NCBI
298
- def infer_location_fromNCBI(accession):
299
- try:
300
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
301
- text = handle.read()
302
- handle.close()
303
- match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
304
- if match:
305
- return match.group(2), match.group(0) # This is the value like "Brunei"
306
- return "Not found", "Not found"
307
-
308
- except Exception as e:
309
- print("❌ Entrez error:", e)
310
- return "Not found", "Not found"
311
-
312
- ### ANCIENT/MODERN FLAG
313
- from Bio import Entrez
314
- import re
315
-
316
- def flag_ancient_modern(accession, textsToExtract, isolate=None):
317
- """
318
- Try to classify a sample as Ancient or Modern using:
319
- 1. NCBI accession (if available)
320
- 2. Supplementary text or context fallback
321
- """
322
- context = ""
323
- label, explain = "", ""
324
-
325
- try:
326
- # Check if we can fetch metadata from NCBI using the accession
327
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
328
- text = handle.read()
329
- handle.close()
330
-
331
- isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
332
- if isolate_source:
333
- context += isolate_source.group(0) + " "
334
-
335
- specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
336
- if specimen:
337
- context += specimen.group(0) + " "
338
-
339
- if context.strip():
340
- label, explain = detect_ancient_flag(context)
341
- if label!="Unknown":
342
- return label, explain + " from NCBI\n(" + context + ")"
343
-
344
- # If no useful NCBI metadata, check supplementary texts
345
- if textsToExtract:
346
- labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
347
-
348
- for source in textsToExtract:
349
- text_block = textsToExtract[source]
350
- context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
351
- label, explain = detect_ancient_flag(context)
352
-
353
- if label == "Ancient":
354
- labels["ancient"][0] += 1
355
- labels["ancient"][1] += f"{source}:\n{explain}\n\n"
356
- elif label == "Modern":
357
- labels["modern"][0] += 1
358
- labels["modern"][1] += f"{source}:\n{explain}\n\n"
359
- else:
360
- labels["unknown"] += 1
361
-
362
- if max(labels["modern"][0],labels["ancient"][0]) > 0:
363
- if labels["modern"][0] > labels["ancient"][0]:
364
- return "Modern", labels["modern"][1]
365
- else:
366
- return "Ancient", labels["ancient"][1]
367
- else:
368
- return "Unknown", "No strong keywords detected"
369
- else:
370
- print("No DOI or PubMed ID available for inference.")
371
- return "", ""
372
-
373
- except Exception as e:
374
- print("Error:", e)
375
- return "", ""
376
-
377
-
378
- def detect_ancient_flag(context_snippet):
379
- context = context_snippet.lower()
380
-
381
- ancient_keywords = [
382
- "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
383
- "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
384
- "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
385
- "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
386
- ]
387
-
388
- modern_keywords = [
389
- "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
390
- "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
391
- "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
392
- "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
393
- "bioinformatic analysis", "samples from", "population genetics", "genome-wide data"
394
- ]
395
-
396
- ancient_hits = [k for k in ancient_keywords if k in context]
397
- modern_hits = [k for k in modern_keywords if k in context]
398
-
399
- if ancient_hits and not modern_hits:
400
- return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
401
- elif modern_hits and not ancient_hits:
402
- return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
403
- elif ancient_hits and modern_hits:
404
- if len(ancient_hits) >= len(modern_hits):
405
- return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
406
- else:
407
- return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
408
-
409
- # Fallback to QA
410
- answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
411
- if answer.startswith("Error"):
412
- return "Unknown", answer
413
- if "ancient" in answer.lower():
414
- return "Ancient", f"Leaning ancient based on QA: {answer}"
415
- elif "modern" in answer.lower():
416
- return "Modern", f"Leaning modern based on QA: {answer}"
417
- else:
418
- return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
419
-
420
- # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
421
- def classify_sample_location(accession):
422
- outputs = {}
423
- keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
424
- # Step 1: get pubmed id and isolate
425
- pubmedID, isolate = get_info_from_accession(accession)
426
- '''if not pubmedID:
427
- return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
428
- if not isolate:
429
- isolate = "UNKNOWN_ISOLATE"
430
- # Step 2: get doi
431
- doi = get_doi_from_pubmed_id(pubmedID)
432
- '''if not doi:
433
- return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
434
- # Step 3: get text
435
- '''textsToExtract = { "doiLink":"paperText"
436
- "file1.pdf":"text1",
437
- "file2.doc":"text2",
438
- "file3.xlsx":excelText3'''
439
- if doi and pubmedID:
440
- textsToExtract = get_paper_text(doi,pubmedID)
441
- else: textsToExtract = {}
442
- '''if not textsToExtract:
443
- return {"error": f"No texts extracted for DOI {doi}"}'''
444
- if isolate not in [None, "UNKNOWN_ISOLATE"]:
445
- label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
446
- else:
447
- label, explain = flag_ancient_modern(accession,textsToExtract)
448
- # Step 4: prediction
449
- outputs[accession] = {}
450
- outputs[isolate] = {}
451
- # 4.0 Infer from NCBI
452
- location, outputNCBI = infer_location_fromNCBI(accession)
453
- NCBI_result = {
454
- "source": "NCBI",
455
- "sample_id": accession,
456
- "predicted_location": location,
457
- "context_snippet": outputNCBI}
458
- outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
459
- if textsToExtract:
460
- long_text = ""
461
- for key in textsToExtract:
462
- text = textsToExtract[key]
463
- # try accession number first
464
- outputs[accession][key] = {}
465
- keyword = accession
466
- context = extract_context(text, keyword, window=500)
467
- # 4.1: Using a HuggingFace model (question-answering)
468
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
469
- qa_result = {
470
- "source": key,
471
- "sample_id": keyword,
472
- "predicted_location": location,
473
- "context_snippet": context
474
- }
475
- outputs[keyword][key]["QAModel"] = qa_result
476
- # 4.2: Infer from haplogroup
477
- haplo_result = classify_mtDNA_sample_from_haplo(context)
478
- outputs[keyword][key]["haplogroup"] = haplo_result
479
- # try isolate
480
- keyword = isolate
481
- outputs[isolate][key] = {}
482
- context = extract_context(text, keyword, window=500)
483
- # 4.1.1: Using a HuggingFace model (question-answering)
484
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
485
- qa_result = {
486
- "source": key,
487
- "sample_id": keyword,
488
- "predicted_location": location,
489
- "context_snippet": context
490
- }
491
- outputs[keyword][key]["QAModel"] = qa_result
492
- # 4.2.1: Infer from haplogroup
493
- haplo_result = classify_mtDNA_sample_from_haplo(context)
494
- outputs[keyword][key]["haplogroup"] = haplo_result
495
- # add long text
496
- long_text += text + ". \n"
497
- # 4.3: UpgradeClassify
498
- # try sample_id as accession number
499
- sample_id = accession
500
- if sample_id:
501
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
502
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
503
- if locations!="No clear location found in top matches":
504
- outputs[sample_id]["upgradeClassifier"] = {}
505
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
506
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
507
- "sample_id": sample_id,
508
- "predicted_location": ", ".join(locations),
509
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
510
- }
511
- # try sample_id as isolate name
512
- sample_id = isolate
513
- if sample_id:
514
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
515
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
516
- if locations!="No clear location found in top matches":
517
- outputs[sample_id]["upgradeClassifier"] = {}
518
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
519
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
520
- "sample_id": sample_id,
521
- "predicted_location": ", ".join(locations),
522
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
523
- }
524
- return outputs, label, explain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mtDNA Location Classifier MVP (Google Colab)
2
+ # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
+ import os
4
+ #import streamlit as st
5
+ import subprocess
6
+ import re
7
+ from Bio import Entrez
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
+
20
+ # Set your email (required by NCBI Entrez)
21
+ #Entrez.email = "[email protected]"
22
+ import nltk
23
+
24
+ nltk.download("stopwords")
25
+ nltk.download("punkt")
26
+ nltk.download('punkt_tab')
27
+ # Step 1: Get PubMed ID from Accession using EDirect
28
+ from Bio import Entrez, Medline
29
+ import re
30
+
31
+ Entrez.email = "your_email@example.com"
32
+
33
+ # --- Helper Functions (Re-organized and Upgraded) ---
34
+
35
+ def fetch_ncbi_metadata(accession_number):
36
+ """
37
+ Fetches metadata directly from NCBI GenBank using Entrez.
38
+ Includes robust error handling and improved field extraction.
39
+ Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
+ Also attempts to extract ethnicity and sample_type (ancient/modern).
41
+
42
+ Args:
43
+ accession_number (str): The NCBI accession number (e.g., "ON792208").
44
+
45
+ Returns:
46
+ dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
+ 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
+ """
49
+ Entrez.email = "[email protected]" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
+
51
+ country = "unknown"
52
+ specific_location = "unknown"
53
+ ethnicity = "unknown"
54
+ sample_type = "unknown"
55
+ collection_date = "unknown"
56
+ isolate = "unknown"
57
+ title = "unknown"
58
+ doi = "unknown"
59
+ pubmed_id = None
60
+ all_feature = "unknown"
61
+
62
+ KNOWN_COUNTRIES = [
63
+ "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
+ "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
+ "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
+ "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
+ "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
+ "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
+ "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
+ "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
+ "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
+ "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
+ "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
+ "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
+ "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
+ "Yemen", "Zambia", "Zimbabwe"
77
+ ]
78
+ COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
+
80
+ try:
81
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
+ record = Entrez.read(handle)
83
+ handle.close()
84
+
85
+ gb_seq = None
86
+ # Validate record structure: It should be a list with at least one element (a dict)
87
+ if isinstance(record, list) and len(record) > 0:
88
+ if isinstance(record[0], dict):
89
+ gb_seq = record[0]
90
+ else:
91
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
+ else:
93
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
+
95
+ # If gb_seq is still None, return defaults
96
+ if gb_seq is None:
97
+ return {"country": "unknown", "specific_location": "unknown", "ethnicity": "unknown",
98
+ "sample_type": "unknown", "collection_date": "unknown", "isolate": "unknown",
99
+ "title": "unknown", "doi": "unknown", "pubmed_id": None}
100
+
101
+
102
+ # If gb_seq is valid, proceed with extraction
103
+ collection_date = gb_seq.get("GBSeq_create-date","unknown")
104
+
105
+ references = gb_seq.get("GBSeq_references", [])
106
+ for ref in references:
107
+ if not pubmed_id:
108
+ pubmed_id = ref.get("GBReference_pubmed",None)
109
+ if title == "unknown":
110
+ title = ref.get("GBReference_title","unknown")
111
+ for xref in ref.get("GBReference_xref", []):
112
+ if xref.get("GBXref_dbname") == "doi":
113
+ doi = xref.get("GBXref_id")
114
+ break
115
+
116
+ features = gb_seq.get("GBSeq_feature-table", [])
117
+
118
+ context_for_flagging = "" # Accumulate text for ancient/modern detection
119
+ features_context = ""
120
+ for feature in features:
121
+ if feature.get("GBFeature_key") == "source":
122
+ feature_context = ""
123
+ qualifiers = feature.get("GBFeature_quals", [])
124
+ found_country = "unknown"
125
+ found_specific_location = "unknown"
126
+ found_ethnicity = "unknown"
127
+
128
+ temp_geo_loc_name = "unknown"
129
+ temp_note_origin_locality = "unknown"
130
+ temp_country_qual = "unknown"
131
+ temp_locality_qual = "unknown"
132
+ temp_collection_location_qual = "unknown"
133
+ temp_isolation_source_qual = "unknown"
134
+ temp_env_sample_qual = "unknown"
135
+ temp_pop_qual = "unknown"
136
+ temp_organism_qual = "unknown"
137
+ temp_specimen_qual = "unknown"
138
+ temp_strain_qual = "unknown"
139
+
140
+ for qual in qualifiers:
141
+ qual_name = qual.get("GBQualifier_name")
142
+ qual_value = qual.get("GBQualifier_value")
143
+ feature_context += qual_name + ": " + qual_value +"\n"
144
+ if qual_name == "collection_date":
145
+ collection_date = qual_value
146
+ elif qual_name == "isolate":
147
+ isolate = qual_value
148
+ elif qual_name == "population":
149
+ temp_pop_qual = qual_value
150
+ elif qual_name == "organism":
151
+ temp_organism_qual = qual_value
152
+ elif qual_name == "specimen_voucher" or qual_name == "specimen":
153
+ temp_specimen_qual = qual_value
154
+ elif qual_name == "strain":
155
+ temp_strain_qual = qual_value
156
+ elif qual_name == "isolation_source":
157
+ temp_isolation_source_qual = qual_value
158
+ elif qual_name == "environmental_sample":
159
+ temp_env_sample_qual = qual_value
160
+
161
+ if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
162
+ elif qual_name == "note":
163
+ if qual_value.startswith("origin_locality:"):
164
+ temp_note_origin_locality = qual_value
165
+ context_for_flagging += qual_value + " " # Capture all notes for flagging
166
+ elif qual_name == "country": temp_country_qual = qual_value
167
+ elif qual_name == "locality": temp_locality_qual = qual_value
168
+ elif qual_name == "collection_location": temp_collection_location_qual = qual_value
169
+
170
+
171
+ # --- Aggregate all relevant info into context_for_flagging ---
172
+ context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
173
+ context_for_flagging = context_for_flagging.strip()
174
+
175
+ # --- Determine final country and specific_location based on priority ---
176
+ if temp_geo_loc_name != "unknown":
177
+ parts = [p.strip() for p in temp_geo_loc_name.split(':')]
178
+ if len(parts) > 1:
179
+ found_specific_location = parts[-1]; found_country = parts[0]
180
+ else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
181
+ elif temp_note_origin_locality != "unknown":
182
+ match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
183
+ if match:
184
+ location_string = match.group(1).strip()
185
+ parts = [p.strip() for p in location_string.split(':')]
186
+ if len(parts) > 1: found_country = parts[-1]; found_specific_location = parts[0]
187
+ else: found_country = location_string; found_specific_location = "unknown"
188
+ elif temp_locality_qual != "unknown":
189
+ found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
190
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
191
+ else: found_specific_location = temp_locality_qual; found_country = "unknown"
192
+ elif temp_collection_location_qual != "unknown":
193
+ found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
194
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
195
+ else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
196
+ elif temp_isolation_source_qual != "unknown":
197
+ found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
198
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
199
+ else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
200
+ elif temp_env_sample_qual != "unknown":
201
+ found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
202
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
203
+ else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
204
+ if found_country == "unknown" and temp_country_qual != "unknown":
205
+ found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
206
+ if found_country_match: found_country = found_country_match.group(1)
207
+
208
+ country = found_country
209
+ specific_location = found_specific_location
210
+ # --- Determine final ethnicity ---
211
+ if temp_pop_qual != "unknown":
212
+ found_ethnicity = temp_pop_qual
213
+ elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
214
+ found_ethnicity = isolate
215
+ elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
216
+ eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
217
+ if eth_match:
218
+ found_ethnicity = eth_match.group(1).strip()
219
+
220
+ ethnicity = found_ethnicity
221
+
222
+ # --- Determine sample_type (ancient/modern) ---
223
+ if context_for_flagging:
224
+ sample_type, explain = detect_ancient_flag(context_for_flagging)
225
+ features_context += feature_context + "\n"
226
+ break
227
+
228
+ if specific_location != "unknown" and specific_location.lower() == country.lower():
229
+ specific_location = "unknown"
230
+ if not features_context: features_context = "unknown"
231
+ return {"country": country.lower(),
232
+ "specific_location": specific_location.lower(),
233
+ "ethnicity": ethnicity.lower(),
234
+ "sample_type": sample_type.lower(),
235
+ "collection_date": collection_date,
236
+ "isolate": isolate,
237
+ "title": title,
238
+ "doi": doi,
239
+ "pubmed_id": pubmed_id,
240
+ "all_features": features_context}
241
+
242
+ except Exception as e:
243
+ print(f"Error fetching NCBI data for {accession_number}: {e}")
244
+ return {"country": "unknown",
245
+ "specific_location": "unknown",
246
+ "ethnicity": "unknown",
247
+ "sample_type": "unknown",
248
+ "collection_date": "unknown",
249
+ "isolate": "unknown",
250
+ "title": "unknown",
251
+ "doi": "unknown",
252
+ "pubmed_id": None,
253
+ "all_features": "unknown"}
254
+
255
+ # --- Helper function for country matching (re-defined from main code to be self-contained) ---
256
+ _country_keywords = {
257
+ "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
258
+ "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
259
+ "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
260
+ "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
261
+ "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
262
+ "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
263
+ "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
264
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
265
+ "central india": "India", "east india": "India", "northeast india": "India",
266
+ "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
267
+ "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
268
+ }
269
+
270
+ def get_country_from_text(text):
271
+ text_lower = text.lower()
272
+ for keyword, country in _country_keywords.items():
273
+ if keyword in text_lower:
274
+ return country
275
+ return "unknown"
276
+ # The result will be seen as manualLink for the function get_paper_text
277
+ def search_google_custom(query, max_results=3):
278
+ # query should be the title from ncbi or paper/source title
279
+ GOOGLE_CSE_API_KEY = "AIzaSyAg_Hi5DPit2bvvwCs1PpUkAPRZun7yCRQ"
280
+ GOOGLE_CSE_CX = "25a51c433f148490c"
281
+ endpoint = "https://www.googleapis.com/customsearch/v1"
282
+ params = {
283
+ "key": GOOGLE_CSE_API_KEY,
284
+ "cx": GOOGLE_CSE_CX,
285
+ "q": query,
286
+ "num": max_results
287
+ }
288
+ try:
289
+ response = requests.get(endpoint, params=params)
290
+ if response.status_code == 429:
291
+ print("Rate limit hit. Try again later.")
292
+ return []
293
+ response.raise_for_status()
294
+ data = response.json().get("items", [])
295
+ return [item.get("link") for item in data if item.get("link")]
296
+ except Exception as e:
297
+ print("Google CSE error:", e)
298
+ return []
299
+ # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
300
+ # Step 3.1: Extract Text
301
+ # sub: download excel file
302
+ def download_excel_file(url, save_path="temp.xlsx"):
303
+ if "view.officeapps.live.com" in url:
304
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
305
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
306
+ response = requests.get(real_url)
307
+ with open(save_path, "wb") as f:
308
+ f.write(response.content)
309
+ return save_path
310
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
311
+ response = requests.get(url)
312
+ response.raise_for_status() # Raises error if download fails
313
+ with open(save_path, "wb") as f:
314
+ f.write(response.content)
315
+ return save_path
316
+ else:
317
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
318
+ return url
319
+ def get_paper_text(doi,id,manualLinks=None):
320
+ # create the temporary folder to contain the texts
321
+ folder_path = Path("data/"+str(id))
322
+ if not folder_path.exists():
323
+ cmd = f'mkdir data/{id}'
324
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
325
+ print("data/"+str(id) +" created.")
326
+ else:
327
+ print("data/"+str(id) +" already exists.")
328
+ saveLinkFolder = "data/"+id
329
+
330
+ link = 'https://doi.org/' + doi
331
+ '''textsToExtract = { "doiLink":"paperText"
332
+ "file1.pdf":"text1",
333
+ "file2.doc":"text2",
334
+ "file3.xlsx":excelText3'''
335
+ textsToExtract = {}
336
+ # get the file to create listOfFile for each id
337
+ html = extractHTML.HTML("",link)
338
+ jsonSM = html.getSupMaterial()
339
+ text = ""
340
+ links = [link] + sum((jsonSM[key] for key in jsonSM),[])
341
+ if manualLinks != None:
342
+ links += manualLinks
343
+ for l in links:
344
+ # get the main paper
345
+ name = l.split("/")[-1]
346
+ file_path = folder_path / name
347
+ if l == link:
348
+ text = html.getListSection()
349
+ textsToExtract[link] = text
350
+ elif l.endswith(".pdf"):
351
+ if file_path.is_file():
352
+ l = saveLinkFolder + "/" + name
353
+ print("File exists.")
354
+ p = pdf.PDF(l,saveLinkFolder,doi)
355
+ f = p.openPDFFile()
356
+ pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
357
+ doc = fitz.open(pdf_path)
358
+ text = "\n".join([page.get_text() for page in doc])
359
+ textsToExtract[l] = text
360
+ elif l.endswith(".doc") or l.endswith(".docx"):
361
+ d = wordDoc.wordDoc(l,saveLinkFolder)
362
+ text = d.extractTextByPage()
363
+ textsToExtract[l] = text
364
+ elif l.split(".")[-1].lower() in "xlsx":
365
+ wc = word2vec.word2Vec()
366
+ # download excel file if it not downloaded yet
367
+ savePath = saveLinkFolder +"/"+ l.split("/")[-1]
368
+ excelPath = download_excel_file(l, savePath)
369
+ corpus = wc.tableTransformToCorpusText([],excelPath)
370
+ text = ''
371
+ for c in corpus:
372
+ para = corpus[c]
373
+ for words in para:
374
+ text += " ".join(words)
375
+ textsToExtract[l] = text
376
+ # delete folder after finishing getting text
377
+ #cmd = f'rm -r data/{id}'
378
+ #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
379
+ return textsToExtract
380
+ # Step 3.2: Extract context
381
+ def extract_context(text, keyword, window=500):
382
+ # firstly try accession number
383
+ idx = text.find(keyword)
384
+ if idx == -1:
385
+ return "Sample ID not found."
386
+ return text[max(0, idx-window): idx+window]
387
+ def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
388
+ if keep_if is None:
389
+ keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
390
+
391
+ outputs = ""
392
+ text = text.lower()
393
+
394
+ # If isolate is provided, prioritize paragraphs that mention it
395
+ # If isolate is provided, prioritize paragraphs that mention it
396
+ if accession and accession.lower() in text:
397
+ if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
398
+ outputs += extract_context(text, accession.lower(), window=700)
399
+ if isolate and isolate.lower() in text:
400
+ if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
401
+ outputs += extract_context(text, isolate.lower(), window=700)
402
+ for keyword in keep_if:
403
+ para = extract_context(text, keyword)
404
+ if para and para not in outputs:
405
+ outputs += para + "\n"
406
+ return outputs
407
+ # Step 4: Classification for now (demo purposes)
408
+ # 4.1: Using a HuggingFace model (question-answering)
409
+ def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
410
+ try:
411
+ qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
412
+ result = qa({"context": context, "question": question})
413
+ return result.get("answer", "Unknown")
414
+ except Exception as e:
415
+ return f"Error: {str(e)}"
416
+
417
+ # 4.2: Infer from haplogroup
418
+ # Load pre-trained spaCy model for NER
419
+ try:
420
+ nlp = spacy.load("en_core_web_sm")
421
+ except OSError:
422
+ download("en_core_web_sm")
423
+ nlp = spacy.load("en_core_web_sm")
424
+
425
+ # Define the haplogroup-to-region mapping (simple rule-based)
426
+ import csv
427
+
428
+ def load_haplogroup_mapping(csv_path):
429
+ mapping = {}
430
+ with open(csv_path) as f:
431
+ reader = csv.DictReader(f)
432
+ for row in reader:
433
+ mapping[row["haplogroup"]] = [row["region"],row["source"]]
434
+ return mapping
435
+
436
+ # Function to extract haplogroup from the text
437
+ def extract_haplogroup(text):
438
+ match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
439
+ if match:
440
+ submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
441
+ if submatch:
442
+ return submatch.group(0)
443
+ else:
444
+ return match.group(1) # fallback
445
+ fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
446
+ if fallback:
447
+ return fallback.group(1)
448
+ return None
449
+
450
+
451
+ # Function to extract location based on NER
452
+ def extract_location(text):
453
+ doc = nlp(text)
454
+ locations = []
455
+ for ent in doc.ents:
456
+ if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
457
+ locations.append(ent.text)
458
+ return locations
459
+
460
+ # Function to infer location from haplogroup
461
+ def infer_location_from_haplogroup(haplogroup):
462
+ haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
463
+ return haplo_map.get(haplogroup, ["Unknown","Unknown"])
464
+
465
+ # Function to classify the mtDNA sample
466
+ def classify_mtDNA_sample_from_haplo(text):
467
+ # Extract haplogroup
468
+ haplogroup = extract_haplogroup(text)
469
+ # Extract location based on NER
470
+ locations = extract_location(text)
471
+ # Infer location based on haplogroup
472
+ inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
473
+ return {
474
+ "source":sourceHaplo,
475
+ "locations_found_in_context": locations,
476
+ "haplogroup": haplogroup,
477
+ "inferred_location": inferred_location
478
+
479
+ }
480
+ # 4.3 Get from available NCBI
481
+ def infer_location_fromNCBI(accession):
482
+ try:
483
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
484
+ text = handle.read()
485
+ handle.close()
486
+ match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
487
+ if match:
488
+ return match.group(2), match.group(0) # This is the value like "Brunei"
489
+ return "Not found", "Not found"
490
+
491
+ except Exception as e:
492
+ print("❌ Entrez error:", e)
493
+ return "Not found", "Not found"
494
+
495
+ ### ANCIENT/MODERN FLAG
496
+ from Bio import Entrez
497
+ import re
498
+
499
+ def flag_ancient_modern(accession, textsToExtract, isolate=None):
500
+ """
501
+ Try to classify a sample as Ancient or Modern using:
502
+ 1. NCBI accession (if available)
503
+ 2. Supplementary text or context fallback
504
+ """
505
+ context = ""
506
+ label, explain = "", ""
507
+
508
+ try:
509
+ # Check if we can fetch metadata from NCBI using the accession
510
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
511
+ text = handle.read()
512
+ handle.close()
513
+
514
+ isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
515
+ if isolate_source:
516
+ context += isolate_source.group(0) + " "
517
+
518
+ specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
519
+ if specimen:
520
+ context += specimen.group(0) + " "
521
+
522
+ if context.strip():
523
+ label, explain = detect_ancient_flag(context)
524
+ if label!="Unknown":
525
+ return label, explain + " from NCBI\n(" + context + ")"
526
+
527
+ # If no useful NCBI metadata, check supplementary texts
528
+ if textsToExtract:
529
+ labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
530
+
531
+ for source in textsToExtract:
532
+ text_block = textsToExtract[source]
533
+ context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
534
+ label, explain = detect_ancient_flag(context)
535
+
536
+ if label == "Ancient":
537
+ labels["ancient"][0] += 1
538
+ labels["ancient"][1] += f"{source}:\n{explain}\n\n"
539
+ elif label == "Modern":
540
+ labels["modern"][0] += 1
541
+ labels["modern"][1] += f"{source}:\n{explain}\n\n"
542
+ else:
543
+ labels["unknown"] += 1
544
+
545
+ if max(labels["modern"][0],labels["ancient"][0]) > 0:
546
+ if labels["modern"][0] > labels["ancient"][0]:
547
+ return "Modern", labels["modern"][1]
548
+ else:
549
+ return "Ancient", labels["ancient"][1]
550
+ else:
551
+ return "Unknown", "No strong keywords detected"
552
+ else:
553
+ print("No DOI or PubMed ID available for inference.")
554
+ return "", ""
555
+
556
+ except Exception as e:
557
+ print("Error:", e)
558
+ return "", ""
559
+
560
+
561
+ def detect_ancient_flag(context_snippet):
562
+ context = context_snippet.lower()
563
+
564
+ ancient_keywords = [
565
+ "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
566
+ "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
567
+ "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
568
+ "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
569
+ ]
570
+
571
+ modern_keywords = [
572
+ "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
573
+ "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
574
+ "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
575
+ "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
576
+ "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
577
+ ]
578
+
579
+ ancient_hits = [k for k in ancient_keywords if k in context]
580
+ modern_hits = [k for k in modern_keywords if k in context]
581
+
582
+ if ancient_hits and not modern_hits:
583
+ return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
584
+ elif modern_hits and not ancient_hits:
585
+ return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
586
+ elif ancient_hits and modern_hits:
587
+ if len(ancient_hits) >= len(modern_hits):
588
+ return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
589
+ else:
590
+ return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
591
+
592
+ # Fallback to QA
593
+ answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
594
+ if answer.startswith("Error"):
595
+ return "Unknown", answer
596
+ if "ancient" in answer.lower():
597
+ return "Ancient", f"Leaning ancient based on QA: {answer}"
598
+ elif "modern" in answer.lower():
599
+ return "Modern", f"Leaning modern based on QA: {answer}"
600
+ else:
601
+ return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
602
+
603
+ # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
604
+ def classify_sample_location(accession):
605
+ outputs = {}
606
+ keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
607
+ # Step 1: get pubmed id and isolate
608
+ pubmedID, isolate = get_info_from_accession(accession)
609
+ '''if not pubmedID:
610
+ return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
611
+ if not isolate:
612
+ isolate = "UNKNOWN_ISOLATE"
613
+ # Step 2: get doi
614
+ doi = get_doi_from_pubmed_id(pubmedID)
615
+ '''if not doi:
616
+ return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
617
+ # Step 3: get text
618
+ '''textsToExtract = { "doiLink":"paperText"
619
+ "file1.pdf":"text1",
620
+ "file2.doc":"text2",
621
+ "file3.xlsx":excelText3'''
622
+ if doi and pubmedID:
623
+ textsToExtract = get_paper_text(doi,pubmedID)
624
+ else: textsToExtract = {}
625
+ '''if not textsToExtract:
626
+ return {"error": f"No texts extracted for DOI {doi}"}'''
627
+ if isolate not in [None, "UNKNOWN_ISOLATE"]:
628
+ label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
629
+ else:
630
+ label, explain = flag_ancient_modern(accession,textsToExtract)
631
+ # Step 4: prediction
632
+ outputs[accession] = {}
633
+ outputs[isolate] = {}
634
+ # 4.0 Infer from NCBI
635
+ location, outputNCBI = infer_location_fromNCBI(accession)
636
+ NCBI_result = {
637
+ "source": "NCBI",
638
+ "sample_id": accession,
639
+ "predicted_location": location,
640
+ "context_snippet": outputNCBI}
641
+ outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
642
+ if textsToExtract:
643
+ long_text = ""
644
+ for key in textsToExtract:
645
+ text = textsToExtract[key]
646
+ # try accession number first
647
+ outputs[accession][key] = {}
648
+ keyword = accession
649
+ context = extract_context(text, keyword, window=500)
650
+ # 4.1: Using a HuggingFace model (question-answering)
651
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
652
+ qa_result = {
653
+ "source": key,
654
+ "sample_id": keyword,
655
+ "predicted_location": location,
656
+ "context_snippet": context
657
+ }
658
+ outputs[keyword][key]["QAModel"] = qa_result
659
+ # 4.2: Infer from haplogroup
660
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
661
+ outputs[keyword][key]["haplogroup"] = haplo_result
662
+ # try isolate
663
+ keyword = isolate
664
+ outputs[isolate][key] = {}
665
+ context = extract_context(text, keyword, window=500)
666
+ # 4.1.1: Using a HuggingFace model (question-answering)
667
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
668
+ qa_result = {
669
+ "source": key,
670
+ "sample_id": keyword,
671
+ "predicted_location": location,
672
+ "context_snippet": context
673
+ }
674
+ outputs[keyword][key]["QAModel"] = qa_result
675
+ # 4.2.1: Infer from haplogroup
676
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
677
+ outputs[keyword][key]["haplogroup"] = haplo_result
678
+ # add long text
679
+ long_text += text + ". \n"
680
+ # 4.3: UpgradeClassify
681
+ # try sample_id as accession number
682
+ sample_id = accession
683
+ if sample_id:
684
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
685
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
686
+ if locations!="No clear location found in top matches":
687
+ outputs[sample_id]["upgradeClassifier"] = {}
688
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
689
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
690
+ "sample_id": sample_id,
691
+ "predicted_location": ", ".join(locations),
692
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
693
+ }
694
+ # try sample_id as isolate name
695
+ sample_id = isolate
696
+ if sample_id:
697
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
698
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
699
+ if locations!="No clear location found in top matches":
700
+ outputs[sample_id]["upgradeClassifier"] = {}
701
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
702
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
703
+ "sample_id": sample_id,
704
+ "predicted_location": ", ".join(locations),
705
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
706
+ }
707
+ return outputs, label, explain
pipeline.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test1: MJ17 direct
2
+ # test2: "A1YU101" thailand cross-ref
3
+ # test3: "EBK109" thailand cross-ref
4
+ # test4: "OQ731952"/"BST115" for search query title: "South Asian maternal and paternal lineages in southern Thailand and"
5
+ from iterate3 import data_preprocess, model
6
+ import mtdna_classifier
7
+ import app
8
+ import pandas as pd
9
+ from pathlib import Path
10
+ import subprocess
11
+ from NER.html import extractHTML
12
+ import os
13
+ import google.generativeai as genai
14
+ import re
15
+ import standardize_location
16
+ # Helper functions in for this pipeline
17
+ # Track time
18
+ import time
19
+ import multiprocessing
20
+
21
+ def run_with_timeout(func, args=(), kwargs={}, timeout=20):
22
+ """
23
+ Runs `func` with timeout in seconds. Kills if it exceeds.
24
+ Returns: (success, result or None)
25
+ """
26
+ def wrapper(q, *args, **kwargs):
27
+ try:
28
+ q.put(func(*args, **kwargs))
29
+ except Exception as e:
30
+ q.put(e)
31
+
32
+ q = multiprocessing.Queue()
33
+ p = multiprocessing.Process(target=wrapper, args=(q, *args), kwargs=kwargs)
34
+ p.start()
35
+ p.join(timeout)
36
+
37
+ if p.is_alive():
38
+ p.terminate()
39
+ p.join()
40
+ print(f"⏱️ Timeout exceeded ({timeout} sec) — function killed.")
41
+ return False, None
42
+ else:
43
+ result = q.get()
44
+ if isinstance(result, Exception):
45
+ raise result
46
+ return True, result
47
+
48
+ def time_it(func, *args, **kwargs):
49
+ """
50
+ Measure how long a function takes to run and return its result + time.
51
+ """
52
+ start = time.time()
53
+ result = func(*args, **kwargs)
54
+ end = time.time()
55
+ elapsed = end - start
56
+ print(f"⏱️ '{func.__name__}' took {elapsed:.3f} seconds")
57
+ return result, elapsed
58
+ # --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) ---
59
+ def track_gemini_cost():
60
+ # Prices are per 1,000 tokens
61
+ PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
62
+ PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
63
+ PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
64
+ return True
65
+
66
+ def unique_preserve_order(seq):
67
+ seen = set()
68
+ return [x for x in seq if not (x in seen or seen.add(x))]
69
+ # Main execution
70
+ def pipeline_with_gemini(accessions):
71
+ # output: country, sample_type, ethnic, location, money_cost, time_cost, explain
72
+ # there can be one accession number in the accessions
73
+ # Prices are per 1,000 tokens
74
+ PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
75
+ PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
76
+ PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
77
+ if not accessions:
78
+ print("no input")
79
+ return None
80
+ else:
81
+ accs_output = {}
82
+ os.environ["GOOGLE_API_KEY"] = "AIzaSyDi0CNKBgEtnr6YuPaY6YNEuC5wT0cdKhk"
83
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
84
+ for acc in accessions:
85
+ start = time.time()
86
+ total_cost_title = 0
87
+ jsonSM, links, article_text = {},[], ""
88
+ acc_score = { "isolate": "",
89
+ "country":{},
90
+ "sample_type":{},
91
+ #"specific_location":{},
92
+ #"ethnicity":{},
93
+ "query_cost":total_cost_title,
94
+ "time_cost":None,
95
+ "source":links}
96
+ meta = mtdna_classifier.fetch_ncbi_metadata(acc)
97
+ country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"]
98
+ acc_score["isolate"] = iso
99
+ # set up step: create the folder to save document
100
+ chunk, all_output = "",""
101
+ if pudID:
102
+ id = pudID
103
+ saveTitle = title
104
+ else:
105
+ saveTitle = title + "_" + col_date
106
+ id = "DirectSubmission"
107
+ folder_path = Path("/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id))
108
+ if not folder_path.exists():
109
+ cmd = f'mkdir /content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/{id}'
110
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
111
+ print("data/"+str(id) +" created.")
112
+ else:
113
+ print("data/"+str(id) +" already exists.")
114
+ saveLinkFolder = "/content/drive/MyDrive/CollectData/MVP/mtDNA-Location-Classifier/data/"+str(id)
115
+ # first way: ncbi method
116
+ if country.lower() != "unknown":
117
+ stand_country = standardize_location.smart_country_lookup(country.lower())
118
+ if stand_country.lower() != "not found":
119
+ acc_score["country"][stand_country.lower()] = ["ncbi"]
120
+ else: acc_score["country"][country.lower()] = ["ncbi"]
121
+ # if spe_loc.lower() != "unknown":
122
+ # acc_score["specific_location"][spe_loc.lower()] = ["ncbi"]
123
+ # if ethnic.lower() != "unknown":
124
+ # acc_score["ethnicity"][ethnic.lower()] = ["ncbi"]
125
+ if sample_type.lower() != "unknown":
126
+ acc_score["sample_type"][sample_type.lower()] = ["ncbi"]
127
+ # second way: LLM model
128
+ # Preprocess the input token
129
+ accession, isolate = None, None
130
+ if acc != "unknown": accession = acc
131
+ if iso != "unknown": isolate = iso
132
+ # check doi first
133
+ if doi != "unknown":
134
+ link = 'https://doi.org/' + doi
135
+ # get the file to create listOfFile for each id
136
+ html = extractHTML.HTML("",link)
137
+ jsonSM = html.getSupMaterial()
138
+ article_text = html.getListSection()
139
+ if article_text:
140
+ if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text.lower() or "403 Forbidden Request".lower() not in article_text.lower():
141
+ links.append(link)
142
+ if jsonSM:
143
+ links += sum((jsonSM[key] for key in jsonSM),[])
144
+ # no doi then google custom search api
145
+ if len(article_text) == 0 or "Just a moment...Enable JavaScript and cookies to continue".lower() in article_text.lower() or "403 Forbidden Request".lower() in article_text.lower():
146
+ # might find the article
147
+ tem_links = mtdna_classifier.search_google_custom(title, 2)
148
+ # get supplementary of that article
149
+ for link in tem_links:
150
+ html = extractHTML.HTML("",link)
151
+ jsonSM = html.getSupMaterial()
152
+ article_text_tem = html.getListSection()
153
+ if article_text_tem:
154
+ if "Just a moment...Enable JavaScript and cookies to continue".lower() not in article_text_tem.lower() or "403 Forbidden Request".lower() not in article_text_tem.lower():
155
+ links.append(link)
156
+ if jsonSM:
157
+ links += sum((jsonSM[key] for key in jsonSM),[])
158
+ print(links)
159
+ links = unique_preserve_order(links)
160
+ acc_score["source"] = links
161
+ chunk_path = "/"+saveTitle+"_merged_document.docx"
162
+ all_path = "/"+saveTitle+"_all_merged_document.docx"
163
+ # if chunk and all output not exist yet
164
+ file_chunk_path = saveLinkFolder + chunk_path
165
+ file_all_path = saveLinkFolder + all_path
166
+ if os.path.exists(file_chunk_path):
167
+ print("File chunk exists!")
168
+ if not chunk:
169
+ text, table, document_title = model.read_docx_text(file_chunk_path)
170
+ chunk = data_preprocess.normalize_for_overlap(text) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table))
171
+ if os.path.exists(file_all_path):
172
+ print("File all output exists!")
173
+ if not all_output:
174
+ text_all, table_all, document_title_all = model.read_docx_text(file_all_path)
175
+ all_output = data_preprocess.normalize_for_overlap(text_all) + "\n" + data_preprocess.normalize_for_overlap(". ".join(table_all))
176
+ if not chunk and not all_output:
177
+ # else: check if we can reuse these chunk and all output of existed accession to find another
178
+ if links:
179
+ for link in links:
180
+ print(link)
181
+ # if len(all_output) > 1000*1000:
182
+ # all_output = data_preprocess.normalize_for_overlap(all_output)
183
+ # print("after normalizing all output: ", len(all_output))
184
+ if len(data_preprocess.normalize_for_overlap(all_output)) > 600000:
185
+ print("break here")
186
+ break
187
+ if iso != "unknown": query_kw = iso
188
+ else: query_kw = acc
189
+ #text_link, tables_link, final_input_link = data_preprocess.preprocess_document(link,saveLinkFolder, isolate=query_kw)
190
+ success_process, output_process = run_with_timeout(data_preprocess.preprocess_document,args=(link,saveLinkFolder),kwargs={"isolate":query_kw},timeout=180)
191
+ if success_process:
192
+ text_link, tables_link, final_input_link = output_process[0], output_process[1], output_process[2]
193
+ print("yes succeed for process document")
194
+ else: text_link, tables_link, final_input_link = "", "", ""
195
+ context = data_preprocess.extract_context(final_input_link, query_kw)
196
+ if context != "Sample ID not found.":
197
+ if len(data_preprocess.normalize_for_overlap(chunk)) < 1000*1000:
198
+ success_chunk, the_output_chunk = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(chunk, context))
199
+ if success_chunk:
200
+ chunk = the_output_chunk#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
201
+ print("yes succeed for chunk")
202
+ else:
203
+ chunk += context
204
+ print("len context: ", len(context))
205
+ print("basic fall back")
206
+ print("len chunk after: ", len(chunk))
207
+ if len(final_input_link) > 1000*1000:
208
+ if context != "Sample ID not found.":
209
+ final_input_link = context
210
+ else:
211
+ final_input_link = data_preprocess.normalize_for_overlap(final_input_link)
212
+ if len(final_input_link) > 1000 *1000:
213
+ final_input_link = final_input_link[:100000]
214
+ if len(data_preprocess.normalize_for_overlap(all_output)) < 1000*1000:
215
+ success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link))
216
+ if success:
217
+ all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
218
+ print("yes succeed")
219
+ else:
220
+ all_output += final_input_link
221
+ print("len final input: ", len(final_input_link))
222
+ print("basic fall back")
223
+ print("len all output after: ", len(all_output))
224
+ #country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate)
225
+
226
+ else:
227
+ chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
228
+ all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
229
+ if not chunk: chunk = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
230
+ if not all_output: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
231
+ if len(all_output) > 1*1024*1024:
232
+ all_output = data_preprocess.normalize_for_overlap(all_output)
233
+ if len(all_output) > 1*1024*1024:
234
+ all_output = all_output[:1*1024*1024]
235
+ print("chunk len: ", len(chunk))
236
+ print("all output len: ", len(all_output))
237
+ data_preprocess.save_text_to_docx(chunk, file_chunk_path)
238
+ data_preprocess.save_text_to_docx(all_output, file_all_path)
239
+ # else:
240
+ # final_input = ""
241
+ # if all_output:
242
+ # final_input = all_output
243
+ # else:
244
+ # if chunk: final_input = chunk
245
+ # #data_preprocess.merge_texts_skipping_overlap(final_input, all_output)
246
+ # if final_input:
247
+ # keywords = []
248
+ # if iso != "unknown": keywords.append(iso)
249
+ # if acc != "unknown": keywords.append(acc)
250
+ # for keyword in keywords:
251
+ # chunkBFS = data_preprocess.get_contextual_sentences_BFS(final_input, keyword)
252
+ # countryDFS, chunkDFS = data_preprocess.get_contextual_sentences_DFS(final_input, keyword)
253
+ # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkDFS)
254
+ # chunk = data_preprocess.merge_texts_skipping_overlap(chunk, chunkBFS)
255
+
256
+ # Define paths for cached RAG assets
257
+ faiss_index_path = saveLinkFolder+"/faiss_index.bin"
258
+ document_chunks_path = saveLinkFolder+"/document_chunks.json"
259
+ structured_lookup_path = saveLinkFolder+"/structured_lookup.json"
260
+
261
+ master_structured_lookup, faiss_index, document_chunks = model.load_rag_assets(
262
+ faiss_index_path, document_chunks_path, structured_lookup_path
263
+ )
264
+
265
+ global_llm_model_for_counting_tokens = genai.GenerativeModel('gemini-1.5-flash-latest')
266
+ if not all_output:
267
+ if chunk: all_output = chunk
268
+ else: all_output = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + features
269
+ if faiss_index is None:
270
+ print("\nBuilding RAG assets (structured lookup, FAISS index, chunks)...")
271
+ total_doc_embedding_tokens = global_llm_model_for_counting_tokens.count_tokens(
272
+ all_output
273
+ ).total_tokens
274
+
275
+ initial_embedding_cost = (total_doc_embedding_tokens / 1000) * PRICE_PER_1K_EMBEDDING_INPUT
276
+ total_cost_title += initial_embedding_cost
277
+ print(f"Initial one-time embedding cost for '{file_all_path}' ({total_doc_embedding_tokens} tokens): ${initial_embedding_cost:.6f}")
278
+
279
+
280
+ master_structured_lookup, faiss_index, document_chunks, plain_text_content = model.build_vector_index_and_data(
281
+ file_all_path, faiss_index_path, document_chunks_path, structured_lookup_path
282
+ )
283
+ else:
284
+ print("\nRAG assets loaded from file. No re-embedding of entire document will occur.")
285
+ plain_text_content_all, table_strings_all, document_title_all = model.read_docx_text(file_all_path)
286
+ master_structured_lookup['document_title'] = master_structured_lookup.get('document_title', document_title_all)
287
+
288
+ primary_word = iso
289
+ alternative_word = acc
290
+ print(f"\n--- General Query: Primary='{primary_word}' (Alternative='{alternative_word}') ---")
291
+ if features.lower() not in all_output.lower():
292
+ all_output += ". NCBI Features: " + features
293
+ # country, sample_type, method_used, ethnic, spe_loc, total_query_cost = model.query_document_info(
294
+ # primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
295
+ # model.call_llm_api, chunk=chunk, all_output=all_output)
296
+ country, sample_type, method_used, country_explanation, sample_type_explanation, total_query_cost = model.query_document_info(
297
+ primary_word, alternative_word, meta, master_structured_lookup, faiss_index, document_chunks,
298
+ model.call_llm_api, chunk=chunk, all_output=all_output)
299
+ if len(country) == 0: country = "unknown"
300
+ if len(sample_type) == 0: sample_type = "unknown"
301
+ if country_explanation: country_explanation = "-"+country_explanation
302
+ else: country_explanation = ""
303
+ if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
304
+ else: sample_type_explanation = ""
305
+ if method_used == "unknown": method_used = ""
306
+ if country.lower() != "unknown":
307
+ stand_country = standardize_location.smart_country_lookup(country.lower())
308
+ if stand_country.lower() != "not found":
309
+ if stand_country.lower() in acc_score["country"]:
310
+ if country_explanation:
311
+ acc_score["country"][stand_country.lower()].append(method_used + country_explanation)
312
+ else:
313
+ acc_score["country"][stand_country.lower()] = [method_used + country_explanation]
314
+ else:
315
+ if country.lower() in acc_score["country"]:
316
+ if country_explanation:
317
+ if len(method_used + country_explanation) > 0:
318
+ acc_score["country"][country.lower()].append(method_used + country_explanation)
319
+ else:
320
+ if len(method_used + country_explanation) > 0:
321
+ acc_score["country"][country.lower()] = [method_used + country_explanation]
322
+ # if spe_loc.lower() != "unknown":
323
+ # if spe_loc.lower() in acc_score["specific_location"]:
324
+ # acc_score["specific_location"][spe_loc.lower()].append(method_used)
325
+ # else:
326
+ # acc_score["specific_location"][spe_loc.lower()] = [method_used]
327
+ # if ethnic.lower() != "unknown":
328
+ # if ethnic.lower() in acc_score["ethnicity"]:
329
+ # acc_score["ethnicity"][ethnic.lower()].append(method_used)
330
+ # else:
331
+ # acc_score["ethnicity"][ethnic.lower()] = [method_used]
332
+ if sample_type.lower() != "unknown":
333
+ if sample_type.lower() in acc_score["sample_type"]:
334
+ if len(method_used + sample_type_explanation) > 0:
335
+ acc_score["sample_type"][sample_type.lower()].append(method_used + sample_type_explanation)
336
+ else:
337
+ if len(method_used + sample_type_explanation)> 0:
338
+ acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
339
+ end = time.time()
340
+ total_cost_title += total_query_cost
341
+ acc_score["query_cost"] = total_cost_title
342
+ elapsed = end - start
343
+ acc_score["time_cost"] = f"{elapsed:.3f} seconds"
344
+ accs_output[acc] = acc_score
345
+ print(accs_output[acc])
346
+
347
+ return accs_output
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  biopython==1.85
2
  bs4==0.0.2
3
  gensim==4.3.3
4
- gradio==5.29.0
5
  gspread==6.2.0
6
  gspread-dataframe==4.0.0
7
  huggingface-hub==0.30.2
@@ -23,9 +23,22 @@ Spire.Xls==14.12.0
23
  statsmodels==0.14.4
24
  tabula-py==2.10.0
25
  thefuzz==0.22.1
26
- torch @ https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp310-cp310-linux_x86_64.whl
27
  transformers==4.51.3
28
  wordsegment==1.3.1
29
  xlrd==2.0.1
30
  sentence-transformers
31
- lxml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  biopython==1.85
2
  bs4==0.0.2
3
  gensim==4.3.3
4
+ gradio
5
  gspread==6.2.0
6
  gspread-dataframe==4.0.0
7
  huggingface-hub==0.30.2
 
23
  statsmodels==0.14.4
24
  tabula-py==2.10.0
25
  thefuzz==0.22.1
26
+ torch
27
  transformers==4.51.3
28
  wordsegment==1.3.1
29
  xlrd==2.0.1
30
  sentence-transformers
31
+ lxml
32
+ streamlit
33
+ requests
34
+ google-generativeai
35
+ PyPDF2
36
+ beautifulsoup4
37
+ # For Claude
38
+ anthropic
39
+ faiss-cpu
40
+ python-docx
41
+ pycountry
42
+ # For Deepseek (If direct DeepseekLLM client library is available, use it.
43
+ # Otherwise, 'requests' covers it for simple API calls, but a dedicated client is better for full features)
44
+ # deepseek-llm # Uncomment this if Deepseek provides a dedicated pip package for their LLM