import gradio as gr from collections import Counter import csv import os from functools import lru_cache #import app from mtdna_classifier import classify_sample_location import data_preprocess, model, pipeline import subprocess import json import pandas as pd import io import re import tempfile import gspread from oauth2client.service_account import ServiceAccountCredentials from io import StringIO import hashlib import threading # @lru_cache(maxsize=3600) # def classify_sample_location_cached(accession): # return classify_sample_location(accession) #@lru_cache(maxsize=3600) def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None): print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession]) print("len of save df: ", len(save_df)) return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df) # Count and suggest final location # def compute_final_suggested_location(rows): # candidates = [ # row.get("Predicted Location", "").strip() # for row in rows # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"] # ] + [ # row.get("Inferred Region", "").strip() # for row in rows # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"] # ] # if not candidates: # return Counter(), ("Unknown", 0) # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc. # tokens = [] # for item in candidates: # # Split by comma, whitespace, and newlines # parts = re.split(r'[\s,]+', item) # tokens.extend(parts) # # Step 2: Clean and normalize tokens # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens # # Step 3: Count # counts = Counter(tokens) # # Step 4: Get most common # top_location, count = counts.most_common(1)[0] # return counts, (top_location, count) # Store feedback (with required fields) def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""): if not answer1.strip() or not answer2.strip(): return "⚠️ Please answer both questions before submitting." try: # ✅ Step: Load credentials from Hugging Face secret creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # Connect to Google Sheet client = gspread.authorize(creds) sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches # Append feedback sheet.append_row([accession, answer1, answer2, contact]) return "✅ Feedback submitted. Thank you!" except Exception as e: return f"❌ Error submitting feedback: {e}" # helper function to extract accessions def extract_accessions_from_input(file=None, raw_text=""): print(f"RAW TEXT RECEIVED: {raw_text}") accessions = [] seen = set() if file: try: if file.name.endswith(".csv"): df = pd.read_csv(file) elif file.name.endswith(".xlsx"): df = pd.read_excel(file) else: return [], "Unsupported file format. Please upload CSV or Excel." for acc in df.iloc[:, 0].dropna().astype(str).str.strip(): if acc not in seen: accessions.append(acc) seen.add(acc) except Exception as e: return [], f"Failed to read file: {e}" if raw_text: text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()] for acc in text_ids: if acc not in seen: accessions.append(acc) seen.add(acc) return list(accessions), None # ✅ Add a new helper to backend: `filter_unprocessed_accessions()` def get_incomplete_accessions(file_path): df = pd.read_excel(file_path) incomplete_accessions = [] for _, row in df.iterrows(): sample_id = str(row.get("Sample ID", "")).strip() # Skip if no sample ID if not sample_id: continue # Drop the Sample ID and check if the rest is empty other_cols = row.drop(labels=["Sample ID"], errors="ignore") if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all(): # Extract the accession number from the sample ID using regex match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id) if match: incomplete_accessions.append(match.group(0)) print(len(incomplete_accessions)) return incomplete_accessions # GOOGLE_SHEET_NAME = "known_samples" # USAGE_DRIVE_FILENAME = "user_usage_log.json" def summarize_results(accession, stop_flag=None): # Early bail if stop_flag is not None and stop_flag.value: print(f"🛑 Skipping {accession} before starting.") return [] # try cache first cached = check_known_output(accession) if cached: print(f"✅ Using cached result for {accession}") return [[ cached["Sample ID"] or "unknown", cached["Predicted Country"] or "unknown", cached["Country Explanation"] or "unknown", cached["Predicted Sample Type"] or "unknown", cached["Sample Type Explanation"] or "unknown", cached["Sources"] or "No Links", cached["Time cost"] ]] # only run when nothing in the cache try: print("try gemini pipeline: ",accession) # ✅ Load credentials from Hugging Face secret creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) client = gspread.authorize(creds) spreadsheet = client.open("known_samples") sheet = spreadsheet.sheet1 data = sheet.get_all_values() if not data: print("⚠️ Google Sheet 'known_samples' is empty.") return None save_df = pd.DataFrame(data[1:], columns=data[0]) print("before pipeline, len of save df: ", len(save_df)) outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df) if stop_flag is not None and stop_flag.value: print(f"🛑 Skipped {accession} mid-pipeline.") return [] # outputs = {'KU131308': {'isolate':'BRU18', # 'country': {'brunei': ['ncbi', # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, # 'sample_type': {'modern': # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, # 'query_cost': 9.754999999999999e-05, # 'time_cost': '24.776 seconds', # 'source': ['https://doi.org/10.1007/s00439-015-1620-z', # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf', # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}} except Exception as e: return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}" if accession not in outputs: print("no accession in output ", accession) return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results." row_score = [] rows = [] save_rows = [] for key in outputs: pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown" for section, results in outputs[key].items(): if section == "country" or section =="sample_type": pred_output = []#"\n".join(list(results.keys())) output_explanation = "" for result, content in results.items(): if len(result) == 0: result = "unknown" if len(content) == 0: output_explanation = "unknown" else: output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n" pred_output.append(result) pred_output = "\n".join(pred_output) if section == "country": pred_country, country_explanation = pred_output, output_explanation elif section == "sample_type": pred_sample, sample_explanation = pred_output, output_explanation if outputs[key]["isolate"].lower()!="unknown": label = key + "(Isolate: " + outputs[key]["isolate"] + ")" else: label = key if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"] row = { "Sample ID": label or "unknown", "Predicted Country": pred_country or "unknown", "Country Explanation": country_explanation or "unknown", "Predicted Sample Type":pred_sample or "unknown", "Sample Type Explanation":sample_explanation or "unknown", "Sources": "\n".join(outputs[key]["source"]) or "No Links", "Time cost": outputs[key]["time_cost"] } #row_score.append(row) rows.append(list(row.values())) save_row = { "Sample ID": label or "unknown", "Predicted Country": pred_country or "unknown", "Country Explanation": country_explanation or "unknown", "Predicted Sample Type":pred_sample or "unknown", "Sample Type Explanation":sample_explanation or "unknown", "Sources": "\n".join(outputs[key]["source"]) or "No Links", "Query_cost": outputs[key]["query_cost"] or "", "Time cost": outputs[key]["time_cost"] or "", "file_chunk":outputs[key]["file_chunk"] or "", "file_all_output":outputs[key]["file_all_output"] or "" } #row_score.append(row) save_rows.append(list(save_row.values())) # #location_counts, (final_location, count) = compute_final_suggested_location(row_score) # summary_lines = [f"### 🧭 Location Summary:\n"] # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()] # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)") # summary = "\n".join(summary_lines) # save the new running sample to known excel file # try: # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"]) # if os.path.exists(KNOWN_OUTPUT_PATH): # df_old = pd.read_excel(KNOWN_OUTPUT_PATH) # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") # else: # df_combined = df_new # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False) # except Exception as e: # print(f"⚠️ Failed to save known output: {e}") # try: # df_new = pd.DataFrame(save_rows, columns=[ # "Sample ID", "Predicted Country", "Country Explanation", # "Predicted Sample Type", "Sample Type Explanation", # "Sources", "Query_cost", "Time cost" # ]) # # ✅ Google Sheets API setup # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # client = gspread.authorize(creds) # # ✅ Open the known_samples sheet # spreadsheet = client.open("known_samples") # Replace with your sheet name # sheet = spreadsheet.sheet1 # # ✅ Read old data # existing_data = sheet.get_all_values() # if existing_data: # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) # else: # df_old = pd.DataFrame(columns=df_new.columns) # # ✅ Combine and remove duplicates # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID") # # ✅ Clear and write back # sheet.clear() # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist()) # except Exception as e: # print(f"⚠️ Failed to save known output to Google Sheets: {e}") try: # Prepare as DataFrame df_new = pd.DataFrame(save_rows, columns=[ "Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output" ]) # ✅ Setup Google Sheets creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) client = gspread.authorize(creds) spreadsheet = client.open("known_samples") sheet = spreadsheet.sheet1 # ✅ Read existing data existing_data = sheet.get_all_values() if existing_data: df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) else: df_old = pd.DataFrame(columns=[ "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation", "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type", "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output" ]) # ✅ Index by Sample ID df_old.set_index("Sample ID", inplace=True) df_new.set_index("Sample ID", inplace=True) # ✅ Update only matching fields update_columns = [ "Predicted Country", "Predicted Sample Type", "Country Explanation", "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output" ] for idx, row in df_new.iterrows(): if idx not in df_old.index: df_old.loc[idx] = "" # new row, fill empty first for col in update_columns: if pd.notna(row[col]) and row[col] != "": df_old.at[idx, col] = row[col] # ✅ Reset and write back df_old.reset_index(inplace=True) sheet.clear() sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist()) print("✅ Match results saved to known_samples.") except Exception as e: print(f"❌ Failed to update known_samples: {e}") return rows#, summary, labelAncient_Modern, explain_label # save the batch input in excel file # def save_to_excel(all_rows, summary_text, flag_text, filename): # with pd.ExcelWriter(filename) as writer: # # Save table # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) # df.to_excel(writer, sheet_name="Detailed Results", index=False) # try: # df_old = pd.read_excel(filename) # except: # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") # # if os.path.exists(filename): # # df_old = pd.read_excel(filename) # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") # # else: # # df_combined = df_new # df_combined.to_excel(filename, index=False) # # # Save summary # # summary_df = pd.DataFrame({"Summary": [summary_text]}) # # summary_df.to_excel(writer, sheet_name="Summary", index=False) # # # Save flag # # flag_df = pd.DataFrame({"Flag": [flag_text]}) # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False) # def save_to_excel(all_rows, summary_text, flag_text, filename): # df_new = pd.DataFrame(all_rows, columns=[ # "Sample ID", "Predicted Country", "Country Explanation", # "Predicted Sample Type", "Sample Type Explanation", # "Sources", "Time cost" # ]) # try: # if os.path.exists(filename): # df_old = pd.read_excel(filename) # else: # df_old = pd.DataFrame(columns=df_new.columns) # except Exception as e: # print(f"⚠️ Warning reading old Excel file: {e}") # df_old = pd.DataFrame(columns=df_new.columns) # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first") # df_old.set_index("Sample ID", inplace=True) # df_new.set_index("Sample ID", inplace=True) # df_old.update(df_new) # <-- update matching rows in df_old with df_new content # df_combined = df_old.reset_index() # try: # df_combined.to_excel(filename, index=False) # except Exception as e: # print(f"❌ Failed to write Excel file {filename}: {e}") def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False): df_new = pd.DataFrame(all_rows, columns=[ "Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost" ]) if is_resume and os.path.exists(filename): try: df_old = pd.read_excel(filename) except Exception as e: print(f"⚠️ Warning reading old Excel file: {e}") df_old = pd.DataFrame(columns=df_new.columns) # Set index and update existing rows df_old.set_index("Sample ID", inplace=True) df_new.set_index("Sample ID", inplace=True) df_old.update(df_new) df_combined = df_old.reset_index() else: # If not resuming or file doesn't exist, just use new rows df_combined = df_new try: df_combined.to_excel(filename, index=False) except Exception as e: print(f"❌ Failed to write Excel file {filename}: {e}") # save the batch input in JSON file def save_to_json(all_rows, summary_text, flag_text, filename): output_dict = { "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame # "Summary_Text": summary_text, # "Ancient_Modern_Flag": flag_text } # If all_rows is a DataFrame, convert it if isinstance(all_rows, pd.DataFrame): output_dict["Detailed_Results"] = all_rows.to_dict(orient="records") with open(filename, "w") as external_file: json.dump(output_dict, external_file, indent=2) # save the batch input in Text file def save_to_txt(all_rows, summary_text, flag_text, filename): if isinstance(all_rows, pd.DataFrame): detailed_results = all_rows.to_dict(orient="records") output = "" #output += ",".join(list(detailed_results[0].keys())) + "\n\n" output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n" for r in detailed_results: output += ",".join([str(v) for v in r.values()]) + "\n\n" with open(filename, "w") as f: f.write("=== Detailed Results ===\n") f.write(output + "\n") # f.write("\n=== Summary ===\n") # f.write(summary_text + "\n") # f.write("\n=== Ancient/Modern Flag ===\n") # f.write(flag_text + "\n") def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None): tmp_dir = tempfile.mkdtemp() #html_table = all_rows.value # assuming this is stored somewhere # Parse back to DataFrame #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list all_rows = pd.read_html(StringIO(all_rows))[0] print(all_rows) if output_type == "Excel": file_path = f"{tmp_dir}/batch_output.xlsx" save_to_excel(all_rows, summary_text, flag_text, file_path) elif output_type == "JSON": file_path = f"{tmp_dir}/batch_output.json" save_to_json(all_rows, summary_text, flag_text, file_path) print("Done with JSON") elif output_type == "TXT": file_path = f"{tmp_dir}/batch_output.txt" save_to_txt(all_rows, summary_text, flag_text, file_path) else: return gr.update(visible=False) # invalid option return gr.update(value=file_path, visible=True) # save cost by checking the known outputs # def check_known_output(accession): # if not os.path.exists(KNOWN_OUTPUT_PATH): # return None # try: # df = pd.read_excel(KNOWN_OUTPUT_PATH) # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) # if match: # accession = match.group(0) # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] # if not matched.empty: # return matched.iloc[0].to_dict() # Return the cached row # except Exception as e: # print(f"⚠️ Failed to load known samples: {e}") # return None # def check_known_output(accession): # try: # # ✅ Load credentials from Hugging Face secret # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # client = gspread.authorize(creds) # # ✅ Open the known_samples sheet # spreadsheet = client.open("known_samples") # Replace with your sheet name # sheet = spreadsheet.sheet1 # # ✅ Read all rows # data = sheet.get_all_values() # if not data: # return None # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row # # ✅ Normalize accession pattern # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) # if match: # accession = match.group(0) # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] # if not matched.empty: # return matched.iloc[0].to_dict() # except Exception as e: # print(f"⚠️ Failed to load known samples from Google Sheets: {e}") # return None def check_known_output(accession): print("inside check known output function") try: # ✅ Load credentials from Hugging Face secret creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) client = gspread.authorize(creds) spreadsheet = client.open("known_samples") sheet = spreadsheet.sheet1 data = sheet.get_all_values() if not data: print("⚠️ Google Sheet 'known_samples' is empty.") return None df = pd.DataFrame(data[1:], columns=data[0]) if "Sample ID" not in df.columns: print("❌ Column 'Sample ID' not found in Google Sheet.") return None match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) if match: accession = match.group(0) matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] if not matched.empty: #return matched.iloc[0].to_dict() row = matched.iloc[0] country = row.get("Predicted Country", "").strip().lower() sample_type = row.get("Predicted Sample Type", "").strip().lower() if country and country != "unknown" and sample_type and sample_type != "unknown": return row.to_dict() else: print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.") return None else: print(f"🔍 Accession {accession} not found in known_samples.") return None except Exception as e: import traceback print("❌ Exception occurred during check_known_output:") traceback.print_exc() return None def hash_user_id(user_input): return hashlib.sha256(user_input.encode()).hexdigest() # ✅ Load and save usage count # def load_user_usage(): # if not os.path.exists(USER_USAGE_TRACK_FILE): # return {} # try: # with open(USER_USAGE_TRACK_FILE, "r") as f: # content = f.read().strip() # if not content: # return {} # file is empty # return json.loads(content) # except (json.JSONDecodeError, ValueError): # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.") # return {} # fallback to empty dict # def load_user_usage(): # try: # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # client = gspread.authorize(creds) # sheet = client.open("user_usage_log").sheet1 # data = sheet.get_all_records() # Assumes columns: email, usage_count # usage = {} # for row in data: # email = row.get("email", "").strip().lower() # count = int(row.get("usage_count", 0)) # if email: # usage[email] = count # return usage # except Exception as e: # print(f"⚠️ Failed to load user usage from Google Sheets: {e}") # return {} # def load_user_usage(): # try: # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier") # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id) # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id) # if not found: # return {} # not found, start fresh # #file_id = found[0]["id"] # file_id = found # content = pipeline.download_drive_file_content(file_id) # return json.loads(content.strip()) if content.strip() else {} # except Exception as e: # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}") # return {} def load_user_usage(): try: creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) client = gspread.authorize(creds) sheet = client.open("user_usage_log").sheet1 data = sheet.get_all_values() print("data: ", data) print("🧪 Raw header row from sheet:", data[0]) print("🧪 Character codes in each header:") for h in data[0]: print([ord(c) for c in h]) if not data or len(data) < 2: print("⚠️ Sheet is empty or missing rows.") return {} headers = [h.strip().lower() for h in data[0]] if "email" not in headers or "usage_count" not in headers: print("❌ Header format incorrect. Must have 'email' and 'usage_count'.") return {} permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None df = pd.DataFrame(data[1:], columns=headers) usage = {} permitted = {} for _, row in df.iterrows(): email = row.get("email", "").strip().lower() try: #count = int(row.get("usage_count", 0)) try: count = int(float(row.get("usage_count", 0))) except Exception: print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}") count = 0 if email: usage[email] = count if permitted_index is not None: try: permitted_count = int(float(row.get("permitted_samples", 50))) permitted[email] = permitted_count except: permitted[email] = 50 except ValueError: print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}") return usage, permitted except Exception as e: print(f"❌ Error in load_user_usage: {e}") return {}, {} # def save_user_usage(usage): # with open(USER_USAGE_TRACK_FILE, "w") as f: # json.dump(usage, f, indent=2) # def save_user_usage(usage_dict): # try: # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # client = gspread.authorize(creds) # sheet = client.open("user_usage_log").sheet1 # sheet.clear() # clear old contents first # # Write header + rows # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()] # sheet.update(rows) # except Exception as e: # print(f"❌ Failed to save user usage to Google Sheets: {e}") # def save_user_usage(usage_dict): # try: # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier") # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id) # import tempfile # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json") # print("💾 Saving this usage dict:", usage_dict) # with open(tmp_path, "w") as f: # json.dump(usage_dict, f, indent=2) # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id) # except Exception as e: # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}") # def save_user_usage(usage_dict): # try: # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) # client = gspread.authorize(creds) # spreadsheet = client.open("user_usage_log") # sheet = spreadsheet.sheet1 # # Step 1: Convert new usage to DataFrame # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"]) # df_new["email"] = df_new["email"].str.strip().str.lower() # # Step 2: Load existing data # existing_data = sheet.get_all_values() # print("🧪 Sheet existing_data:", existing_data) # # Try to load old data # if existing_data and len(existing_data[0]) >= 1: # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) # # Fix missing columns # if "email" not in df_old.columns: # df_old["email"] = "" # if "usage_count" not in df_old.columns: # df_old["usage_count"] = 0 # df_old["email"] = df_old["email"].str.strip().str.lower() # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int) # else: # df_old = pd.DataFrame(columns=["email", "usage_count"]) # # Step 3: Merge # df_combined = pd.concat([df_old, df_new], ignore_index=True) # df_combined = df_combined.groupby("email", as_index=False).sum() # # Step 4: Write back # sheet.clear() # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist()) # print("✅ Saved user usage to user_usage_log sheet.") # except Exception as e: # print(f"❌ Failed to save user usage to Google Sheets: {e}") def save_user_usage(usage_dict): try: creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) client = gspread.authorize(creds) spreadsheet = client.open("user_usage_log") sheet = spreadsheet.sheet1 # Build new df df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"]) df_new["email"] = df_new["email"].str.strip().str.lower() df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int) # Read existing data existing_data = sheet.get_all_values() if existing_data and len(existing_data[0]) >= 2: df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) df_old["email"] = df_old["email"].str.strip().str.lower() df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int) else: df_old = pd.DataFrame(columns=["email", "usage_count"]) # ✅ Overwrite specific emails only df_old = df_old.set_index("email") for email, count in usage_dict.items(): email = email.strip().lower() df_old.loc[email, "usage_count"] = count df_old = df_old.reset_index() # Save sheet.clear() sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist()) print("✅ Saved user usage to user_usage_log sheet.") except Exception as e: print(f"❌ Failed to save user usage to Google Sheets: {e}") # def increment_usage(user_id, num_samples=1): # usage = load_user_usage() # if user_id not in usage: # usage[user_id] = 0 # usage[user_id] += num_samples # save_user_usage(usage) # return usage[user_id] # def increment_usage(email: str, count: int): # usage = load_user_usage() # email_key = email.strip().lower() # usage[email_key] = usage.get(email_key, 0) + count # save_user_usage(usage) # return usage[email_key] def increment_usage(email: str, count: int = 1): usage, permitted = load_user_usage() email_key = email.strip().lower() #usage[email_key] = usage.get(email_key, 0) + count current = usage.get(email_key, 0) new_value = current + count max_allowed = permitted.get(email_key) or 50 usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}") print("max allow is: ", max_allowed) save_user_usage(usage) return usage[email_key], max_allowed # run the batch def summarize_batch(file=None, raw_text="", resume_file=None, user_email="", stop_flag=None, output_file_path=None, limited_acc=50, yield_callback=None): if user_email: limited_acc += 10 accessions, error = extract_accessions_from_input(file, raw_text) if error: #return [], "", "", f"Error: {error}" return [], f"Error: {error}", 0, "", "" if resume_file: accessions = get_incomplete_accessions(resume_file) tmp_dir = tempfile.mkdtemp() if not output_file_path: if resume_file: output_file_path = os.path.join(tmp_dir, resume_file) else: output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx") all_rows = [] # all_summaries = [] # all_flags = [] progress_lines = [] warning = "" if len(accessions) > limited_acc: accessions = accessions[:limited_acc] warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions" for i, acc in enumerate(accessions): if stop_flag and stop_flag.value: line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})" progress_lines.append(line) if yield_callback: yield_callback(line) print("🛑 User requested stop.") break print(f"[{i+1}/{len(accessions)}] Processing {acc}") try: # rows, summary, label, explain = summarize_results(acc) rows = summarize_results(acc) all_rows.extend(rows) # all_summaries.append(f"**{acc}**\n{summary}") # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}") #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path) save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file)) line = f"✅ Processed {acc} ({i+1}/{len(accessions)})" progress_lines.append(line) if yield_callback: yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})") except Exception as e: print(f"❌ Failed to process {acc}: {e}") continue #all_summaries.append(f"**{acc}**: Failed - {e}") #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})") limited_acc -= 1 """for row in all_rows: source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2) if source_column.startswith("http"): # Check if the source is a URL # Wrap it with HTML anchor tags to make it clickable row[2] = f'{source_column}'""" if not warning: warning = f"You only have {limited_acc} left" if user_email.strip(): user_hash = hash_user_id(user_email) total_queries = increment_usage(user_hash, len(all_rows)) else: total_queries = 0 yield_callback("✅ Finished!") # summary_text = "\n\n---\n\n".join(all_summaries) # flag_text = "\n\n---\n\n".join(all_flags) #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False) #return all_rows, gr.update(visible=True), gr.update(visible=False) return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning