VyLala commited on
Commit
20f8860
Β·
verified Β·
1 Parent(s): 64d0b74

Update mtdna_backend.py

Browse files
Files changed (1) hide show
  1. mtdna_backend.py +906 -896
mtdna_backend.py CHANGED
@@ -1,897 +1,907 @@
1
- import gradio as gr
2
- from collections import Counter
3
- import csv
4
- import os
5
- from functools import lru_cache
6
- #import app
7
- from mtdna_classifier import classify_sample_location
8
- import data_preprocess, model, pipeline
9
- import subprocess
10
- import json
11
- import pandas as pd
12
- import io
13
- import re
14
- import tempfile
15
- import gspread
16
- from oauth2client.service_account import ServiceAccountCredentials
17
- from io import StringIO
18
- import hashlib
19
- import threading
20
-
21
- # @lru_cache(maxsize=3600)
22
- # def classify_sample_location_cached(accession):
23
- # return classify_sample_location(accession)
24
-
25
- @lru_cache(maxsize=3600)
26
- def pipeline_classify_sample_location_cached(accession):
27
- print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
- return pipeline.pipeline_with_gemini([accession])
29
-
30
- # Count and suggest final location
31
- # def compute_final_suggested_location(rows):
32
- # candidates = [
33
- # row.get("Predicted Location", "").strip()
34
- # for row in rows
35
- # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
36
- # ] + [
37
- # row.get("Inferred Region", "").strip()
38
- # for row in rows
39
- # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
40
- # ]
41
-
42
- # if not candidates:
43
- # return Counter(), ("Unknown", 0)
44
- # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
45
- # tokens = []
46
- # for item in candidates:
47
- # # Split by comma, whitespace, and newlines
48
- # parts = re.split(r'[\s,]+', item)
49
- # tokens.extend(parts)
50
-
51
- # # Step 2: Clean and normalize tokens
52
- # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
53
-
54
- # # Step 3: Count
55
- # counts = Counter(tokens)
56
-
57
- # # Step 4: Get most common
58
- # top_location, count = counts.most_common(1)[0]
59
- # return counts, (top_location, count)
60
-
61
- # Store feedback (with required fields)
62
-
63
- def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
64
- if not answer1.strip() or not answer2.strip():
65
- return "⚠️ Please answer both questions before submitting."
66
-
67
- try:
68
- # βœ… Step: Load credentials from Hugging Face secret
69
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
70
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
71
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
72
-
73
- # Connect to Google Sheet
74
- client = gspread.authorize(creds)
75
- sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
76
-
77
- # Append feedback
78
- sheet.append_row([accession, answer1, answer2, contact])
79
- return "βœ… Feedback submitted. Thank you!"
80
-
81
- except Exception as e:
82
- return f"❌ Error submitting feedback: {e}"
83
-
84
- # helper function to extract accessions
85
- def extract_accessions_from_input(file=None, raw_text=""):
86
- print(f"RAW TEXT RECEIVED: {raw_text}")
87
- accessions = []
88
- seen = set()
89
- if file:
90
- try:
91
- if file.name.endswith(".csv"):
92
- df = pd.read_csv(file)
93
- elif file.name.endswith(".xlsx"):
94
- df = pd.read_excel(file)
95
- else:
96
- return [], "Unsupported file format. Please upload CSV or Excel."
97
- for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
98
- if acc not in seen:
99
- accessions.append(acc)
100
- seen.add(acc)
101
- except Exception as e:
102
- return [], f"Failed to read file: {e}"
103
-
104
- if raw_text:
105
- text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
106
- for acc in text_ids:
107
- if acc not in seen:
108
- accessions.append(acc)
109
- seen.add(acc)
110
-
111
- return list(accessions), None
112
- # βœ… Add a new helper to backend: `filter_unprocessed_accessions()`
113
- def get_incomplete_accessions(file_path):
114
- df = pd.read_excel(file_path)
115
-
116
- incomplete_accessions = []
117
- for _, row in df.iterrows():
118
- sample_id = str(row.get("Sample ID", "")).strip()
119
-
120
- # Skip if no sample ID
121
- if not sample_id:
122
- continue
123
-
124
- # Drop the Sample ID and check if the rest is empty
125
- other_cols = row.drop(labels=["Sample ID"], errors="ignore")
126
- if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
127
- # Extract the accession number from the sample ID using regex
128
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
129
- if match:
130
- incomplete_accessions.append(match.group(0))
131
- print(len(incomplete_accessions))
132
- return incomplete_accessions
133
-
134
- # GOOGLE_SHEET_NAME = "known_samples"
135
- # USAGE_DRIVE_FILENAME = "user_usage_log.json"
136
-
137
- def summarize_results(accession):
138
- # try cache first
139
- cached = check_known_output(accession)
140
- if cached:
141
- print(f"βœ… Using cached result for {accession}")
142
- return [[
143
- cached["Sample ID"] or "unknown",
144
- cached["Predicted Country"] or "unknown",
145
- cached["Country Explanation"] or "unknown",
146
- cached["Predicted Sample Type"] or "unknown",
147
- cached["Sample Type Explanation"] or "unknown",
148
- cached["Sources"] or "No Links",
149
- cached["Time cost"]
150
- ]]
151
- # only run when nothing in the cache
152
- try:
153
- print("try gemini pipeline: ",accession)
154
- outputs = pipeline_classify_sample_location_cached(accession)
155
- # outputs = {'KU131308': {'isolate':'BRU18',
156
- # 'country': {'brunei': ['ncbi',
157
- # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
158
- # 'sample_type': {'modern':
159
- # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
160
- # 'query_cost': 9.754999999999999e-05,
161
- # 'time_cost': '24.776 seconds',
162
- # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
163
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
164
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
165
- except Exception as e:
166
- return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
167
-
168
- if accession not in outputs:
169
- print("no accession in output ", accession)
170
- return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
171
-
172
- row_score = []
173
- rows = []
174
- save_rows = []
175
- for key in outputs:
176
- pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
177
- for section, results in outputs[key].items():
178
- if section == "country" or section =="sample_type":
179
- pred_output = []#"\n".join(list(results.keys()))
180
- output_explanation = ""
181
- for result, content in results.items():
182
- if len(result) == 0: result = "unknown"
183
- if len(content) == 0: output_explanation = "unknown"
184
- else:
185
- output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
186
- pred_output.append(result)
187
- pred_output = "\n".join(pred_output)
188
- if section == "country":
189
- pred_country, country_explanation = pred_output, output_explanation
190
- elif section == "sample_type":
191
- pred_sample, sample_explanation = pred_output, output_explanation
192
- if outputs[key]["isolate"].lower()!="unknown":
193
- label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
194
- else: label = key
195
- if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
196
- row = {
197
- "Sample ID": label or "unknown",
198
- "Predicted Country": pred_country or "unknown",
199
- "Country Explanation": country_explanation or "unknown",
200
- "Predicted Sample Type":pred_sample or "unknown",
201
- "Sample Type Explanation":sample_explanation or "unknown",
202
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
203
- "Time cost": outputs[key]["time_cost"]
204
- }
205
- #row_score.append(row)
206
- rows.append(list(row.values()))
207
-
208
- save_row = {
209
- "Sample ID": label or "unknown",
210
- "Predicted Country": pred_country or "unknown",
211
- "Country Explanation": country_explanation or "unknown",
212
- "Predicted Sample Type":pred_sample or "unknown",
213
- "Sample Type Explanation":sample_explanation or "unknown",
214
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
215
- "Query_cost": outputs[key]["query_cost"],
216
- "Time cost": outputs[key]["time_cost"]
217
- }
218
- #row_score.append(row)
219
- save_rows.append(list(save_row.values()))
220
-
221
- # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
222
- # summary_lines = [f"### 🧭 Location Summary:\n"]
223
- # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
224
- # summary_lines.append(f"\n**Final Suggested Location:** πŸ—ΊοΈ **{final_location}** (mentioned {count} times)")
225
- # summary = "\n".join(summary_lines)
226
-
227
- # save the new running sample to known excel file
228
- # try:
229
- # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
230
- # if os.path.exists(KNOWN_OUTPUT_PATH):
231
- # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
232
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
233
- # else:
234
- # df_combined = df_new
235
- # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
236
- # except Exception as e:
237
- # print(f"⚠️ Failed to save known output: {e}")
238
- # try:
239
- # df_new = pd.DataFrame(save_rows, columns=[
240
- # "Sample ID", "Predicted Country", "Country Explanation",
241
- # "Predicted Sample Type", "Sample Type Explanation",
242
- # "Sources", "Query_cost", "Time cost"
243
- # ])
244
-
245
- # # βœ… Google Sheets API setup
246
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
247
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
248
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
249
- # client = gspread.authorize(creds)
250
-
251
- # # βœ… Open the known_samples sheet
252
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
253
- # sheet = spreadsheet.sheet1
254
-
255
- # # βœ… Read old data
256
- # existing_data = sheet.get_all_values()
257
- # if existing_data:
258
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
259
- # else:
260
- # df_old = pd.DataFrame(columns=df_new.columns)
261
-
262
- # # βœ… Combine and remove duplicates
263
- # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
264
-
265
- # # βœ… Clear and write back
266
- # sheet.clear()
267
- # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
268
-
269
- # except Exception as e:
270
- # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
271
- try:
272
- # Prepare as DataFrame
273
- df_new = pd.DataFrame(save_rows, columns=[
274
- "Sample ID", "Predicted Country", "Country Explanation",
275
- "Predicted Sample Type", "Sample Type Explanation",
276
- "Sources", "Query_cost", "Time cost"
277
- ])
278
-
279
- # βœ… Setup Google Sheets
280
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
281
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
282
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
283
- client = gspread.authorize(creds)
284
- spreadsheet = client.open("known_samples")
285
- sheet = spreadsheet.sheet1
286
-
287
- # βœ… Read existing data
288
- existing_data = sheet.get_all_values()
289
- if existing_data:
290
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
291
- else:
292
- df_old = pd.DataFrame(columns=[
293
- "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
294
- "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
295
- "Query_cost", "Sample Type Explanation", "Sources", "Time cost"
296
- ])
297
-
298
- # βœ… Index by Sample ID
299
- df_old.set_index("Sample ID", inplace=True)
300
- df_new.set_index("Sample ID", inplace=True)
301
-
302
- # βœ… Update only matching fields
303
- update_columns = [
304
- "Predicted Country", "Predicted Sample Type", "Country Explanation",
305
- "Sample Type Explanation", "Sources", "Query_cost", "Time cost"
306
- ]
307
- for idx, row in df_new.iterrows():
308
- if idx not in df_old.index:
309
- df_old.loc[idx] = "" # new row, fill empty first
310
- for col in update_columns:
311
- if pd.notna(row[col]) and row[col] != "":
312
- df_old.at[idx, col] = row[col]
313
-
314
- # βœ… Reset and write back
315
- df_old.reset_index(inplace=True)
316
- sheet.clear()
317
- sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
318
- print("βœ… Match results saved to known_samples.")
319
-
320
- except Exception as e:
321
- print(f"❌ Failed to update known_samples: {e}")
322
-
323
-
324
- return rows#, summary, labelAncient_Modern, explain_label
325
-
326
- # save the batch input in excel file
327
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
328
- # with pd.ExcelWriter(filename) as writer:
329
- # # Save table
330
- # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
331
- # df.to_excel(writer, sheet_name="Detailed Results", index=False)
332
- # try:
333
- # df_old = pd.read_excel(filename)
334
- # except:
335
- # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
336
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
337
- # # if os.path.exists(filename):
338
- # # df_old = pd.read_excel(filename)
339
- # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
340
- # # else:
341
- # # df_combined = df_new
342
- # df_combined.to_excel(filename, index=False)
343
- # # # Save summary
344
- # # summary_df = pd.DataFrame({"Summary": [summary_text]})
345
- # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
346
-
347
- # # # Save flag
348
- # # flag_df = pd.DataFrame({"Flag": [flag_text]})
349
- # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
350
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
351
- # df_new = pd.DataFrame(all_rows, columns=[
352
- # "Sample ID", "Predicted Country", "Country Explanation",
353
- # "Predicted Sample Type", "Sample Type Explanation",
354
- # "Sources", "Time cost"
355
- # ])
356
-
357
- # try:
358
- # if os.path.exists(filename):
359
- # df_old = pd.read_excel(filename)
360
- # else:
361
- # df_old = pd.DataFrame(columns=df_new.columns)
362
- # except Exception as e:
363
- # print(f"⚠️ Warning reading old Excel file: {e}")
364
- # df_old = pd.DataFrame(columns=df_new.columns)
365
-
366
- # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
367
- # df_old.set_index("Sample ID", inplace=True)
368
- # df_new.set_index("Sample ID", inplace=True)
369
-
370
- # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
371
-
372
- # df_combined = df_old.reset_index()
373
-
374
- # try:
375
- # df_combined.to_excel(filename, index=False)
376
- # except Exception as e:
377
- # print(f"❌ Failed to write Excel file {filename}: {e}")
378
- def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
379
- df_new = pd.DataFrame(all_rows, columns=[
380
- "Sample ID", "Predicted Country", "Country Explanation",
381
- "Predicted Sample Type", "Sample Type Explanation",
382
- "Sources", "Time cost"
383
- ])
384
-
385
- if is_resume and os.path.exists(filename):
386
- try:
387
- df_old = pd.read_excel(filename)
388
- except Exception as e:
389
- print(f"⚠️ Warning reading old Excel file: {e}")
390
- df_old = pd.DataFrame(columns=df_new.columns)
391
-
392
- # Set index and update existing rows
393
- df_old.set_index("Sample ID", inplace=True)
394
- df_new.set_index("Sample ID", inplace=True)
395
- df_old.update(df_new)
396
-
397
- df_combined = df_old.reset_index()
398
- else:
399
- # If not resuming or file doesn't exist, just use new rows
400
- df_combined = df_new
401
-
402
- try:
403
- df_combined.to_excel(filename, index=False)
404
- except Exception as e:
405
- print(f"❌ Failed to write Excel file {filename}: {e}")
406
-
407
-
408
- # save the batch input in JSON file
409
- def save_to_json(all_rows, summary_text, flag_text, filename):
410
- output_dict = {
411
- "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
412
- # "Summary_Text": summary_text,
413
- # "Ancient_Modern_Flag": flag_text
414
- }
415
-
416
- # If all_rows is a DataFrame, convert it
417
- if isinstance(all_rows, pd.DataFrame):
418
- output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
419
-
420
- with open(filename, "w") as external_file:
421
- json.dump(output_dict, external_file, indent=2)
422
-
423
- # save the batch input in Text file
424
- def save_to_txt(all_rows, summary_text, flag_text, filename):
425
- if isinstance(all_rows, pd.DataFrame):
426
- detailed_results = all_rows.to_dict(orient="records")
427
- output = ""
428
- #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
429
- output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
430
- for r in detailed_results:
431
- output += ",".join([str(v) for v in r.values()]) + "\n\n"
432
- with open(filename, "w") as f:
433
- f.write("=== Detailed Results ===\n")
434
- f.write(output + "\n")
435
-
436
- # f.write("\n=== Summary ===\n")
437
- # f.write(summary_text + "\n")
438
-
439
- # f.write("\n=== Ancient/Modern Flag ===\n")
440
- # f.write(flag_text + "\n")
441
-
442
- def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
443
- tmp_dir = tempfile.mkdtemp()
444
-
445
- #html_table = all_rows.value # assuming this is stored somewhere
446
-
447
- # Parse back to DataFrame
448
- #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
449
- all_rows = pd.read_html(StringIO(all_rows))[0]
450
- print(all_rows)
451
-
452
- if output_type == "Excel":
453
- file_path = f"{tmp_dir}/batch_output.xlsx"
454
- save_to_excel(all_rows, summary_text, flag_text, file_path)
455
- elif output_type == "JSON":
456
- file_path = f"{tmp_dir}/batch_output.json"
457
- save_to_json(all_rows, summary_text, flag_text, file_path)
458
- print("Done with JSON")
459
- elif output_type == "TXT":
460
- file_path = f"{tmp_dir}/batch_output.txt"
461
- save_to_txt(all_rows, summary_text, flag_text, file_path)
462
- else:
463
- return gr.update(visible=False) # invalid option
464
-
465
- return gr.update(value=file_path, visible=True)
466
- # save cost by checking the known outputs
467
-
468
- # def check_known_output(accession):
469
- # if not os.path.exists(KNOWN_OUTPUT_PATH):
470
- # return None
471
-
472
- # try:
473
- # df = pd.read_excel(KNOWN_OUTPUT_PATH)
474
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
475
- # if match:
476
- # accession = match.group(0)
477
-
478
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
479
- # if not matched.empty:
480
- # return matched.iloc[0].to_dict() # Return the cached row
481
- # except Exception as e:
482
- # print(f"⚠️ Failed to load known samples: {e}")
483
- # return None
484
-
485
- # def check_known_output(accession):
486
- # try:
487
- # # βœ… Load credentials from Hugging Face secret
488
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
489
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
490
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
491
- # client = gspread.authorize(creds)
492
-
493
- # # βœ… Open the known_samples sheet
494
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
495
- # sheet = spreadsheet.sheet1
496
-
497
- # # βœ… Read all rows
498
- # data = sheet.get_all_values()
499
- # if not data:
500
- # return None
501
-
502
- # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
503
-
504
- # # βœ… Normalize accession pattern
505
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
506
- # if match:
507
- # accession = match.group(0)
508
-
509
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
510
- # if not matched.empty:
511
- # return matched.iloc[0].to_dict()
512
-
513
- # except Exception as e:
514
- # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
515
- # return None
516
- def check_known_output(accession):
517
- try:
518
- # βœ… Load credentials from Hugging Face secret
519
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
520
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
521
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
522
- client = gspread.authorize(creds)
523
-
524
- spreadsheet = client.open("known_samples")
525
- sheet = spreadsheet.sheet1
526
-
527
- data = sheet.get_all_values()
528
- if not data:
529
- print("⚠️ Google Sheet 'known_samples' is empty.")
530
- return None
531
-
532
- df = pd.DataFrame(data[1:], columns=data[0])
533
- if "Sample ID" not in df.columns:
534
- print("❌ Column 'Sample ID' not found in Google Sheet.")
535
- return None
536
-
537
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
538
- if match:
539
- accession = match.group(0)
540
-
541
- matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
542
- if not matched.empty:
543
- #return matched.iloc[0].to_dict()
544
- row = matched.iloc[0]
545
- country = row.get("Predicted Country", "").strip().lower()
546
- sample_type = row.get("Predicted Sample Type", "").strip().lower()
547
-
548
- if country and country != "unknown" and sample_type and sample_type != "unknown":
549
- return row.to_dict()
550
- else:
551
- print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
552
- return None
553
- else:
554
- print(f"πŸ” Accession {accession} not found in known_samples.")
555
- return None
556
-
557
- except Exception as e:
558
- import traceback
559
- print("❌ Exception occurred during check_known_output:")
560
- traceback.print_exc()
561
- return None
562
-
563
-
564
- def hash_user_id(user_input):
565
- return hashlib.sha256(user_input.encode()).hexdigest()
566
-
567
- # βœ… Load and save usage count
568
-
569
- # def load_user_usage():
570
- # if not os.path.exists(USER_USAGE_TRACK_FILE):
571
- # return {}
572
-
573
- # try:
574
- # with open(USER_USAGE_TRACK_FILE, "r") as f:
575
- # content = f.read().strip()
576
- # if not content:
577
- # return {} # file is empty
578
- # return json.loads(content)
579
- # except (json.JSONDecodeError, ValueError):
580
- # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
581
- # return {} # fallback to empty dict
582
- # def load_user_usage():
583
- # try:
584
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
585
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
586
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
587
- # client = gspread.authorize(creds)
588
-
589
- # sheet = client.open("user_usage_log").sheet1
590
- # data = sheet.get_all_records() # Assumes columns: email, usage_count
591
-
592
- # usage = {}
593
- # for row in data:
594
- # email = row.get("email", "").strip().lower()
595
- # count = int(row.get("usage_count", 0))
596
- # if email:
597
- # usage[email] = count
598
- # return usage
599
- # except Exception as e:
600
- # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
601
- # return {}
602
- # def load_user_usage():
603
- # try:
604
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
605
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
606
-
607
- # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
608
- # if not found:
609
- # return {} # not found, start fresh
610
-
611
- # #file_id = found[0]["id"]
612
- # file_id = found
613
- # content = pipeline.download_drive_file_content(file_id)
614
- # return json.loads(content.strip()) if content.strip() else {}
615
-
616
- # except Exception as e:
617
- # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
618
- # return {}
619
- def load_user_usage():
620
- try:
621
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
622
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
623
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
624
- client = gspread.authorize(creds)
625
-
626
- sheet = client.open("user_usage_log").sheet1
627
- data = sheet.get_all_values()
628
- print("data: ", data)
629
- print("πŸ§ͺ Raw header row from sheet:", data[0])
630
- print("πŸ§ͺ Character codes in each header:")
631
- for h in data[0]:
632
- print([ord(c) for c in h])
633
-
634
- if not data or len(data) < 2:
635
- print("⚠️ Sheet is empty or missing rows.")
636
- return {}
637
-
638
- headers = [h.strip().lower() for h in data[0]]
639
- if "email" not in headers or "usage_count" not in headers:
640
- print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
641
- return {}
642
-
643
- permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
644
- df = pd.DataFrame(data[1:], columns=headers)
645
-
646
- usage = {}
647
- permitted = {}
648
- for _, row in df.iterrows():
649
- email = row.get("email", "").strip().lower()
650
- try:
651
- #count = int(row.get("usage_count", 0))
652
- try:
653
- count = int(float(row.get("usage_count", 0)))
654
- except Exception:
655
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
656
- count = 0
657
-
658
- if email:
659
- usage[email] = count
660
- if permitted_index is not None:
661
- try:
662
- permitted_count = int(float(row.get("permitted_samples", 50)))
663
- permitted[email] = permitted_count
664
- except:
665
- permitted[email] = 50
666
-
667
- except ValueError:
668
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
669
- return usage, permitted
670
-
671
- except Exception as e:
672
- print(f"❌ Error in load_user_usage: {e}")
673
- return {}, {}
674
-
675
-
676
-
677
- # def save_user_usage(usage):
678
- # with open(USER_USAGE_TRACK_FILE, "w") as f:
679
- # json.dump(usage, f, indent=2)
680
-
681
- # def save_user_usage(usage_dict):
682
- # try:
683
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
684
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
685
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
686
- # client = gspread.authorize(creds)
687
-
688
- # sheet = client.open("user_usage_log").sheet1
689
- # sheet.clear() # clear old contents first
690
-
691
- # # Write header + rows
692
- # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
693
- # sheet.update(rows)
694
- # except Exception as e:
695
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
696
- # def save_user_usage(usage_dict):
697
- # try:
698
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
699
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
700
-
701
- # import tempfile
702
- # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
703
- # print("πŸ’Ύ Saving this usage dict:", usage_dict)
704
- # with open(tmp_path, "w") as f:
705
- # json.dump(usage_dict, f, indent=2)
706
-
707
- # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
708
-
709
- # except Exception as e:
710
- # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
711
- # def save_user_usage(usage_dict):
712
- # try:
713
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
714
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
715
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
716
- # client = gspread.authorize(creds)
717
-
718
- # spreadsheet = client.open("user_usage_log")
719
- # sheet = spreadsheet.sheet1
720
-
721
- # # Step 1: Convert new usage to DataFrame
722
- # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
723
- # df_new["email"] = df_new["email"].str.strip().str.lower()
724
-
725
- # # Step 2: Load existing data
726
- # existing_data = sheet.get_all_values()
727
- # print("πŸ§ͺ Sheet existing_data:", existing_data)
728
-
729
- # # Try to load old data
730
- # if existing_data and len(existing_data[0]) >= 1:
731
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
732
-
733
- # # Fix missing columns
734
- # if "email" not in df_old.columns:
735
- # df_old["email"] = ""
736
- # if "usage_count" not in df_old.columns:
737
- # df_old["usage_count"] = 0
738
-
739
- # df_old["email"] = df_old["email"].str.strip().str.lower()
740
- # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
741
- # else:
742
- # df_old = pd.DataFrame(columns=["email", "usage_count"])
743
-
744
- # # Step 3: Merge
745
- # df_combined = pd.concat([df_old, df_new], ignore_index=True)
746
- # df_combined = df_combined.groupby("email", as_index=False).sum()
747
-
748
- # # Step 4: Write back
749
- # sheet.clear()
750
- # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
751
- # print("βœ… Saved user usage to user_usage_log sheet.")
752
-
753
- # except Exception as e:
754
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
755
- def save_user_usage(usage_dict):
756
- try:
757
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
758
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
759
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
760
- client = gspread.authorize(creds)
761
-
762
- spreadsheet = client.open("user_usage_log")
763
- sheet = spreadsheet.sheet1
764
-
765
- # Build new df
766
- df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
767
- df_new["email"] = df_new["email"].str.strip().str.lower()
768
- df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
769
-
770
- # Read existing data
771
- existing_data = sheet.get_all_values()
772
- if existing_data and len(existing_data[0]) >= 2:
773
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
774
- df_old["email"] = df_old["email"].str.strip().str.lower()
775
- df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
776
- else:
777
- df_old = pd.DataFrame(columns=["email", "usage_count"])
778
-
779
- # βœ… Overwrite specific emails only
780
- df_old = df_old.set_index("email")
781
- for email, count in usage_dict.items():
782
- email = email.strip().lower()
783
- df_old.loc[email, "usage_count"] = count
784
- df_old = df_old.reset_index()
785
-
786
- # Save
787
- sheet.clear()
788
- sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
789
- print("βœ… Saved user usage to user_usage_log sheet.")
790
-
791
- except Exception as e:
792
- print(f"❌ Failed to save user usage to Google Sheets: {e}")
793
-
794
-
795
-
796
-
797
- # def increment_usage(user_id, num_samples=1):
798
- # usage = load_user_usage()
799
- # if user_id not in usage:
800
- # usage[user_id] = 0
801
- # usage[user_id] += num_samples
802
- # save_user_usage(usage)
803
- # return usage[user_id]
804
- # def increment_usage(email: str, count: int):
805
- # usage = load_user_usage()
806
- # email_key = email.strip().lower()
807
- # usage[email_key] = usage.get(email_key, 0) + count
808
- # save_user_usage(usage)
809
- # return usage[email_key]
810
- def increment_usage(email: str, count: int = 1):
811
- usage, permitted = load_user_usage()
812
- email_key = email.strip().lower()
813
- #usage[email_key] = usage.get(email_key, 0) + count
814
- current = usage.get(email_key, 0)
815
- new_value = current + count
816
- max_allowed = permitted.get(email_key) or 50
817
- usage[email_key] = max(current, new_value) # βœ… Prevent overwrite with lower
818
- print(f"πŸ§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
819
- print("max allow is: ", max_allowed)
820
- save_user_usage(usage)
821
- return usage[email_key], max_allowed
822
-
823
-
824
- # run the batch
825
- def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
826
- stop_flag=None, output_file_path=None,
827
- limited_acc=50, yield_callback=None):
828
- if user_email:
829
- limited_acc += 10
830
- accessions, error = extract_accessions_from_input(file, raw_text)
831
- if error:
832
- #return [], "", "", f"Error: {error}"
833
- return [], f"Error: {error}", 0, "", ""
834
- if resume_file:
835
- accessions = get_incomplete_accessions(resume_file)
836
- tmp_dir = tempfile.mkdtemp()
837
- if not output_file_path:
838
- if resume_file:
839
- output_file_path = os.path.join(tmp_dir, resume_file)
840
- else:
841
- output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
842
-
843
- all_rows = []
844
- # all_summaries = []
845
- # all_flags = []
846
- progress_lines = []
847
- warning = ""
848
- if len(accessions) > limited_acc:
849
- accessions = accessions[:limited_acc]
850
- warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
851
- for i, acc in enumerate(accessions):
852
- if stop_flag and stop_flag.value:
853
- line = f"πŸ›‘ Stopped at {acc} ({i+1}/{len(accessions)})"
854
- progress_lines.append(line)
855
- if yield_callback:
856
- yield_callback(line)
857
- print("πŸ›‘ User requested stop.")
858
- break
859
- print(f"[{i+1}/{len(accessions)}] Processing {acc}")
860
- try:
861
- # rows, summary, label, explain = summarize_results(acc)
862
- rows = summarize_results(acc)
863
- all_rows.extend(rows)
864
- # all_summaries.append(f"**{acc}**\n{summary}")
865
- # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
866
- #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
867
- save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
868
- line = f"βœ… Processed {acc} ({i+1}/{len(accessions)})"
869
- progress_lines.append(line)
870
- if yield_callback:
871
- yield_callback(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
872
- except Exception as e:
873
- print(f"❌ Failed to process {acc}: {e}")
874
- continue
875
- #all_summaries.append(f"**{acc}**: Failed - {e}")
876
- #progress_lines.append(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
877
- limited_acc -= 1
878
- """for row in all_rows:
879
- source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
880
-
881
- if source_column.startswith("http"): # Check if the source is a URL
882
- # Wrap it with HTML anchor tags to make it clickable
883
- row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
884
- if not warning:
885
- warning = f"You only have {limited_acc} left"
886
- if user_email.strip():
887
- user_hash = hash_user_id(user_email)
888
- total_queries = increment_usage(user_hash, len(all_rows))
889
- else:
890
- total_queries = 0
891
- yield_callback("βœ… Finished!")
892
-
893
- # summary_text = "\n\n---\n\n".join(all_summaries)
894
- # flag_text = "\n\n---\n\n".join(all_flags)
895
- #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
896
- #return all_rows, gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
897
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
 
1
+ import gradio as gr
2
+ from collections import Counter
3
+ import csv
4
+ import os
5
+ from functools import lru_cache
6
+ #import app
7
+ from mtdna_classifier import classify_sample_location
8
+ import data_preprocess, model, pipeline
9
+ import subprocess
10
+ import json
11
+ import pandas as pd
12
+ import io
13
+ import re
14
+ import tempfile
15
+ import gspread
16
+ from oauth2client.service_account import ServiceAccountCredentials
17
+ from io import StringIO
18
+ import hashlib
19
+ import threading
20
+
21
+ # @lru_cache(maxsize=3600)
22
+ # def classify_sample_location_cached(accession):
23
+ # return classify_sample_location(accession)
24
+
25
+ @lru_cache(maxsize=3600)
26
+ def pipeline_classify_sample_location_cached(accession,stop_flag):
27
+ print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
+ if stop_flag is not None and stop_flag.value:
29
+ print(f"πŸ›‘ Skipped {accession} mid-pipeline.")
30
+ return []
31
+ return pipeline.pipeline_with_gemini([accession],stop_flag)
32
+
33
+ # Count and suggest final location
34
+ # def compute_final_suggested_location(rows):
35
+ # candidates = [
36
+ # row.get("Predicted Location", "").strip()
37
+ # for row in rows
38
+ # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
39
+ # ] + [
40
+ # row.get("Inferred Region", "").strip()
41
+ # for row in rows
42
+ # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
43
+ # ]
44
+
45
+ # if not candidates:
46
+ # return Counter(), ("Unknown", 0)
47
+ # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
48
+ # tokens = []
49
+ # for item in candidates:
50
+ # # Split by comma, whitespace, and newlines
51
+ # parts = re.split(r'[\s,]+', item)
52
+ # tokens.extend(parts)
53
+
54
+ # # Step 2: Clean and normalize tokens
55
+ # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
56
+
57
+ # # Step 3: Count
58
+ # counts = Counter(tokens)
59
+
60
+ # # Step 4: Get most common
61
+ # top_location, count = counts.most_common(1)[0]
62
+ # return counts, (top_location, count)
63
+
64
+ # Store feedback (with required fields)
65
+
66
+ def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
67
+ if not answer1.strip() or not answer2.strip():
68
+ return "⚠️ Please answer both questions before submitting."
69
+
70
+ try:
71
+ # βœ… Step: Load credentials from Hugging Face secret
72
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
73
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
74
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
75
+
76
+ # Connect to Google Sheet
77
+ client = gspread.authorize(creds)
78
+ sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
79
+
80
+ # Append feedback
81
+ sheet.append_row([accession, answer1, answer2, contact])
82
+ return "βœ… Feedback submitted. Thank you!"
83
+
84
+ except Exception as e:
85
+ return f"❌ Error submitting feedback: {e}"
86
+
87
+ # helper function to extract accessions
88
+ def extract_accessions_from_input(file=None, raw_text=""):
89
+ print(f"RAW TEXT RECEIVED: {raw_text}")
90
+ accessions = []
91
+ seen = set()
92
+ if file:
93
+ try:
94
+ if file.name.endswith(".csv"):
95
+ df = pd.read_csv(file)
96
+ elif file.name.endswith(".xlsx"):
97
+ df = pd.read_excel(file)
98
+ else:
99
+ return [], "Unsupported file format. Please upload CSV or Excel."
100
+ for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
101
+ if acc not in seen:
102
+ accessions.append(acc)
103
+ seen.add(acc)
104
+ except Exception as e:
105
+ return [], f"Failed to read file: {e}"
106
+
107
+ if raw_text:
108
+ text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
109
+ for acc in text_ids:
110
+ if acc not in seen:
111
+ accessions.append(acc)
112
+ seen.add(acc)
113
+
114
+ return list(accessions), None
115
+ # βœ… Add a new helper to backend: `filter_unprocessed_accessions()`
116
+ def get_incomplete_accessions(file_path):
117
+ df = pd.read_excel(file_path)
118
+
119
+ incomplete_accessions = []
120
+ for _, row in df.iterrows():
121
+ sample_id = str(row.get("Sample ID", "")).strip()
122
+
123
+ # Skip if no sample ID
124
+ if not sample_id:
125
+ continue
126
+
127
+ # Drop the Sample ID and check if the rest is empty
128
+ other_cols = row.drop(labels=["Sample ID"], errors="ignore")
129
+ if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
130
+ # Extract the accession number from the sample ID using regex
131
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
132
+ if match:
133
+ incomplete_accessions.append(match.group(0))
134
+ print(len(incomplete_accessions))
135
+ return incomplete_accessions
136
+
137
+ # GOOGLE_SHEET_NAME = "known_samples"
138
+ # USAGE_DRIVE_FILENAME = "user_usage_log.json"
139
+
140
+ def summarize_results(accession, stop_flag=None):
141
+ # Early bail
142
+ if stop_flag is not None and stop_flag.value:
143
+ print(f"πŸ›‘ Skipping {accession} before starting.")
144
+ return []
145
+ # try cache first
146
+ cached = check_known_output(accession)
147
+ if cached:
148
+ print(f"βœ… Using cached result for {accession}")
149
+ return [[
150
+ cached["Sample ID"] or "unknown",
151
+ cached["Predicted Country"] or "unknown",
152
+ cached["Country Explanation"] or "unknown",
153
+ cached["Predicted Sample Type"] or "unknown",
154
+ cached["Sample Type Explanation"] or "unknown",
155
+ cached["Sources"] or "No Links",
156
+ cached["Time cost"]
157
+ ]]
158
+ # only run when nothing in the cache
159
+ try:
160
+ print("try gemini pipeline: ",accession)
161
+ outputs = pipeline_classify_sample_location_cached(accession, stop_flag)
162
+ if stop_flag is not None and stop_flag.value:
163
+ print(f"πŸ›‘ Skipped {accession} mid-pipeline.")
164
+ return []
165
+ # outputs = {'KU131308': {'isolate':'BRU18',
166
+ # 'country': {'brunei': ['ncbi',
167
+ # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
168
+ # 'sample_type': {'modern':
169
+ # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
170
+ # 'query_cost': 9.754999999999999e-05,
171
+ # 'time_cost': '24.776 seconds',
172
+ # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
173
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
174
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
175
+ except Exception as e:
176
+ return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
177
+
178
+ if accession not in outputs:
179
+ print("no accession in output ", accession)
180
+ return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
181
+
182
+ row_score = []
183
+ rows = []
184
+ save_rows = []
185
+ for key in outputs:
186
+ pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
187
+ for section, results in outputs[key].items():
188
+ if section == "country" or section =="sample_type":
189
+ pred_output = []#"\n".join(list(results.keys()))
190
+ output_explanation = ""
191
+ for result, content in results.items():
192
+ if len(result) == 0: result = "unknown"
193
+ if len(content) == 0: output_explanation = "unknown"
194
+ else:
195
+ output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
196
+ pred_output.append(result)
197
+ pred_output = "\n".join(pred_output)
198
+ if section == "country":
199
+ pred_country, country_explanation = pred_output, output_explanation
200
+ elif section == "sample_type":
201
+ pred_sample, sample_explanation = pred_output, output_explanation
202
+ if outputs[key]["isolate"].lower()!="unknown":
203
+ label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
204
+ else: label = key
205
+ if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
206
+ row = {
207
+ "Sample ID": label or "unknown",
208
+ "Predicted Country": pred_country or "unknown",
209
+ "Country Explanation": country_explanation or "unknown",
210
+ "Predicted Sample Type":pred_sample or "unknown",
211
+ "Sample Type Explanation":sample_explanation or "unknown",
212
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
213
+ "Time cost": outputs[key]["time_cost"]
214
+ }
215
+ #row_score.append(row)
216
+ rows.append(list(row.values()))
217
+
218
+ save_row = {
219
+ "Sample ID": label or "unknown",
220
+ "Predicted Country": pred_country or "unknown",
221
+ "Country Explanation": country_explanation or "unknown",
222
+ "Predicted Sample Type":pred_sample or "unknown",
223
+ "Sample Type Explanation":sample_explanation or "unknown",
224
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
225
+ "Query_cost": outputs[key]["query_cost"],
226
+ "Time cost": outputs[key]["time_cost"]
227
+ }
228
+ #row_score.append(row)
229
+ save_rows.append(list(save_row.values()))
230
+
231
+ # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
232
+ # summary_lines = [f"### 🧭 Location Summary:\n"]
233
+ # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
234
+ # summary_lines.append(f"\n**Final Suggested Location:** πŸ—ΊοΈ **{final_location}** (mentioned {count} times)")
235
+ # summary = "\n".join(summary_lines)
236
+
237
+ # save the new running sample to known excel file
238
+ # try:
239
+ # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
240
+ # if os.path.exists(KNOWN_OUTPUT_PATH):
241
+ # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
242
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
243
+ # else:
244
+ # df_combined = df_new
245
+ # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
246
+ # except Exception as e:
247
+ # print(f"⚠️ Failed to save known output: {e}")
248
+ # try:
249
+ # df_new = pd.DataFrame(save_rows, columns=[
250
+ # "Sample ID", "Predicted Country", "Country Explanation",
251
+ # "Predicted Sample Type", "Sample Type Explanation",
252
+ # "Sources", "Query_cost", "Time cost"
253
+ # ])
254
+
255
+ # # βœ… Google Sheets API setup
256
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
257
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
258
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
259
+ # client = gspread.authorize(creds)
260
+
261
+ # # βœ… Open the known_samples sheet
262
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
263
+ # sheet = spreadsheet.sheet1
264
+
265
+ # # βœ… Read old data
266
+ # existing_data = sheet.get_all_values()
267
+ # if existing_data:
268
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
269
+ # else:
270
+ # df_old = pd.DataFrame(columns=df_new.columns)
271
+
272
+ # # βœ… Combine and remove duplicates
273
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
274
+
275
+ # # βœ… Clear and write back
276
+ # sheet.clear()
277
+ # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
278
+
279
+ # except Exception as e:
280
+ # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
281
+ try:
282
+ # Prepare as DataFrame
283
+ df_new = pd.DataFrame(save_rows, columns=[
284
+ "Sample ID", "Predicted Country", "Country Explanation",
285
+ "Predicted Sample Type", "Sample Type Explanation",
286
+ "Sources", "Query_cost", "Time cost"
287
+ ])
288
+
289
+ # βœ… Setup Google Sheets
290
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
291
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
292
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
293
+ client = gspread.authorize(creds)
294
+ spreadsheet = client.open("known_samples")
295
+ sheet = spreadsheet.sheet1
296
+
297
+ # βœ… Read existing data
298
+ existing_data = sheet.get_all_values()
299
+ if existing_data:
300
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
301
+ else:
302
+ df_old = pd.DataFrame(columns=[
303
+ "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
304
+ "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
305
+ "Query_cost", "Sample Type Explanation", "Sources", "Time cost"
306
+ ])
307
+
308
+ # βœ… Index by Sample ID
309
+ df_old.set_index("Sample ID", inplace=True)
310
+ df_new.set_index("Sample ID", inplace=True)
311
+
312
+ # βœ… Update only matching fields
313
+ update_columns = [
314
+ "Predicted Country", "Predicted Sample Type", "Country Explanation",
315
+ "Sample Type Explanation", "Sources", "Query_cost", "Time cost"
316
+ ]
317
+ for idx, row in df_new.iterrows():
318
+ if idx not in df_old.index:
319
+ df_old.loc[idx] = "" # new row, fill empty first
320
+ for col in update_columns:
321
+ if pd.notna(row[col]) and row[col] != "":
322
+ df_old.at[idx, col] = row[col]
323
+
324
+ # βœ… Reset and write back
325
+ df_old.reset_index(inplace=True)
326
+ sheet.clear()
327
+ sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
328
+ print("βœ… Match results saved to known_samples.")
329
+
330
+ except Exception as e:
331
+ print(f"❌ Failed to update known_samples: {e}")
332
+
333
+
334
+ return rows#, summary, labelAncient_Modern, explain_label
335
+
336
+ # save the batch input in excel file
337
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
338
+ # with pd.ExcelWriter(filename) as writer:
339
+ # # Save table
340
+ # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
341
+ # df.to_excel(writer, sheet_name="Detailed Results", index=False)
342
+ # try:
343
+ # df_old = pd.read_excel(filename)
344
+ # except:
345
+ # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
346
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
347
+ # # if os.path.exists(filename):
348
+ # # df_old = pd.read_excel(filename)
349
+ # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
350
+ # # else:
351
+ # # df_combined = df_new
352
+ # df_combined.to_excel(filename, index=False)
353
+ # # # Save summary
354
+ # # summary_df = pd.DataFrame({"Summary": [summary_text]})
355
+ # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
356
+
357
+ # # # Save flag
358
+ # # flag_df = pd.DataFrame({"Flag": [flag_text]})
359
+ # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
360
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
361
+ # df_new = pd.DataFrame(all_rows, columns=[
362
+ # "Sample ID", "Predicted Country", "Country Explanation",
363
+ # "Predicted Sample Type", "Sample Type Explanation",
364
+ # "Sources", "Time cost"
365
+ # ])
366
+
367
+ # try:
368
+ # if os.path.exists(filename):
369
+ # df_old = pd.read_excel(filename)
370
+ # else:
371
+ # df_old = pd.DataFrame(columns=df_new.columns)
372
+ # except Exception as e:
373
+ # print(f"⚠️ Warning reading old Excel file: {e}")
374
+ # df_old = pd.DataFrame(columns=df_new.columns)
375
+
376
+ # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
377
+ # df_old.set_index("Sample ID", inplace=True)
378
+ # df_new.set_index("Sample ID", inplace=True)
379
+
380
+ # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
381
+
382
+ # df_combined = df_old.reset_index()
383
+
384
+ # try:
385
+ # df_combined.to_excel(filename, index=False)
386
+ # except Exception as e:
387
+ # print(f"❌ Failed to write Excel file {filename}: {e}")
388
+ def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
389
+ df_new = pd.DataFrame(all_rows, columns=[
390
+ "Sample ID", "Predicted Country", "Country Explanation",
391
+ "Predicted Sample Type", "Sample Type Explanation",
392
+ "Sources", "Time cost"
393
+ ])
394
+
395
+ if is_resume and os.path.exists(filename):
396
+ try:
397
+ df_old = pd.read_excel(filename)
398
+ except Exception as e:
399
+ print(f"⚠️ Warning reading old Excel file: {e}")
400
+ df_old = pd.DataFrame(columns=df_new.columns)
401
+
402
+ # Set index and update existing rows
403
+ df_old.set_index("Sample ID", inplace=True)
404
+ df_new.set_index("Sample ID", inplace=True)
405
+ df_old.update(df_new)
406
+
407
+ df_combined = df_old.reset_index()
408
+ else:
409
+ # If not resuming or file doesn't exist, just use new rows
410
+ df_combined = df_new
411
+
412
+ try:
413
+ df_combined.to_excel(filename, index=False)
414
+ except Exception as e:
415
+ print(f"❌ Failed to write Excel file {filename}: {e}")
416
+
417
+
418
+ # save the batch input in JSON file
419
+ def save_to_json(all_rows, summary_text, flag_text, filename):
420
+ output_dict = {
421
+ "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
422
+ # "Summary_Text": summary_text,
423
+ # "Ancient_Modern_Flag": flag_text
424
+ }
425
+
426
+ # If all_rows is a DataFrame, convert it
427
+ if isinstance(all_rows, pd.DataFrame):
428
+ output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
429
+
430
+ with open(filename, "w") as external_file:
431
+ json.dump(output_dict, external_file, indent=2)
432
+
433
+ # save the batch input in Text file
434
+ def save_to_txt(all_rows, summary_text, flag_text, filename):
435
+ if isinstance(all_rows, pd.DataFrame):
436
+ detailed_results = all_rows.to_dict(orient="records")
437
+ output = ""
438
+ #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
439
+ output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
440
+ for r in detailed_results:
441
+ output += ",".join([str(v) for v in r.values()]) + "\n\n"
442
+ with open(filename, "w") as f:
443
+ f.write("=== Detailed Results ===\n")
444
+ f.write(output + "\n")
445
+
446
+ # f.write("\n=== Summary ===\n")
447
+ # f.write(summary_text + "\n")
448
+
449
+ # f.write("\n=== Ancient/Modern Flag ===\n")
450
+ # f.write(flag_text + "\n")
451
+
452
+ def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
453
+ tmp_dir = tempfile.mkdtemp()
454
+
455
+ #html_table = all_rows.value # assuming this is stored somewhere
456
+
457
+ # Parse back to DataFrame
458
+ #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
459
+ all_rows = pd.read_html(StringIO(all_rows))[0]
460
+ print(all_rows)
461
+
462
+ if output_type == "Excel":
463
+ file_path = f"{tmp_dir}/batch_output.xlsx"
464
+ save_to_excel(all_rows, summary_text, flag_text, file_path)
465
+ elif output_type == "JSON":
466
+ file_path = f"{tmp_dir}/batch_output.json"
467
+ save_to_json(all_rows, summary_text, flag_text, file_path)
468
+ print("Done with JSON")
469
+ elif output_type == "TXT":
470
+ file_path = f"{tmp_dir}/batch_output.txt"
471
+ save_to_txt(all_rows, summary_text, flag_text, file_path)
472
+ else:
473
+ return gr.update(visible=False) # invalid option
474
+
475
+ return gr.update(value=file_path, visible=True)
476
+ # save cost by checking the known outputs
477
+
478
+ # def check_known_output(accession):
479
+ # if not os.path.exists(KNOWN_OUTPUT_PATH):
480
+ # return None
481
+
482
+ # try:
483
+ # df = pd.read_excel(KNOWN_OUTPUT_PATH)
484
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
485
+ # if match:
486
+ # accession = match.group(0)
487
+
488
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
489
+ # if not matched.empty:
490
+ # return matched.iloc[0].to_dict() # Return the cached row
491
+ # except Exception as e:
492
+ # print(f"⚠️ Failed to load known samples: {e}")
493
+ # return None
494
+
495
+ # def check_known_output(accession):
496
+ # try:
497
+ # # βœ… Load credentials from Hugging Face secret
498
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
499
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
500
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
501
+ # client = gspread.authorize(creds)
502
+
503
+ # # βœ… Open the known_samples sheet
504
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
505
+ # sheet = spreadsheet.sheet1
506
+
507
+ # # βœ… Read all rows
508
+ # data = sheet.get_all_values()
509
+ # if not data:
510
+ # return None
511
+
512
+ # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
513
+
514
+ # # βœ… Normalize accession pattern
515
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
516
+ # if match:
517
+ # accession = match.group(0)
518
+
519
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
520
+ # if not matched.empty:
521
+ # return matched.iloc[0].to_dict()
522
+
523
+ # except Exception as e:
524
+ # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
525
+ # return None
526
+ def check_known_output(accession):
527
+ try:
528
+ # βœ… Load credentials from Hugging Face secret
529
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
530
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
531
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
532
+ client = gspread.authorize(creds)
533
+
534
+ spreadsheet = client.open("known_samples")
535
+ sheet = spreadsheet.sheet1
536
+
537
+ data = sheet.get_all_values()
538
+ if not data:
539
+ print("⚠️ Google Sheet 'known_samples' is empty.")
540
+ return None
541
+
542
+ df = pd.DataFrame(data[1:], columns=data[0])
543
+ if "Sample ID" not in df.columns:
544
+ print("❌ Column 'Sample ID' not found in Google Sheet.")
545
+ return None
546
+
547
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
548
+ if match:
549
+ accession = match.group(0)
550
+
551
+ matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
552
+ if not matched.empty:
553
+ #return matched.iloc[0].to_dict()
554
+ row = matched.iloc[0]
555
+ country = row.get("Predicted Country", "").strip().lower()
556
+ sample_type = row.get("Predicted Sample Type", "").strip().lower()
557
+
558
+ if country and country != "unknown" and sample_type and sample_type != "unknown":
559
+ return row.to_dict()
560
+ else:
561
+ print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
562
+ return None
563
+ else:
564
+ print(f"πŸ” Accession {accession} not found in known_samples.")
565
+ return None
566
+
567
+ except Exception as e:
568
+ import traceback
569
+ print("❌ Exception occurred during check_known_output:")
570
+ traceback.print_exc()
571
+ return None
572
+
573
+
574
+ def hash_user_id(user_input):
575
+ return hashlib.sha256(user_input.encode()).hexdigest()
576
+
577
+ # βœ… Load and save usage count
578
+
579
+ # def load_user_usage():
580
+ # if not os.path.exists(USER_USAGE_TRACK_FILE):
581
+ # return {}
582
+
583
+ # try:
584
+ # with open(USER_USAGE_TRACK_FILE, "r") as f:
585
+ # content = f.read().strip()
586
+ # if not content:
587
+ # return {} # file is empty
588
+ # return json.loads(content)
589
+ # except (json.JSONDecodeError, ValueError):
590
+ # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
591
+ # return {} # fallback to empty dict
592
+ # def load_user_usage():
593
+ # try:
594
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
595
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
596
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
597
+ # client = gspread.authorize(creds)
598
+
599
+ # sheet = client.open("user_usage_log").sheet1
600
+ # data = sheet.get_all_records() # Assumes columns: email, usage_count
601
+
602
+ # usage = {}
603
+ # for row in data:
604
+ # email = row.get("email", "").strip().lower()
605
+ # count = int(row.get("usage_count", 0))
606
+ # if email:
607
+ # usage[email] = count
608
+ # return usage
609
+ # except Exception as e:
610
+ # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
611
+ # return {}
612
+ # def load_user_usage():
613
+ # try:
614
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
615
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
616
+
617
+ # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
618
+ # if not found:
619
+ # return {} # not found, start fresh
620
+
621
+ # #file_id = found[0]["id"]
622
+ # file_id = found
623
+ # content = pipeline.download_drive_file_content(file_id)
624
+ # return json.loads(content.strip()) if content.strip() else {}
625
+
626
+ # except Exception as e:
627
+ # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
628
+ # return {}
629
+ def load_user_usage():
630
+ try:
631
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
632
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
633
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
634
+ client = gspread.authorize(creds)
635
+
636
+ sheet = client.open("user_usage_log").sheet1
637
+ data = sheet.get_all_values()
638
+ print("data: ", data)
639
+ print("πŸ§ͺ Raw header row from sheet:", data[0])
640
+ print("πŸ§ͺ Character codes in each header:")
641
+ for h in data[0]:
642
+ print([ord(c) for c in h])
643
+
644
+ if not data or len(data) < 2:
645
+ print("⚠️ Sheet is empty or missing rows.")
646
+ return {}
647
+
648
+ headers = [h.strip().lower() for h in data[0]]
649
+ if "email" not in headers or "usage_count" not in headers:
650
+ print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
651
+ return {}
652
+
653
+ permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
654
+ df = pd.DataFrame(data[1:], columns=headers)
655
+
656
+ usage = {}
657
+ permitted = {}
658
+ for _, row in df.iterrows():
659
+ email = row.get("email", "").strip().lower()
660
+ try:
661
+ #count = int(row.get("usage_count", 0))
662
+ try:
663
+ count = int(float(row.get("usage_count", 0)))
664
+ except Exception:
665
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
666
+ count = 0
667
+
668
+ if email:
669
+ usage[email] = count
670
+ if permitted_index is not None:
671
+ try:
672
+ permitted_count = int(float(row.get("permitted_samples", 50)))
673
+ permitted[email] = permitted_count
674
+ except:
675
+ permitted[email] = 50
676
+
677
+ except ValueError:
678
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
679
+ return usage, permitted
680
+
681
+ except Exception as e:
682
+ print(f"❌ Error in load_user_usage: {e}")
683
+ return {}, {}
684
+
685
+
686
+
687
+ # def save_user_usage(usage):
688
+ # with open(USER_USAGE_TRACK_FILE, "w") as f:
689
+ # json.dump(usage, f, indent=2)
690
+
691
+ # def save_user_usage(usage_dict):
692
+ # try:
693
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
694
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
695
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
696
+ # client = gspread.authorize(creds)
697
+
698
+ # sheet = client.open("user_usage_log").sheet1
699
+ # sheet.clear() # clear old contents first
700
+
701
+ # # Write header + rows
702
+ # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
703
+ # sheet.update(rows)
704
+ # except Exception as e:
705
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
706
+ # def save_user_usage(usage_dict):
707
+ # try:
708
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
709
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
710
+
711
+ # import tempfile
712
+ # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
713
+ # print("πŸ’Ύ Saving this usage dict:", usage_dict)
714
+ # with open(tmp_path, "w") as f:
715
+ # json.dump(usage_dict, f, indent=2)
716
+
717
+ # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
718
+
719
+ # except Exception as e:
720
+ # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
721
+ # def save_user_usage(usage_dict):
722
+ # try:
723
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
724
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
725
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
726
+ # client = gspread.authorize(creds)
727
+
728
+ # spreadsheet = client.open("user_usage_log")
729
+ # sheet = spreadsheet.sheet1
730
+
731
+ # # Step 1: Convert new usage to DataFrame
732
+ # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
733
+ # df_new["email"] = df_new["email"].str.strip().str.lower()
734
+
735
+ # # Step 2: Load existing data
736
+ # existing_data = sheet.get_all_values()
737
+ # print("πŸ§ͺ Sheet existing_data:", existing_data)
738
+
739
+ # # Try to load old data
740
+ # if existing_data and len(existing_data[0]) >= 1:
741
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
742
+
743
+ # # Fix missing columns
744
+ # if "email" not in df_old.columns:
745
+ # df_old["email"] = ""
746
+ # if "usage_count" not in df_old.columns:
747
+ # df_old["usage_count"] = 0
748
+
749
+ # df_old["email"] = df_old["email"].str.strip().str.lower()
750
+ # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
751
+ # else:
752
+ # df_old = pd.DataFrame(columns=["email", "usage_count"])
753
+
754
+ # # Step 3: Merge
755
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True)
756
+ # df_combined = df_combined.groupby("email", as_index=False).sum()
757
+
758
+ # # Step 4: Write back
759
+ # sheet.clear()
760
+ # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
761
+ # print("βœ… Saved user usage to user_usage_log sheet.")
762
+
763
+ # except Exception as e:
764
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
765
+ def save_user_usage(usage_dict):
766
+ try:
767
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
768
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
769
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
770
+ client = gspread.authorize(creds)
771
+
772
+ spreadsheet = client.open("user_usage_log")
773
+ sheet = spreadsheet.sheet1
774
+
775
+ # Build new df
776
+ df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
777
+ df_new["email"] = df_new["email"].str.strip().str.lower()
778
+ df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
779
+
780
+ # Read existing data
781
+ existing_data = sheet.get_all_values()
782
+ if existing_data and len(existing_data[0]) >= 2:
783
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
784
+ df_old["email"] = df_old["email"].str.strip().str.lower()
785
+ df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
786
+ else:
787
+ df_old = pd.DataFrame(columns=["email", "usage_count"])
788
+
789
+ # βœ… Overwrite specific emails only
790
+ df_old = df_old.set_index("email")
791
+ for email, count in usage_dict.items():
792
+ email = email.strip().lower()
793
+ df_old.loc[email, "usage_count"] = count
794
+ df_old = df_old.reset_index()
795
+
796
+ # Save
797
+ sheet.clear()
798
+ sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
799
+ print("βœ… Saved user usage to user_usage_log sheet.")
800
+
801
+ except Exception as e:
802
+ print(f"❌ Failed to save user usage to Google Sheets: {e}")
803
+
804
+
805
+
806
+
807
+ # def increment_usage(user_id, num_samples=1):
808
+ # usage = load_user_usage()
809
+ # if user_id not in usage:
810
+ # usage[user_id] = 0
811
+ # usage[user_id] += num_samples
812
+ # save_user_usage(usage)
813
+ # return usage[user_id]
814
+ # def increment_usage(email: str, count: int):
815
+ # usage = load_user_usage()
816
+ # email_key = email.strip().lower()
817
+ # usage[email_key] = usage.get(email_key, 0) + count
818
+ # save_user_usage(usage)
819
+ # return usage[email_key]
820
+ def increment_usage(email: str, count: int = 1):
821
+ usage, permitted = load_user_usage()
822
+ email_key = email.strip().lower()
823
+ #usage[email_key] = usage.get(email_key, 0) + count
824
+ current = usage.get(email_key, 0)
825
+ new_value = current + count
826
+ max_allowed = permitted.get(email_key) or 50
827
+ usage[email_key] = max(current, new_value) # βœ… Prevent overwrite with lower
828
+ print(f"πŸ§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
829
+ print("max allow is: ", max_allowed)
830
+ save_user_usage(usage)
831
+ return usage[email_key], max_allowed
832
+
833
+
834
+ # run the batch
835
+ def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
836
+ stop_flag=None, output_file_path=None,
837
+ limited_acc=50, yield_callback=None):
838
+ if user_email:
839
+ limited_acc += 10
840
+ accessions, error = extract_accessions_from_input(file, raw_text)
841
+ if error:
842
+ #return [], "", "", f"Error: {error}"
843
+ return [], f"Error: {error}", 0, "", ""
844
+ if resume_file:
845
+ accessions = get_incomplete_accessions(resume_file)
846
+ tmp_dir = tempfile.mkdtemp()
847
+ if not output_file_path:
848
+ if resume_file:
849
+ output_file_path = os.path.join(tmp_dir, resume_file)
850
+ else:
851
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
852
+
853
+ all_rows = []
854
+ # all_summaries = []
855
+ # all_flags = []
856
+ progress_lines = []
857
+ warning = ""
858
+ if len(accessions) > limited_acc:
859
+ accessions = accessions[:limited_acc]
860
+ warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
861
+ for i, acc in enumerate(accessions):
862
+ if stop_flag and stop_flag.value:
863
+ line = f"πŸ›‘ Stopped at {acc} ({i+1}/{len(accessions)})"
864
+ progress_lines.append(line)
865
+ if yield_callback:
866
+ yield_callback(line)
867
+ print("πŸ›‘ User requested stop.")
868
+ break
869
+ print(f"[{i+1}/{len(accessions)}] Processing {acc}")
870
+ try:
871
+ # rows, summary, label, explain = summarize_results(acc)
872
+ rows = summarize_results(acc)
873
+ all_rows.extend(rows)
874
+ # all_summaries.append(f"**{acc}**\n{summary}")
875
+ # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
876
+ #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
877
+ save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
878
+ line = f"βœ… Processed {acc} ({i+1}/{len(accessions)})"
879
+ progress_lines.append(line)
880
+ if yield_callback:
881
+ yield_callback(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
882
+ except Exception as e:
883
+ print(f"❌ Failed to process {acc}: {e}")
884
+ continue
885
+ #all_summaries.append(f"**{acc}**: Failed - {e}")
886
+ #progress_lines.append(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
887
+ limited_acc -= 1
888
+ """for row in all_rows:
889
+ source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
890
+
891
+ if source_column.startswith("http"): # Check if the source is a URL
892
+ # Wrap it with HTML anchor tags to make it clickable
893
+ row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
894
+ if not warning:
895
+ warning = f"You only have {limited_acc} left"
896
+ if user_email.strip():
897
+ user_hash = hash_user_id(user_email)
898
+ total_queries = increment_usage(user_hash, len(all_rows))
899
+ else:
900
+ total_queries = 0
901
+ yield_callback("βœ… Finished!")
902
+
903
+ # summary_text = "\n\n---\n\n".join(all_summaries)
904
+ # flag_text = "\n\n---\n\n".join(all_flags)
905
+ #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
906
+ #return all_rows, gr.update(visible=True), gr.update(visible=False)
907
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning