Fill-Mask
Transformers
Safetensors
esm
File size: 44,518 Bytes
1e6a1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
### Prepare to BLAST all of our sequences against UniProt
import pandas as pd
import os
import subprocess
import time
import re
import pickle
import numpy as np

from fuson_plm.utils.logging import log_update, open_logfile
from fuson_plm.utils.embedding import redump_pickle_dictionary
from fuson_plm.data.blast.plot import group_difference_plot, group_swiss_and_ht_plot, group_box_plot, group_pos_id_plot

def prepare_blast_inputs():
    log_update("\nPreparing BLAST Inputs. Logging every 1000 sequences... ")
    # make directory for input and output 
    os.makedirs("blast_inputs", exist_ok=True)
    
    # read the fuson database
    fuson_db = pd.read_csv('../fuson_db.csv')
    
    # make dictionary mapping sequences to seqids (for naming input filess)
    fuson_db_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
    
    # convert the database into fasta format
    new_fa_files_created = 0
    old_fa_files_found = 0
    total_seqs_processed=0
    for i, (seq, seqid) in enumerate(fuson_db_dict.items()):
        total_seqs_processed+=1
        # if the path already exists, skip
        if os.path.exists(f"blast_inputs/{seqid}.fa"):
            old_fa_files_found+=1
        else:
            new_fa_files_created+=1
            with open(f"blast_inputs/{seqid}.txt", 'w') as f:
                fasta_lines = '>' + seqid + '\n' + seq
                f.write(fasta_lines)
            # rename it to .fa
            os.rename(f"blast_inputs/{seqid}.txt", f"blast_inputs/{seqid}.fa")

        if i%1000==0:
            log_update(f"\t\t{i}\t{seqid}:{seq}")
    
    log_update("\tFinished preparing BLAST Inputs (results in blast_inputs folder)")
    log_update(f"\t\tSequences processed: {total_seqs_processed}/{len(fuson_db)} seqs in FusOn-DB\n\t\tFasta files found: {old_fa_files_found}\n\t\tNew fasta files created: {new_fa_files_created}")

def run_blast(blast_inputs_dir, database="swissprot",n=1,interval=2000):
    """
    Run BLAST on all files in blast_inputs_dir
    """  
    # Must change the PATH variable to include the BLAST executables 
    os.environ['PATH'] += ":./ncbi-blast-2.16.0+/bin"   
    os.environ['BLASTDB'] = f"ncbi-blast-2.16.0+/{database}"
    
    # make directory for outputs
    os.makedirs("blast_outputs", exist_ok=True)
    os.makedirs(f"blast_outputs/{database}", exist_ok=True)
    already_blasted = os.listdir(f"blast_outputs/{database}")
    blast_input_files = os.listdir(blast_inputs_dir)
    # Sort the list using a custom key to extract the numeric part
    blast_input_files = sorted(blast_input_files, key=lambda x: int(re.search(r'\d+', x).group()))
    
    # print how many we've already blasted
    log_update(f"Running BLAST.\n\t{len(blast_input_files)} input files\n\t{len(already_blasted)} already blasted\n") 
    
    tot_seqs_processed = 0
    total_blast_time = 0
    
    start_i = interval*(n-1)
    end_i = interval*n
    if end_i>len(blast_input_files): end_i = len(blast_input_files)
    for i, blast_input_file in enumerate(blast_input_files[start_i:end_i]):
        tot_seqs_processed+=1
        # blast_input_file is of the format seqid.fa
        seqid = blast_input_file.split('.fa')[0]
        input_path = f"blast_inputs/{blast_input_file}"
        output_path = f"blast_outputs/{database}/{seqid}_{database}_results.out"
        
        if os.path.exists(output_path):
            log_update(f"\t{i+1}.\tAlready blasted {seqid}")
            continue
        
        # Construct the command as a list of arguments
        command = [
            "ncbi-blast-2.16.0+/bin/blastp",
            "-db", database,
            "-query", input_path,
            "-out", output_path
        ]

        # Run the command, and time it
        blast_start_time = time.time()
        result = subprocess.run(command, capture_output=True, text=True)
        blast_end_time = time.time()
        blast_seq_time = blast_end_time-blast_start_time
        total_blast_time+=blast_seq_time

        # Check if there was an error
        if result.returncode != 0:
            log_update(f"\t{i+1}.\tError running BLAST for {seqid}: {result.stderr} ({blast_seq_time:.2f}s)")
        else:
            log_update(f"\t{i+1}.\tBLAST search completed for {seqid} ({blast_seq_time:.2f}s)")
    
    log_update(f"\tFinished processing {tot_seqs_processed} sequences ({total_blast_time:.2f}s)")
 
def remove_incomplete_blasts(database="swissprot"):
    incomplete_list = []
    for fname in os.listdir(f"blast_outputs/{database}"):
        complete=False
        with open(f"blast_outputs/{database}/{fname}", "r") as f:
            lines = f.readlines()
            if len(lines)>1 and "Window for multiple hits:" in lines[-1]:
                complete=True
            if not complete:
                incomplete_list.append(fname)

    log_update(f"\t{len(incomplete_list)} BLAST files are incomplete (due to BLAST errors). Deleting them. Rerun these")
    # remove all these files
    for fname in incomplete_list:
        os.remove(f"blast_outputs/{database}/{fname}")
        
def find_nomatch_blasts(fuson_ht_db, database="swissprot"):
    no_match_list = []
    for fname in os.listdir(f"blast_outputs/{database}"):
        match=True
        with open(f"blast_outputs/{database}/{fname}", "r") as f:
            lines = f.readlines()
            if len(lines)>1 and "No hits found" in lines[28]:   # it'll say no hits found if there are no hits
                match=False
            if not match:
                no_match_list.append(fname)

    log_update(f"\t{len(no_match_list)} sequence IDs had no match in the BLAST database {database}")
    # write no match list to a file in blast_outputs
    with open(f"blast_outputs/{database}_no_match.txt","w") as f:
        for i, fname in enumerate(no_match_list):
            if i!=len(no_match_list)-1:
                f.write(f"{fname}\n")
            else:
                f.write(f"{fname}")
    
    # write a subset of fuson_ht_db containing these sequences as well 
    no_match_ids = [x.split('_')[0] for x in no_match_list]
    subset = fuson_ht_db.loc[
        fuson_ht_db['seq_id'].isin(no_match_ids)
    ].reset_index(drop=True)
    subset.to_csv(f"blast_outputs/{database}_no_match.csv",index=False)
    
    return no_match_ids
    
def make_fuson_ht_db(path_to_fuson_db="../fuson_db.csv", path_to_unimap="../head_tail_data/htgenes_uniprotids.csv",savepath="fuson_ht_db.csv"):
    """
    Make a version of the fuson_db that has all the heads and tails for each of the genes. Will make it easier to analyze blast results
    """
    if os.path.exists(savepath):
        df = pd.read_csv(savepath)
        return df
    
    # read both of teh databases
    fuson_db = pd.read_csv(path_to_fuson_db)
    ht_db = pd.read_csv(path_to_unimap)
    
    # Make it such that each row of fuson_db just has ONE head and ONE tail
    fuson_ht_db = fuson_db.copy(deep=True)
    fuson_ht_db['fusiongenes'] = fuson_ht_db['fusiongenes'].apply(lambda x: x.split(','))
    fuson_ht_db = fuson_ht_db.explode('fusiongenes')
    fuson_ht_db['hgene'] = fuson_ht_db['fusiongenes'].str.split('::',expand=True)[0]
    fuson_ht_db['tgene'] = fuson_ht_db['fusiongenes'].str.split('::',expand=True)[1]

    # Merge on head, then merge on tail
    fuson_ht_db = pd.merge(             # merge on head
        fuson_ht_db,
        ht_db.rename(columns={
            'Gene': 'hgene',
            'UniProtID': 'hgUniProt',
            'Reviewed': 'hgUniProtReviewed'
        }),
        on='hgene',
        how='left'
    )
    fuson_ht_db = pd.merge(             # merge on tail
        fuson_ht_db,
        ht_db.rename(columns={
            'Gene': 'tgene',
            'UniProtID': 'tgUniProt',
            'Reviewed': 'tgUniProtReviewed'
        }),
        on='tgene',
        how='left'
    )
    
    # Make sure we haven't lost anything
    tot_og_seqids = len(fuson_db['seq_id'].unique())
    tot_final_seqids = len(fuson_ht_db['seq_id'].unique())
    log_update(f"\tTotal sequence IDs in combined database = {tot_final_seqids}. Matches expected: {tot_final_seqids==tot_og_seqids}")
    # Each fusion should have the same number of ROWS as it does commas+1
    fuson_db['n_commas'] = fuson_db['fusiongenes'].str.count(',') + 1
    seqid_rows_map = dict(zip(fuson_db['seq_id'],fuson_db['n_commas']))
    vc = fuson_ht_db['seq_id'].value_counts().reset_index()
    vc['expected_count'] = vc['index'].map(seqid_rows_map)
    log_update(f"\tEach seq_id has the expected number of head-tail combos: {(vc['expected_count']==vc['seq_id']).all()}")
    
    log_update(f"\tPreview of combined database:")
    prev = fuson_ht_db.head(10)
    prev['aa_seq'] = prev['aa_seq'].apply(lambda x: x[0:10]+'...')
    log_update(prev.to_string(index=False))
    fuson_ht_db.to_csv(savepath, index=False)
    return fuson_ht_db

def format_dict(d, indent=0):
    """
    Recursively formats a dictionary for display purposes.
    
    Args:
        d (dict): The dictionary to format.
        indent (int): The current level of indentation.
    
    Returns:
        str: A formatted string representing the dictionary.
    """
    formatted_str = ""
    # Iterate through each key-value pair in the dictionary
    for key, value in d.items():
        # Create the current indentation
        current_indent = " " * (indent * 4)
        # Add the key
        formatted_str += f"{current_indent}{repr(key)}: "
        
        # Check the type of the value
        if isinstance(value, dict):
            # If dictionary, call format_dict recursively
            formatted_str += "{\n" + format_dict(value, indent + 1) + current_indent + "},\n"
        elif isinstance(value, list):
            # If list, convert it to a formatted string
            formatted_str += f"[{', '.join(repr(item) for item in value)}],\n"
        elif isinstance(value, str):
            # If string, enclose in quotes
            formatted_str += f"'{value}',\n"
        elif value is None:
            # If None, display as 'None'
            formatted_str += "None,\n"
        else:
            formatted_str += f"{repr(value)},\n"
    
    return formatted_str

def parse_blast_output(file_path, head_ids, tail_ids):
    """
    Args:
        - file_path: /path/to/blast/output
        - head_ids: list of all UniProt IDs for the head protien
        - tail_ids: list of all UniProt IDs for the tail protein
    """
    target_ids = list(set(head_ids + tail_ids))    # make a list to make some functions easier
    with open(file_path, 'r') as file:
        best_data = {tid: None for tid in target_ids}   # stores the best alignment for each ID we care about
        current_data = {tid: {} for tid in target_ids}  # stores the current data for each ID we care about (most recent alignment we read)
        best_score = {tid: -float('inf') for tid in target_ids} # stores the best score for each ID we care about
        capture = {tid: False for tid in target_ids}    # whether we are currently processing this ID
        replace_best = {tid: False for tid in target_ids}   # whether we should replace the best_data with the current_data for this ID
        isoform_dict = {tid: None for tid in target_ids}    # dictionary of isoforms for 
        
        # variables that will only be used for getting the best alignment
        alignment_count = 0
        cur_id = None
        on_best_alignment=False

        # Iterate through lines
        for line in file:
            line = line.strip()
            # if NEW ID (not necessarily new alignment! can be multiple alignmetns under one >)
            if line.startswith('>'):
                found_tid_in_header=False   # assume we have not found a target ID we are looking for
                alignment_count+=1
                if alignment_count==1:  # we're on the best alignment because this is the one that's listed first! it should be
                    on_best_alignment=True
                else:
                    on_best_alignment = False
                    
                ## We may have just finisehd processing an ID. Check for the one who currently has capture set to true
                just_captured = None
                total_captured = 0
                for k, v in capture.items():
                    if v:
                        total_captured+=1
                        just_captured = k
                # we should never be capturing more than one thing at a time. make sure of this
                assert total_captured<2
                if just_captured is not None:
                    if replace_best[just_captured]:   # if we just finished an alignment for the just_captured ID, and it's the best one, put it in
                        best_data[just_captured] = current_data[just_captured].copy()
                        replace_best[just_captured] = False     # we just did the replacement, so reset it 
                    
                # Check if the line contains any of the target IDs. 
                # This means EITHER [UniProtID] or [UniProtID.Isoform] or [UniProtID-Isoform] is in the line
                for tid in target_ids:
                    pattern = fr">{tid}([.-]\d+)? "    # for ID P02671, would match ">P02671 ", ">P02671.2 " and ">P02671-2 "
                    if re.search(pattern, line):    # if this ID matches
                        isoform_dict[tid] = None    # set it to None, update it if we need to 
                        if "." in line: # look for isoform denoted by . if there is one, otherwise it'll stay as None
                            isoform = int(line.split(".")[1].split(" ")[0])
                            isoform_dict[tid] = isoform
                            #print(f"\t\tID = {tid} (is a head or tail), isoform={isoform}")
                        elif "-" in line: # look for isoform denoted by - if there is one, otherwise it'll stay as None
                            isoform = int(line.split("-")[1].split(" ")[0])
                            isoform_dict[tid] = isoform
                            #print(f"\t\tID = {tid} (is a head or tail), isoform={isoform}")
                        capture[tid] = True
                        current_data[tid] = {'header': line}
                        found_tid_in_header=True   # we've found the tid that's in this line, so no need to check theothers 
                    else:
                        capture[tid] = False
                
                if on_best_alignment:   # if this is the best alignment
                    if not(found_tid_in_header):    # if none of our TIDs are it
                        cur_id_full = line.split('>')[1].split(' ')[0]
                        cur_id, isoform = cur_id_full, None
                        isoform_dict[cur_id] = None # change this if we need
                        if "." in cur_id_full:  # if there's a dot, it's an isoform. 
                            cur_id = cur_id_full.split(".")[0]
                            isoform = int(cur_id_full.split(".")[1])
                            isoform_dict[cur_id] = isoform
                            #log_update(f"\t\tID = {cur_id} (best alignment, not a head or tail), isoform={isoform}")
                            #log_update(f"\t\t\tFull line: {line}")  # so we can see the gene name. does it make sense? 
                        elif "-" in cur_id_full:  # if there's a -, it's an isoform. 
                            cur_id = cur_id_full.split("-")[0]
                            isoform = int(cur_id_full.split("-")[1])
                            isoform_dict[cur_id] = isoform
                            #log_update(f"\t\tID = {cur_id} (best alignment, not a head or tail), isoform={isoform}")
                            #log_update(f"\t\t\tFull line: {line}")  # so we can see the gene name. does it make sense? 
                        # add this id to all the dictionaries
                        best_data[cur_id] = None
                        current_data[cur_id] = {}
                        best_score[cur_id] = -float('inf')
                        capture[cur_id] = False
                        replace_best[cur_id] = False
                        
                            
            for tid in target_ids:
                if capture[tid]:    # if we're currently on an alignment for a tid we care about
                    if 'Score =' in line:
                        if replace_best[tid]:   # if we're replacing the best alignment with this one, within the same ID, do it 
                            best_data[tid] = current_data[tid].copy()
                            # now reset the variable! 
                            replace_best[tid] = False

                        score_value = float(line.split()[2])  # Assuming "Score = 1053 bits (2723)" format
                        current_data[tid] = {}  # Reset current_data for this ID
                        current_data[tid]['Isoform'] = isoform_dict[tid]
                        current_data[tid]['Score'] = score_value
                        current_data[tid]['Expect'] = line.split('Expect =')[1].split(', Method')[0].strip()
                        current_data[tid]['Query_Aligned'] = []
                        current_data[tid]['Subject_Aligned'] = []
                        # Set the ID as a head or tail, or neither (neither shouldn't happen here though)
                        if tid in head_ids:
                            current_data[tid]['H_or_T'] = 'Head' 
                            if tid in tail_ids:
                                current_data[tid]['H_or_T'] = 'Head,Tail'
                        elif tid in tail_ids:
                            current_data[tid]['H_or_T'] = 'Tail'
                        else:
                            current_data[tid]['H_or_T'] = np.nan
                            
                        current_data[tid]['Best'] = True if on_best_alignment else False
                        if score_value > best_score[tid]:   # if this is the best score we have for an alignment of this protein
                            best_score[tid] = score_value
                            replace_best[tid] = True
                        else:
                            replace_best[tid] = False

                    if 'Identities =' in line:
                        idents = line.split(', ')
                        current_data[tid]['Identities'] = idents[0].split('=')[1].strip()
                        current_data[tid]['Positives'] = idents[1].split('=')[1].strip()
                        current_data[tid]['Gaps'] = idents[2].split('=')[1].strip()
                    if line.startswith('Query'):
                        parts = line.split()
                        if 'Query_Start' not in current_data[tid]:
                            current_data[tid]['Query_Start'] = int(parts[1])
                        current_data[tid]['Query_End'] = int(parts[3])
                        current_data[tid]['Query_Aligned'].append(parts[2])
                    if line.startswith('Sbjct'):
                        parts = line.split()
                        if 'Sbjct_Start' not in current_data[tid]:
                            current_data[tid]['Sbjct_Start'] = int(parts[1])
                        current_data[tid]['Sbjct_End'] = int(parts[3])
                        current_data[tid]['Subject_Aligned'].append(parts[2])
                        
            # if we're on the best alignment and it's not one of our target_ids, still process it the same way
            if on_best_alignment:
                if not(found_tid_in_header):
                    if 'Score =' in line:
                        if replace_best[cur_id]:   # if we're replacing the best alignment with this one, within the same ID, do it 
                            best_data[cur_id] = current_data[cur_id].copy()
                            # now reset the variable! 
                            replace_best[cur_id] = False
                            
                        score_value = float(line.split()[2])  # Assuming "Score = 1053 bits (2723)" format
                        current_data[cur_id] = {}  # Reset current_data for this ID
                        current_data[cur_id]['Isoform'] = isoform_dict[cur_id]
                        current_data[cur_id]['Score'] = score_value
                        current_data[cur_id]['Expect'] = line.split('Expect =')[1].split(', Method')[0].strip()
                        current_data[cur_id]['Query_Aligned'] = []
                        current_data[cur_id]['Subject_Aligned'] = []
                        # Set the ID as a head or tail, or neither
                        if cur_id in head_ids:
                            current_data[cur_id]['H_or_T'] = 'Head' 
                            if cur_id in tail_ids:
                                current_data[cur_id]['H_or_T'] = 'Head,Tail'
                        elif cur_id in tail_ids:
                            current_data[cur_id]['H_or_T'] = 'Tail'
                        else:
                            current_data[cur_id]['H_or_T'] = np.nan
                            
                        current_data[cur_id]['Best'] = True
                        if score_value > best_score[cur_id]:   # if this is the best score we have for an alignment of this protein
                            best_score[cur_id] = score_value
                            replace_best[cur_id] = True
                        else:
                            replace_best[cur_id] = False

                    if 'Identities =' in line:
                        idents = line.split(', ')
                        current_data[cur_id]['Identities'] = idents[0].split('=')[1].strip()
                        current_data[cur_id]['Positives'] = idents[1].split('=')[1].strip()
                        current_data[cur_id]['Gaps'] = idents[2].split('=')[1].strip()
                    if line.startswith('Query'):
                        parts = line.split()
                        if 'Query_Start' not in current_data[cur_id]:
                            current_data[cur_id]['Query_Start'] = int(parts[1])
                        current_data[cur_id]['Query_End'] = int(parts[3])
                        current_data[cur_id]['Query_Aligned'].append(parts[2])
                    if line.startswith('Sbjct'):
                        parts = line.split()
                        if 'Sbjct_Start' not in current_data[cur_id]:
                            current_data[cur_id]['Sbjct_Start'] = int(parts[1])
                        current_data[cur_id]['Sbjct_End'] = int(parts[3])
                        current_data[cur_id]['Subject_Aligned'].append(parts[2])
                    
        # add cur_id to target_ids if it's not none
        if not(cur_id is None):
            target_ids += [cur_id]
            
        # Check at the end of the file if the last scores are the best 
        for tid in target_ids:
            if replace_best[tid]:
                best_data[tid] = current_data[tid].copy()

        # Combine sequences into single strings for the best data for each ID
        for tid in target_ids:
            #print(tid)
            if best_data[tid]:
                #print(f"there is a best alignment for {tid}")
                #print(f"best: {best_data[tid]}")
                #print(f"current: {current_data[tid]}")
                best_data[tid]['Query_Aligned'] = ''.join(best_data[tid]['Query_Aligned'])
                best_data[tid]['Subject_Aligned'] = ''.join(best_data[tid]['Subject_Aligned'])

    return best_data

def parse_all_blast_results(fuson_ht_db, database="swissprot"):
    """
    Analyze the BLAST outputs for each fusion protein against UniProt. 
    Use the fuson_ht_db to look for the heads and tails that we expect. If they can't be found, ... ? 
    """
    output_file=f"blast_outputs/{database}_blast_output_analyzed.pkl"
    all_seq_ids = fuson_ht_db['seq_id'].unique().tolist()
    all_seq_ids = sorted(all_seq_ids, key=lambda x: int(re.search(r'\d+', x).group()))  # sort by the number. seq1, seq2, ...
    
    prior_results = {}
    if os.path.exists(output_file):
        with open(output_file, "rb") as f:
            prior_results = pickle.load(f)
    
    # Iterate through seq_ids 
    total_parse_time = 0
    tot_seqs_processed = 0
    for seq_id in all_seq_ids:
        try: 
            tot_seqs_processed+=1
            # If we've already processed it, skip 
            if seq_id in prior_results: 
                log_update(f"\tAlready processed {seq_id} blast results. Continuing")
                continue
            
            file_path = f"blast_outputs/{database}/{seq_id}_{database}_results.out"
            
            aa_seq = fuson_ht_db.loc[
                fuson_ht_db['seq_id']==seq_id
            ]['aa_seq'].tolist()[0]
            
            # Remember, fuson_ht_db has all the IDs for ALL the different head and tail gene identifiers. 
            fusion_genes = fuson_ht_db.loc[
                fuson_ht_db['seq_id']==seq_id
            ]['fusiongenes'].tolist()
            
            ##### Process heads
            head_ids = fuson_ht_db.loc[
                fuson_ht_db['seq_id']==seq_id
            ]['hgUniProt'].dropna().tolist()
            head_reviewed, head_reviewed_dict = "", {}
            if len(head_ids)>0: # if we found head IDs, we can process them and figure out if they're reviewed
                head_ids = ",".join(head_ids).split(",")
                head_reviewed = fuson_ht_db.loc[
                    fuson_ht_db['seq_id']==seq_id
                ]['hgUniProtReviewed'].dropna().tolist()
                head_reviewed = list("".join(head_reviewed))
                
                head_reviewed_dict = dict(zip(head_ids, head_reviewed))
                head_ids = list(head_reviewed_dict.keys())      # there may be some duplicates, so separate them out again
                head_reviewed = list(head_reviewed_dict.values())
            
            head_genes = fuson_ht_db.loc[
                fuson_ht_db['seq_id']==seq_id
            ]['hgene'].unique().tolist()
            
            ##### Process tails - same logic
            tail_ids = fuson_ht_db.loc[
                fuson_ht_db['seq_id']==seq_id
            ]['tgUniProt'].dropna().tolist()
            tail_reviewed, tail_reviewed_dict = "", {}
            if len(tail_ids)>0: # if we found tail IDs, we can process them and figure out if they're reviewed
                tail_ids = ",".join(tail_ids).split(",")
                tail_reviewed = fuson_ht_db.loc[
                    fuson_ht_db['seq_id']==seq_id
                ]['tgUniProtReviewed'].dropna().tolist()
                tail_reviewed = list("".join(tail_reviewed))
                
                tail_reviewed_dict = dict(zip(tail_ids, tail_reviewed))
                tail_ids = list(tail_reviewed_dict.keys())      # there may be some duplicates, so separate them out again
                tail_reviewed = list(tail_reviewed_dict.values())
            
            tail_genes = fuson_ht_db.loc[
                fuson_ht_db['seq_id']==seq_id
            ]['tgene'].unique().tolist()
            
            ###### Log what we just found
            log_update(f"\tEvaluating {seq_id}, fusion genes = {fusion_genes}, len = {len(aa_seq)}...\n\t\tfile_path={file_path}")
            #log_update(f"\n\t\thead genes={head_genes}\n\t\thead_ids={head_ids}\n\t\ttail genes={tail_genes}\n\t\ttail_ids={tail_ids}")
            
            ### Do the analysis and time it
            parse_start_time = time.time()       # time it 
            blast_data = parse_blast_output(file_path, head_ids, tail_ids)
            parse_end_time = time.time()
            parse_seq_time = parse_end_time-parse_start_time
            total_parse_time+=parse_seq_time
            log_update(f"\t\tBLAST output analysis completed for {seq_id} ({parse_seq_time:.2f}s)")
            
            # Give preview of results. Logging the whole dict would be too much, so let's just see what we found
            #log_update(format_dict(blast_data,indent=3))
            n_og_reviewed_head_ids = len([x for x in head_reviewed if x=='1'])
            found_head_ids = [x for x in list(blast_data.keys()) if (blast_data[x] is not None) and (blast_data[x].get('H_or_T',None) in ['Head','Head,Tail'])]
            n_found_reviewed_head_ids = len([x for x in found_head_ids if head_reviewed_dict[x]=='1'])
            
            n_og_reviewed_tail_ids = len([x for x in tail_reviewed if x=='1'])
            found_tail_ids = [x for x in list(blast_data.keys()) if (blast_data[x] is not None) and (blast_data[x].get('H_or_T',None) in ['Tail','Head,Tail'])]
            n_found_reviewed_tail_ids = len([x for x in found_tail_ids if tail_reviewed_dict[x]=='1'])
            
            #log_update(f"\t\t{len(found_head_ids)}/{len(head_ids)} head protein UniProt IDs ({n_found_reviewed_head_ids}/{n_og_reviewed_head_ids} REVIEWED heads) had alignments")
            #log_update(f"\t\t{len(found_tail_ids)}/{len(tail_ids)} tail protein UniProt IDs ({n_found_reviewed_tail_ids}/{n_og_reviewed_tail_ids} REVIEWED tails) had alignments")
        
            # write results to pickle file
            to_pickle_dict = {seq_id: blast_data}
            with open(output_file, 'ab+') as f:
                pickle.dump(to_pickle_dict, f)
    
        except:
            log_update(f"{seq_id} failed")
            # redump the pickle even if we hit an error, so that we can fix the error and continue processing results
            redump_pickle_dictionary(output_file)
            
    # Log total time
    log_update(f"\tFinished processing {tot_seqs_processed} sequences ({total_parse_time:.2f}s)")
    
    # redump the pickle
    redump_pickle_dictionary(output_file)

def analyze_blast_results(fuson_ht_db, database="swissprot"):
    blast_results_path=f"blast_outputs/{database}_blast_output_analyzed.pkl"
    stats_df_savepath = f"blast_outputs/{database}_blast_stats.csv"
    top_alignments_df_savepath = f"blast_outputs/{database}_top_alignments.csv"
    
    stats_df, top_alignments_df = None, None
    if os.path.exists(stats_df_savepath) and os.path.exists(top_alignments_df_savepath):
        stats_df = pd.read_csv(stats_df_savepath)
        top_alignments_df = pd.read_csv(top_alignments_df_savepath, dtype={'top_hg_UniProt_isoform':'str',
                                                                            'top_tg_UniProt_isoform': 'str',
                                                                            'top_UniProt_isoform': 'str'})
        
    else:
        with open(blast_results_path, "rb") as f:
            results = pickle.load(f)
            
        # analyze the results
        # first, basic stats. How many of them have at least one head or tail alignment??
        seqid_stats = {}
        top_alignments_dict = {}
        for seq_id in list(results.keys()):
            seqid_stats[seq_id] = {
                'hgAlignments': 0,
                'tgAlignments': 0,
                'totalAlignments': 0,
                'best_hgScore': 0,
                'best_tgScore': 0,
                'best_Score': 0
            }
            top_alignments_dict[seq_id] = {
                'top_hg_UniProtID': None,
                'top_hg_UniProt_isoform': None,
                'top_hg_UniProt_fus_indices': None,
                'top_tg_UniProtID': None,
                'top_tg_UniProt_isoform': None,
                'top_tg_UniProt_fus_indices': None,
                'top_UniProtID': None,
                'top_UniProt_isoform': None,
                'top_UniProt_fus_indices': None
            }
            for uniprot, d in results[seq_id].items():
                if not(d is None):
                    isoform = d['Isoform']
                    # set up the indices string
                    query_start = d['Query_Start']
                    if (query_start is None) or (type(query_start)==float and np.isnan(query_start)):
                        query_start = ''
                    else:
                        query_start = int(query_start)
                    query_end = d['Query_End']
                    if (query_end is None) or (type(query_end)==float and np.isnan(query_end)):
                        query_end = ''
                    else:
                        query_end = int(query_end)
                    fus_indices = f"{query_start},{query_end}".strip(",")
                    
                    if d['H_or_T'] in ['Head', 'Head,Tail']:
                        seqid_stats[seq_id]['hgAlignments'] +=1
                        if d['Score'] > seqid_stats[seq_id]['best_hgScore']:
                            seqid_stats[seq_id]['best_hgScore'] = d['Score']
                            if type(uniprot)==float or uniprot is None:
                                top_alignments_dict[seq_id]['top_hg_UniProtID'] = ''
                            else:
                                top_alignments_dict[seq_id]['top_hg_UniProtID'] = uniprot
                            if (type(isoform)==float and np.isnan(isoform)) or isoform is None:
                                top_alignments_dict[seq_id]['top_hg_UniProt_isoform'] = ''
                            else:
                                top_alignments_dict[seq_id]['top_hg_UniProt_isoform'] = str(int(isoform))
                            
                            top_alignments_dict[seq_id]['top_hg_UniProt_fus_indices'] = fus_indices
                            
                    if d['H_or_T'] in ['Tail','Head,Tail']:
                        seqid_stats[seq_id]['tgAlignments'] +=1
                        if d['Score'] > seqid_stats[seq_id]['best_tgScore']:
                            seqid_stats[seq_id]['best_tgScore'] = d['Score']
                            if type(uniprot)==float or uniprot is None:
                                top_alignments_dict[seq_id]['top_tg_UniProtID'] = ''
                            else:
                                top_alignments_dict[seq_id]['top_tg_UniProtID'] = uniprot
                            if (type(isoform)==float and np.isnan(isoform)) or isoform is None:
                                top_alignments_dict[seq_id]['top_tg_UniProt_isoform'] = ''
                            else:
                                top_alignments_dict[seq_id]['top_tg_UniProt_isoform'] = str(int(isoform))
                            
                            top_alignments_dict[seq_id]['top_tg_UniProt_fus_indices'] = fus_indices
                    # increment total no matter what type of alignment it is
                    seqid_stats[seq_id]['totalAlignments']+=1
                    #if d['Score'] > seqid_stats[seq_id]['best_Score']:
                    if d['Best']==True: # should be indicated if this is the best!!
                        seqid_stats[seq_id]['best_Score'] = d['Score']
                        if type(uniprot)==float or uniprot is None:
                            top_alignments_dict[seq_id]['top_UniProtID'] = ''
                        else:
                            top_alignments_dict[seq_id]['top_UniProtID'] = uniprot
                        if (type(isoform)==float and np.isnan(isoform)) or isoform is None:
                            top_alignments_dict[seq_id]['top_UniProt_isoform'] = ''
                        else:
                            top_alignments_dict[seq_id]['top_UniProt_isoform'] = str(int(isoform))
                        
                        top_alignments_dict[seq_id]['top_UniProt_fus_indices'] = fus_indices
                        # now get positives and identities
                        if 'Identities' not in d: print(seq_id, uniprot, d.keys())
                        identities = d['Identities']
                        identities = int(identities.split('/')[0])
                        positives = d['Positives']
                        positives = int(positives.split('/')[0])
                        top_alignments_dict[seq_id]['top_UniProt_nIdentities'] = identities
                        top_alignments_dict[seq_id]['top_UniProt_nPositives'] = positives
                        
        
        stats_df = pd.DataFrame.from_dict(seqid_stats, orient='index').reset_index().rename(columns={'index':'seq_id'})
        stats_df['h_or_t_alignment'] = stats_df.apply(lambda row: True if (row['hgAlignments']>0 or row['tgAlignments']>0) else False, axis=1)
        stats_df['h_and_t_alignment'] = stats_df.apply(lambda row: True if (row['hgAlignments']>0 and row['tgAlignments']>0) else False, axis=1)
        stats_df.to_csv(stats_df_savepath,index=False)
        
        top_alignments_df = pd.DataFrame.from_dict(top_alignments_dict, orient='index').reset_index().rename(columns={'index':'seq_id'})
        # add in the sequence length so we can get percentages
        fusion_id_seq_dict = dict(zip(fuson_ht_db['seq_id'],fuson_ht_db['aa_seq']))
        assert len(fusion_id_seq_dict) == len(fuson_ht_db['seq_id'].unique()) == len(fuson_ht_db['aa_seq'].unique())
        top_alignments_df['aa_seq_len'] = top_alignments_df['seq_id'].map(fusion_id_seq_dict).str.len()
        
        top_alignments_df.to_csv(top_alignments_df_savepath,index=False)
    # also, find which ones have no match at all 
    # does it match?
    no_match_list1 = find_nomatch_blasts(fuson_ht_db, database=database)
    
    log_update(stats_df.head(10).to_string())
    # how many have at least one head or tail?
    log_update(f"Total sequences: {len(stats_df)}")
    log_update(f"Sequences with >=1 head alignment: {len(stats_df.loc[stats_df['hgAlignments']>0])}")
    log_update(f"Sequences with >=1 tail alignment: {len(stats_df.loc[stats_df['tgAlignments']>0])}")
    log_update(f"Sequences with >=1 head OR tail alignment: {len(stats_df.loc[stats_df['h_or_t_alignment']])}")
    log_update(f"Sequences with >=1 head AND tail alignment: {len(stats_df.loc[stats_df['h_and_t_alignment']])}")
    log_update(f"Sequences with ANY alignment: {len(stats_df.loc[stats_df['totalAlignments']>0])}")
    
    top_alignments_df = top_alignments_df.replace({None: ''})
    log_update(f"Preview of top alignments for {database} search:\n{top_alignments_df.head(10).to_string(index=False)}")
    top_alignments_df['hiso'] = top_alignments_df['top_hg_UniProtID']+'-'+top_alignments_df['top_hg_UniProt_isoform']
    top_alignments_df['tiso'] = top_alignments_df['top_tg_UniProtID']+'-'+top_alignments_df['top_tg_UniProt_isoform']
    top_alignments_df['biso'] = top_alignments_df['top_UniProtID']+'-'+top_alignments_df['top_UniProt_isoform']
    top_hgs = set([x.strip('-') for x in top_alignments_df['hiso'].tolist()])   # if things don't have isoforms they'll just end in -
    top_tgs = set([x.strip('-') for x in top_alignments_df['tiso'].tolist()])
    top_bgs = set([x.strip('-') for x in top_alignments_df['biso'].tolist()])
    top_gs = top_hgs | top_tgs | top_bgs
    log_update(f"\nTotal unique head proteins (including isoform) producing top head alignments: {len(top_hgs)}")
    log_update(f"\nTotal unique tail proteins (including isoform) producing top tail alignments: {len(top_tgs)}")
    log_update(f"\nTotal unique proteins (including isoform) - head, tail, or neither - producing top alignments: {len(top_gs)}")
    
    
    
    return stats_df, top_alignments_df

def compare_database_blasts(fuson_ht_db, swissprot_blast_stats, fusion_hts_blast_stats, make_new_plots=True):
    # let's start by just returning a list of IDs that were 
    # cols = seq_id  hgAlignments  tgAlignments  totalAlignments  best_hgScore  best_tgScore  best_Score  h_or_t_alignment  h_and_t_alignment
    
    # distinguish the columns
    og_cols = list(swissprot_blast_stats.columns)[1::]
    for c in og_cols:
        if c!='seq_id':
            swissprot_blast_stats = swissprot_blast_stats.rename(columns={c: f"swiss_{c}"})
    for c in og_cols:
        if c!='seq_id':
            fusion_hts_blast_stats = fusion_hts_blast_stats.rename(columns={c: f"hts_{c}"})
            
    # merge
    merged = pd.merge(swissprot_blast_stats,
                      fusion_hts_blast_stats,
                      on='seq_id',
                      how='outer')
    diff_cols = og_cols[0:-2]
    differences = pd.DataFrame(columns=diff_cols)
    log_update(f"Making volcano plots of the differences between fusion head-tail BLAST and swissprot BLAST in the following columns:\n\t{','.join(diff_cols)}")
    for c in diff_cols:
        differences[c] = merged[f"hts_{c}"] - merged[f"swiss_{c}"]

    # make some box plots of differences 
    # Generate volcano plots for each column
    if make_new_plots:
        os.makedirs("figures",exist_ok=True)
        os.makedirs("figures/database_comparison",exist_ok=True)
        os.makedirs("figures/database_comparison/differences",exist_ok=True)
        os.makedirs("figures/database_comparison/values",exist_ok=True)
        os.makedirs("figures/database_comparison/box",exist_ok=True)
        
        group_difference_plot(differences)
        group_swiss_and_ht_plot(merged.drop(columns=['seq_id']), diff_cols)
        group_box_plot(merged.drop(columns=['seq_id']), diff_cols)
        
def fasta_to_dataframe(fasta_file):
    # Read the file into a DataFrame with a single column
    df = pd.read_fwf(fasta_file, header=None, colspecs=[(0, None)], names=['content'])

    # Select even and odd lines using pandas slicing
    ids = df.iloc[::2].reset_index(drop=True)  # Even-indexed lines (IDs)
    sequences = df.iloc[1::2].reset_index(drop=True)  # Odd-indexed lines (sequences)

    # Combine into a new DataFrame
    fasta_df = pd.DataFrame({'ID': ids['content'], 'Sequence': sequences['content']})
    fasta_df['ID'] = fasta_df['ID'].str.split('>',expand=True)[1]
    fasta_df['Sequence'] = fasta_df['Sequence'].str.strip().str.strip('\n')
    
    # print a preview of this 
    temp = fasta_df.head(10)
    temp['Sequence'] = temp['Sequence'].apply(lambda x: x[0:10]+'...')
    log_update(f"Preview of head/tail fasta sequences in a dataframe:\n{temp.to_string(index=False)}")
    
    return fasta_df
    
def get_ht_uniprot_query(swissprot_top_alignments_df):
    '''
    Use swissprot_top_alignments_df to curate all the unique UniProt IDs (ID.Isoform) that created top head and tail alignments
    '''
    swissprot_top_alignments_df['top_hg_full'] = swissprot_top_alignments_df['top_hg_UniProtID']+'.'+swissprot_top_alignments_df['top_hg_UniProt_isoform']
    swissprot_top_alignments_df['top_tg_full'] = swissprot_top_alignments_df['top_tg_UniProtID']+'.'+swissprot_top_alignments_df['top_tg_UniProt_isoform']
    
    unique_heads = swissprot_top_alignments_df.loc[
        swissprot_top_alignments_df['top_hg_UniProtID'].notna()
    ]['top_hg_full'].unique().tolist()

    unique_tails = swissprot_top_alignments_df.loc[
        swissprot_top_alignments_df['top_tg_UniProtID'].notna()
    ]['top_tg_full'].unique().tolist()
    
    unique_ht = set(unique_heads).union(set(unique_tails))
    unique_ht = list(unique_ht)
    unique_ht = [x for x in unique_ht if len(x)>1]    # not just "."

    with open("blast_outputs/ht_uniprot_query.txt", "w") as f:
        for i, ht in enumerate(unique_ht):
            if i!= len(unique_ht)-1:
                f.write(f"{ht}\n")
            else:
                f.write(f"{ht}")
                
def main():
    # Later, add the argparse thing back in here and change where the log is and what happens depending on wht the user decides
    # May need to separate blast prep from actual blast for the manuscript, but worry about this later
    with open_logfile(f"fusion_blast_log.txt"):
        # Start by preparing BLAST inputs
        prepare_blast_inputs()
    
        # Then run BLAST
        run_blast("blast_inputs",database="swissprot")
        
        ###### Analyze BLAST results
        # Make database with head and tail info for each fusion, so we know what to expect
        fuson_ht_db = make_fuson_ht_db(savepath="fuson_ht_db.csv")
        
        #parse_all_blast_results(fuson_ht_db, database="swissprot")
        swissprot_blast_stats, swissprot_top_alignments_df = analyze_blast_results(fuson_ht_db,database="swissprot")

        swissprot_top_alignments_df = pd.read_csv("blast_outputs/swissprot_top_alignments.csv")
        get_ht_uniprot_query(swissprot_top_alignments_df)
        os.makedirs("figures/top_blast_visuals",exist_ok=True)
        group_pos_id_plot(swissprot_top_alignments_df)
        
if __name__ == '__main__':
    main()