VyLala commited on
Commit
fb8d818
·
verified ·
1 Parent(s): bdd4893

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +47 -10
pipeline.py CHANGED
@@ -187,7 +187,7 @@ def unique_preserve_order(seq):
187
  seen = set()
188
  return [x for x in seq if not (x in seen or seen.add(x))]
189
  # Main execution
190
- def pipeline_with_gemini(accessions):
191
  # output: country, sample_type, ethnic, location, money_cost, time_cost, explain
192
  # there can be one accession number in the accessions
193
  # Prices are per 1,000 tokens
@@ -213,6 +213,10 @@ def pipeline_with_gemini(accessions):
213
  "query_cost":total_cost_title,
214
  "time_cost":None,
215
  "source":links}
 
 
 
 
216
  meta = mtdna_classifier.fetch_ncbi_metadata(acc)
217
  country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"]
218
  acc_score["isolate"] = iso
@@ -350,7 +354,15 @@ def pipeline_with_gemini(accessions):
350
  print("tem link before filtering: ", tem_links)
351
  # filter the quality link
352
  print("saveLinkFolder as sample folder id: ", sample_folder_id)
353
- links = smart_fallback.filter_links_by_metadata(tem_links, saveLinkFolder=sample_folder_id, accession=acc)
 
 
 
 
 
 
 
 
354
  print("this is links: ",links)
355
  links = unique_preserve_order(links)
356
  acc_score["source"] = links
@@ -419,15 +431,26 @@ def pipeline_with_gemini(accessions):
419
  final_input_link = data_preprocess.normalize_for_overlap(final_input_link)
420
  if len(final_input_link) > 1000 *1000:
421
  final_input_link = final_input_link[:100000]
422
- if len(data_preprocess.normalize_for_overlap(all_output)) < 1000*1000:
423
- success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link))
 
 
424
  if success:
425
  all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
426
  print("yes succeed")
427
  else:
 
 
428
  all_output += final_input_link
429
  print("len final input: ", len(final_input_link))
430
  print("basic fall back")
 
 
 
 
 
 
 
431
  print("len all output after: ", len(all_output))
432
  #country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate)
433
 
@@ -552,12 +575,19 @@ def pipeline_with_gemini(accessions):
552
  model.call_llm_api, chunk=chunk, all_output=all_output)
553
  print("country using ai: ", country)
554
  print("sample type using ai: ", sample_type)
 
 
 
 
 
 
555
  if len(country) == 0: country = "unknown"
556
  if len(sample_type) == 0: sample_type = "unknown"
557
- if country_explanation: country_explanation = "-"+country_explanation
558
  else: country_explanation = ""
559
- if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
560
  else: sample_type_explanation = ""
 
561
  if method_used == "unknown": method_used = ""
562
  if country.lower() != "unknown":
563
  stand_country = standardize_location.smart_country_lookup(country.lower())
@@ -592,8 +622,9 @@ def pipeline_with_gemini(accessions):
592
  else:
593
  if len(method_used + sample_type_explanation)> 0:
594
  acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
 
595
  # last resort: combine all information to give all output otherwise unknown
596
- if len(acc_score["country"]) == 0 or len(acc_score["sample_type"]) == 0:
597
  text = ""
598
  for key in meta_expand:
599
  text += str(key) + ": " + meta_expand[key] + "\n"
@@ -612,10 +643,15 @@ def pipeline_with_gemini(accessions):
612
  print("sample type: ", sample_type)
613
  if len(country) == 0: country = "unknown"
614
  if len(sample_type) == 0: sample_type = "unknown"
615
- if country_explanation: country_explanation = "-"+country_explanation
 
 
 
 
616
  else: country_explanation = ""
617
- if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
618
  else: sample_type_explanation = ""
 
619
  if method_used == "unknown": method_used = ""
620
  if country.lower() != "unknown":
621
  stand_country = standardize_location.smart_country_lookup(country.lower())
@@ -640,8 +676,9 @@ def pipeline_with_gemini(accessions):
640
  else:
641
  if len(method_used + sample_type_explanation)> 0:
642
  acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
 
643
  end = time.time()
644
- total_cost_title += total_query_cost
645
  acc_score["query_cost"] = f"{total_cost_title:.6f}"
646
  elapsed = end - start
647
  acc_score["time_cost"] = f"{elapsed:.3f} seconds"
 
187
  seen = set()
188
  return [x for x in seq if not (x in seen or seen.add(x))]
189
  # Main execution
190
+ def pipeline_with_gemini(accessions,niche_cases=None):
191
  # output: country, sample_type, ethnic, location, money_cost, time_cost, explain
192
  # there can be one accession number in the accessions
193
  # Prices are per 1,000 tokens
 
213
  "query_cost":total_cost_title,
214
  "time_cost":None,
215
  "source":links}
216
+ if niche_cases:
217
+ for niche in niche_cases:
218
+ acc_score[niche] = {}
219
+
220
  meta = mtdna_classifier.fetch_ncbi_metadata(acc)
221
  country, spe_loc, ethnic, sample_type, col_date, iso, title, doi, pudID, features = meta["country"], meta["specific_location"], meta["ethnicity"], meta["sample_type"], meta["collection_date"], meta["isolate"], meta["title"], meta["doi"], meta["pubmed_id"], meta["all_features"]
222
  acc_score["isolate"] = iso
 
354
  print("tem link before filtering: ", tem_links)
355
  # filter the quality link
356
  print("saveLinkFolder as sample folder id: ", sample_folder_id)
357
+ print("start the smart filter link")
358
+ success_process, output_process = run_with_timeout(smart_fallback.filter_links_by_metadata,args=(tem_links,sample_folder_id),kwargs={"accession":acc},timeout=100)
359
+ if success_process:
360
+ links = output_process
361
+ print("yes succeed for smart filter link")
362
+ else:
363
+ print("no suceed, fallback to all tem links")
364
+ links = tem_links
365
+ #links = smart_fallback.filter_links_by_metadata(tem_links, saveLinkFolder=sample_folder_id, accession=acc)
366
  print("this is links: ",links)
367
  links = unique_preserve_order(links)
368
  acc_score["source"] = links
 
431
  final_input_link = data_preprocess.normalize_for_overlap(final_input_link)
432
  if len(final_input_link) > 1000 *1000:
433
  final_input_link = final_input_link[:100000]
434
+ if len(data_preprocess.normalize_for_overlap(all_output)) < int(100000) and len(final_input_link)<100000:
435
+ print("Running merge_texts_skipping_overlap with timeout")
436
+ success, the_output = run_with_timeout(data_preprocess.merge_texts_skipping_overlap,args=(all_output, final_input_link),timeout=30)
437
+ print("Returned from timeout logic")
438
  if success:
439
  all_output = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
440
  print("yes succeed")
441
  else:
442
+ print("len all output: ", len(all_output))
443
+ print("len final input link: ", len(final_input_link))
444
  all_output += final_input_link
445
  print("len final input: ", len(final_input_link))
446
  print("basic fall back")
447
+ else:
448
+ print("both/either all output or final link too large more than 100000")
449
+ print("len all output: ", len(all_output))
450
+ print("len final input link: ", len(final_input_link))
451
+ all_output += final_input_link
452
+ print("len final input: ", len(final_input_link))
453
+ print("basic fall back")
454
  print("len all output after: ", len(all_output))
455
  #country_pro, chunk, all_output = data_preprocess.process_inputToken(links, saveLinkFolder, accession=accession, isolate=isolate)
456
 
 
575
  model.call_llm_api, chunk=chunk, all_output=all_output)
576
  print("country using ai: ", country)
577
  print("sample type using ai: ", sample_type)
578
+ # if len(country) == 0: country = "unknown"
579
+ # if len(sample_type) == 0: sample_type = "unknown"
580
+ # if country_explanation: country_explanation = "-"+country_explanation
581
+ # else: country_explanation = ""
582
+ # if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
583
+ # else: sample_type_explanation = ""
584
  if len(country) == 0: country = "unknown"
585
  if len(sample_type) == 0: sample_type = "unknown"
586
+ if country_explanation and country_explanation!="unknown": country_explanation = "-"+country_explanation
587
  else: country_explanation = ""
588
+ if sample_type_explanation and sample_type_explanation!="unknown": sample_type_explanation = "-"+sample_type_explanation
589
  else: sample_type_explanation = ""
590
+
591
  if method_used == "unknown": method_used = ""
592
  if country.lower() != "unknown":
593
  stand_country = standardize_location.smart_country_lookup(country.lower())
 
622
  else:
623
  if len(method_used + sample_type_explanation)> 0:
624
  acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
625
+ total_cost_title += total_query_cost
626
  # last resort: combine all information to give all output otherwise unknown
627
+ if len(acc_score["country"]) == 0 or len(acc_score["sample_type"]) == 0 or acc_score["country"] == "unknown" or acc_score["sample_type"] == "unknown":
628
  text = ""
629
  for key in meta_expand:
630
  text += str(key) + ": " + meta_expand[key] + "\n"
 
643
  print("sample type: ", sample_type)
644
  if len(country) == 0: country = "unknown"
645
  if len(sample_type) == 0: sample_type = "unknown"
646
+ # if country_explanation: country_explanation = "-"+country_explanation
647
+ # else: country_explanation = ""
648
+ # if sample_type_explanation: sample_type_explanation = "-"+sample_type_explanation
649
+ # else: sample_type_explanation = ""
650
+ if country_explanation and country_explanation!="unknown": country_explanation = "-"+country_explanation
651
  else: country_explanation = ""
652
+ if sample_type_explanation and sample_type_explanation!="unknown": sample_type_explanation = "-"+sample_type_explanation
653
  else: sample_type_explanation = ""
654
+
655
  if method_used == "unknown": method_used = ""
656
  if country.lower() != "unknown":
657
  stand_country = standardize_location.smart_country_lookup(country.lower())
 
676
  else:
677
  if len(method_used + sample_type_explanation)> 0:
678
  acc_score["sample_type"][sample_type.lower()] = [method_used + sample_type_explanation]
679
+ total_cost_title += total_query_cost
680
  end = time.time()
681
+ #total_cost_title += total_query_cost
682
  acc_score["query_cost"] = f"{total_cost_title:.6f}"
683
  elapsed = end - start
684
  acc_score["time_cost"] = f"{elapsed:.3f} seconds"