Fill-Mask
Transformers
Safetensors
esm
File size: 33,104 Bytes
1e6a1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
## Imports
import pandas as pd
import numpy as np
import os
import sys
import pickle
from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS
from fuson_plm.utils.logging import open_logfile, log_update
from fuson_plm.utils.data_cleaning import clean_rows_and_cols, check_columns_for_listlike, check_item_for_listlike, find_delimiters, find_invalid_chars
from fuson_plm.data.config import CLEAN

def clean_fusionpdb(fusionpdb: pd.DataFrame, tcga_codes, delimiters, valid_aas) -> pd.DataFrame:
    """
    Return a cleaned version of the raw FusionPDB database, downloaded from FusionPDB website "Level 1" link

    Args:
        fusionpdb (pd.DataFrame): The raw FusionPDB database
        delimiters: delimiters to check for 

    Returns:
        pd.DataFrame: A cleaned version of the raw FusionPDB database with no duplicate sequences.

            Columns:
            - `aa_seq`:	amino acid sequence of fusion oncoprotein. each is unique.
            - `n_fusiongenes`: total number of fusion genes with this amino acid sequence.
            - `fusiongenes`:	comma-separated list of fusion genes (hgene::tgene) for this sequence. e.g., "MINK1::SPNS3,UBE2G1::SPNS3"
            - `cancers`:	comma-separated list of cancer types for this sequence. e.g., "breast invasive carcinoma,stomach adenocarcinoma"
            - `primary_source`:	source FusionPDB pulled the data from
            - `secondary_source`:
    """
    # Process and clean FusionPDB database
    log_update("Cleaning FusionPDB raw data")

    # FusionPDB is downloaded with no column labels. Fill in column labels here.
    log_update(f"\tfilling in column names...")
    fusionpdb = fusionpdb.rename(columns={
        0: 'ORF_type',
        1: 'hgene_ens',
        2: 'tgene_ens',
        3: '', # no data in this column
        4: 'primary_source', # database FusionPDB pulled from
        5: 'cancer',
        6: 'database_id',
        7: 'hgene',
        8: 'hgene_chr',
        9: 'hgene_bp',
        10: 'hgene_strand',
        11: 'tgene',
        12: 'tgene_chr',
        13: 'tgene_bp',
        14: 'tgene_strand',
        15: 'bp_dna_transcript',
        16: 'dna_transcript',
        17: 'aa_seq_len',
        18: 'aa_seq',
        19: 'predicted_start_dna_transcript',
        20: 'predicted_end_dna_transcript'
    })

    # Clean rows and columns
    fusionpdb = clean_rows_and_cols(fusionpdb)

    # Check for list-like qualities in the columns we plan to keep
    cols_of_interest =  ['hgene','tgene','cancer','aa_seq','primary_source']
    listlike_dict = check_columns_for_listlike(fusionpdb, cols_of_interest, delimiters)

    # Add a new column for fusiongene, which combines hgene::tgene. e.g., EWS::FLI1
    log_update("\tadding a column for fusiongene = hgene::tgene")
    fusionpdb['fusiongene'] = (fusionpdb['hgene'] + '::' + fusionpdb['tgene']).astype(str)

    # Make 'cancer' column type string to ease downstream processing
    log_update("\tcleaning the cancer column...")
    # turn '.' and nan entries into empty string
    fusionpdb = fusionpdb.replace('.',np.nan)
    fusionpdb['cancer'] = fusionpdb['cancer'].astype(str).replace('nan','')
    log_update("\t\tconverting cancer acronyms into full cancer names...")
    fusionpdb['cancer'] = fusionpdb['cancer'].apply(lambda x: tcga_codes[x].lower() if x in tcga_codes else x.lower())
    log_update("\t\tconverting all lists into comma-separated...")
    fusionpdb['cancer'] = fusionpdb['cancer'].str.replace(';',',')
    fusionpdb['cancer'] = fusionpdb['cancer'].str.replace(', ', ',')
    fusionpdb['cancer'] = fusionpdb['cancer'].str.strip()
    fusionpdb['cancer'] = fusionpdb['cancer'].str.strip(',')
    log_update(f"\t\tchecking for delimiters in the cleaned column...")
    check_columns_for_listlike(fusionpdb, ['cancer'], delimiters)
    
    # Now that we've dealt with listlike instances, make dictionary of hgene and tgene to their ensembl strings 
    log_update("\tcreating dictionary of head and tail genes mapped to Ensembl IDs, to be used later for aquiring UniProtAcc for head and tail genes (needed for BLAST analysis)")
    hgene_to_ensembl_dict = fusionpdb.groupby('hgene').agg(
        {
            'hgene_ens': lambda x: ','.join(set(x))
        }
    ).reset_index()
    hgene_to_ensembl_dict = dict(zip(hgene_to_ensembl_dict['hgene'],hgene_to_ensembl_dict['hgene_ens']))
    tgene_to_ensembl_dict = fusionpdb.groupby('tgene').agg(
        {
            'tgene_ens': lambda x: ','.join(set(x))
        }
    ).reset_index()
    tgene_to_ensembl_dict = dict(zip(tgene_to_ensembl_dict['tgene'],tgene_to_ensembl_dict['tgene_ens']))
    # now, we might have some of the same heads and tails being mapped to different things
    all_keys = set(hgene_to_ensembl_dict.keys()).union(set(tgene_to_ensembl_dict.keys()))
    gene_to_ensembl_dict = {}
    for k in all_keys:
        ens = hgene_to_ensembl_dict.get(k,'') + ',' + tgene_to_ensembl_dict.get(k,'')
        ens = ','.join(set(list(ens.strip(',').split(','))))
        gene_to_ensembl_dict[k] = ens
    os.makedirs("head_tail_data",exist_ok=True)
    with open(f"head_tail_data/gene_to_ensembl_dict.pkl", "wb") as f:
        pickle.dump(gene_to_ensembl_dict, f)
    total_unique_ens_ids = list(gene_to_ensembl_dict.values())
    total_unique_ens_ids = set(",".join(total_unique_ens_ids).split(","))
    log_update(f"\t\tTotal unique head/tail genes: {len(gene_to_ensembl_dict)}\n\t\tTotal unique ensembl ids: {len(total_unique_ens_ids)}")
    
    # To deal with duplicate sequences, group FusionPDB by sequence and concatenate fusion gene names, cancer types, and primary source
    log_update(f"\tchecking FusionPDB for duplicate protein sequences...\n\t\toriginal size: {len(fusionpdb)}")
    duplicates = fusionpdb[fusionpdb.duplicated('aa_seq')]['aa_seq'].unique().tolist()
    n_fgenes_with_duplicates = len(fusionpdb[fusionpdb['aa_seq'].isin(duplicates)]['fusiongene'].unique())
    n_rows_with_duplicates = len(fusionpdb[fusionpdb['aa_seq'].isin(duplicates)])
    log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows and {n_fgenes_with_duplicates} distinct fusiongenes")
    log_update(f"\tgrouping FusionPDB by amino acid sequence...")
    # Merge step
    fusionpdb = pd.merge(
        fusionpdb.groupby('aa_seq').agg({
            'fusiongene': lambda x: x.nunique()}).reset_index().rename(columns={'fusiongene':'n_fusiongenes'}),
        fusionpdb.groupby('aa_seq').agg({
            'fusiongene': lambda x: ','.join(x),
            'cancer': lambda x: ','.join(x),
            'primary_source': lambda x: ','.join(x)}).reset_index().rename(columns={'fusiongene':'fusiongenes', 'cancer': 'cancers', 'primary_source':'primary_sources'}).reset_index(drop=True).rename(columns={'fusiongene':'fusiongenes'}),
        on='aa_seq'
    )
    # Turn each aggregated column into sorted, comma-separated list
    fusionpdb['fusiongenes'] = fusionpdb['fusiongenes'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
    fusionpdb['cancers'] = fusionpdb['cancers'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
    fusionpdb['primary_sources'] = fusionpdb['primary_sources'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')

    # Count and display sequences with >1 fusion gene
    duplicates = fusionpdb.loc[fusionpdb['n_fusiongenes']>1]['aa_seq'].tolist()
    log_update(f"\t\treorganized database contains {len(duplicates)} proteins with >1 fusion gene")
    log_update(f"\t\treorganized database contains {len(fusionpdb)} unique oncofusion sequences")

    # Find invalid amino acids for each sequence and log_update the results
    fusionpdb['invalid_chars'] = fusionpdb['aa_seq'].apply(lambda x: find_invalid_chars(x, valid_aas))
    fusionpdb[fusionpdb['invalid_chars'].str.len()>0].sort_values(by='aa_seq')
    all_invalid_chars = set().union(*fusionpdb['invalid_chars'])
    log_update(f"\tchecking for invalid characters...\n\t\tset of all invalid characters discovered within FusionPDB: {all_invalid_chars}")

    # Filter out any sequences with invalid amino acids
    fusionpdb = fusionpdb[fusionpdb['invalid_chars'].str.len()==0].reset_index(drop=True).drop(columns=['invalid_chars'])
    log_update(f"\tremoving invalid characters...\n\t\tremaining sequences with valid AAs only: {len(fusionpdb)}")

    # Add a column for secondary source - FusionPDB.
    fusionpdb['secondary_source'] = ['FusionPDB']*len(fusionpdb)

    # Final checks of database cleanliness
    log_update(f"\tperforming final checks on cleaned FusionPDB...")
    duplicates = len(fusionpdb.loc[fusionpdb['aa_seq'].duplicated()]['aa_seq'].tolist())
    log_update(f"\t\t{duplicates} duplicate sequences")
    invalids=0
    for x in all_invalid_chars:
        invalids += len(fusionpdb.loc[fusionpdb['aa_seq'].str.contains(x)])
    log_update(f"\t\t{invalids} proteins containing invalid chracters")
    all_unique_seqs = len(fusionpdb)==len(fusionpdb['aa_seq'].unique())
    log_update(f"\t\tevery row contains a unique oncofusion sequence: {all_unique_seqs}")

    return fusionpdb

def clean_fodb(fodb: pd.DataFrame, fodb_codes, delimiters, valid_aas) -> pd.DataFrame:
    """
    Cleans the FOdb database
    
    Args:
        fodb (pd.DataFrame): raw FOdb.
        fodb_codes: 
        delimiters:
        valid_aas:
        
    
    Returns:
        pd.DataFrame: a cleaned version of FOdb with no duplicate sequences.
        
        Columns:
        - `aa_seq`:	amino acid sequence of fusion oncoprotein. each is unique.
        - `n_fusiongenes`: total number of fusion genes with this amino acid sequence.
        - `fusiongenes`:    comma-separated list of fusion genes (hgene::tgene) for this sequence. e.g., "MINK1::SPNS3,UBE2G1::SPNS3"
        - `cancers`:    comma-separated list of cancer types for this sequence. e.g., "breast invasive carinoma,stomach adenocarcinoma"
        - `primary_source`:	source FOdb pulled the data from
        - `secondary_source`: FOdb
    """
    
    log_update("Cleaning FOdb raw data")

    fodb['FO_Name'] = fodb['FO_Name'].apply(lambda x: x.split("_")[0]+"::"+x.split("_")[1])
    fodb = fodb.rename(columns={'Sequence_Source': 'primary_source', 'FO_Name': 'fusiongene', 'AA_Sequence': 'aa_seq'})
    fodb.head()

    # Clean rows and columns
    fodb = clean_rows_and_cols(fodb)
    
    # HEY1::NCOA2 has a "-" on the end by mistake. Replace this with '' for benchmarking purposes
    special_seq = "MKRAHPEYSSSDSELDETIEVEKESADENGNLSSALGSMSPTTSSQILARKRRRGIIEKRRRDRINNSLSELRRLVPSAFEKQGSAKLEKAEILQMTVDHLKMLHTAGGKAFNNPRPGQLGRLLPNQNLPLDITLQSPTGAGPFPPIRNSSPYSVIPQPGMMGNQGMIGNQGNLGNSSTGMIGNSASRPTMPSGEWAPQSSAVRVTCAATTSAMNRPVQGGMIRNPAASIPMRPSSQPGQRQTLQSQVMNIGPSELEMNMGGPQYSQQQAPPNQTAPWPESILPIDQASFASQNRQPFGSSPDDLLCPHPAAESPSDEGALLDQLYLALRNFDGLEEIDRALGIPELVSQSQAVDPEQFSSQDSNIMLEQKAPVFPQQYASQAQMAQGSYSPMQDPNFHTMGQRPSYATLRMQPRPGLRPTGLVQNQPNQLRLQLQHRLQAQQNRQPLMNQISNVSNVNLTLRPGVPTQAPINAQMLAQRQREILNQHLRQRQMHQQQQVQQRTLMMRGQGLNMTPSMVAPSGIPATMSNPRIPQANAQQFPFPPNYGISQQPDPGFTGATTPQSPLMSPRMAHTQSPMMQQSQANPAYQAPSDINGWAQGNMGGNSMFSQQSPPHFGQQANTSMYSNNMNINVSMATNTGGMSSMNQMTGQISMTSVTSVPTSGLSSMGPEQVNDPALRGGNLFPNQLPGMDMIKQEGDTTRKYC-"
    special_seq_name = "HEY1::NCOA2"
    fodb.loc[
        (fodb['fusiongene']==special_seq_name) & 
        (fodb['aa_seq']==special_seq), 'aa_seq'
    ] = special_seq.replace('-','')

    # filter out anything remaining with invalid characters
    fodb['invalid_chars'] = fodb['aa_seq'].apply(lambda x: find_invalid_chars(x, valid_aas))
    all_invalid_chars = set().union(*fodb['invalid_chars'])
    log_update(f"\tchecking for invalid characters...\n\t\tset of all invalid characters discovered within FOdb: {all_invalid_chars}")
    
    fodb = fodb[fodb['invalid_chars'].str.len()==0].reset_index(drop=True).drop(columns=['invalid_chars'])
    log_update(f"\tremoving invalid characters...\n\t\tremaining sequences with valid AAs only: {len(fodb)}")

    # aggregate the cancer data - if there's a 1 in the column, add it to the list of affected cancers
    # acronym -> cancer conversions based on Supplementary Table 3 of FOdb paper (Tripathi et al. 2023 Defining)
    log_update(f"\taggregating cancer data from {len(fodb.columns)-4} individual cancer columns into one...")
    log_update(f"\t\tchanging cancer names from acronyms to full")
    cancers = list(fodb.columns)[4::]
    fodb['cancers'] = ['']*len(fodb)
    for cancer in cancers:
        mapped_cancer = fodb_codes[cancer].lower() if cancer in fodb_codes else cancer
        fodb['cancers'] = fodb.apply(
            lambda row: row['cancers'] + f'{mapped_cancer},' if row[cancer] == 1 else row['cancers'],
            axis=1
        )
    fodb['cancers'] = fodb['cancers'].str.strip(',').replace('nan','')
    fodb = fodb.drop(columns=['Patient_Count']+cancers) 
    
    # Check for list-like qualities in the columns we plan to keep
    cols_of_interest =  ['primary_source','fusiongene','aa_seq','cancers']
    listlike_dict = check_columns_for_listlike(fodb, cols_of_interest, delimiters)
    
    # To deal with duplicate sequences, group fodb by sequence and concatenate fusion gene names, cancer types, and primary source
    log_update(f"\tchecking fodb for duplicate protein sequences...\n\t\toriginal size: {len(fodb)}")
    duplicates = fodb[fodb.duplicated('aa_seq')]['aa_seq'].unique().tolist()
    n_fgenes_with_duplicates = len(fodb[fodb['aa_seq'].isin(duplicates)]['fusiongene'].unique())
    n_rows_with_duplicates = len(fodb[fodb['aa_seq'].isin(duplicates)])
    log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows and {n_fgenes_with_duplicates} distinct fusiongenes")
    log_update(f"\tgrouping fodb by amino acid sequence...")
    # Merge step
    fodb = pd.merge(
        fodb.groupby('aa_seq').agg({
            'fusiongene': lambda x: x.nunique()}).reset_index().rename(columns={'fusiongene':'n_fusiongenes'}),
        fodb.groupby('aa_seq').agg({
            'fusiongene': lambda x: ','.join(x),
            'cancers': lambda x: ','.join(x),
            'primary_source': lambda x: ','.join(x)}).reset_index().rename(columns={'fusiongene':'fusiongenes', 'primary_source':'primary_sources'}).reset_index(drop=True).rename(columns={'fusiongene':'fusiongenes'}),
        on='aa_seq'
    )
    # Turn each aggregated column into sorted, comma-separated list
    fodb['fusiongenes'] = fodb['fusiongenes'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
    fodb['cancers'] = fodb['cancers'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
    fodb['primary_sources'] = fodb['primary_sources'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')

    # Count and display sequences with >1 fusion gene
    duplicates = fodb.loc[fodb['n_fusiongenes']>1]['aa_seq'].tolist()
    log_update(f"\t\treorganized database contains {len(duplicates)} proteins with >1 fusion gene")
    log_update(f"\t\treorganized database contains {len(fodb)} unique oncofusion sequences")

    # Add secondary source column because FOdb is the secondary source here. 
    fodb['secondary_source'] = ['FOdb']*len(fodb)

    # Final checks of database cleanliness
    log_update(f"\tperforming final checks on cleaned FOdb...")
    duplicates = len(fodb.loc[fodb['aa_seq'].duplicated()]['aa_seq'].tolist())
    log_update(f"\t\t{duplicates} duplicate sequences")
    invalids=0
    for x in all_invalid_chars:
        invalids += len(fodb.loc[fodb['aa_seq'].str.contains(x)])
    log_update(f"\t\t{invalids} proteins containing invalid chracters")
    all_unique_seqs = len(fodb)==len(fodb['aa_seq'].unique())
    log_update(f"\t\tevery row contains a unique oncofusion sequence: {all_unique_seqs}")

    return fodb

def create_fuson_db(fusionpdb: pd.DataFrame, fodb: pd.DataFrame) -> pd.DataFrame:
    """
    Merges cleaned FusionPDB and FOdb to create fuson_db (the full set of fusion sequences for training/benchmarking FusOn-pLM)
    
    Args:
        fusionpdb (pd.DataFrame): 
    """
    log_update("Creating the merged database...")

    log_update("\tconcatenating cleaned FusionPDb and cleaned FOdb...")
    fuson_db = pd.concat(
        [
            fusionpdb.rename(columns={'secondary_source':'secondary_sources'}),
            fodb.rename(columns={'secondary_source':'secondary_sources'})
        ]
    )

    # Handle dupliate amino acid sequences
    log_update(f"\tchecking merged database for duplicate protein sequences...\n\t\toriginal size: {len(fuson_db)}")
    duplicates = fuson_db[fuson_db.duplicated('aa_seq')]['aa_seq'].unique().tolist()
    n_fgenes_with_duplicates = len(fuson_db[fuson_db['aa_seq'].isin(duplicates)]['fusiongenes'].unique())
    n_rows_with_duplicates = len(fuson_db[fuson_db['aa_seq'].isin(duplicates)])
    log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows and {n_fgenes_with_duplicates} distinct fusiongenes")
    log_update(f"\tgrouping database by amino acid sequence...")

    fuson_db = fuson_db.groupby('aa_seq').agg(
        {
            'fusiongenes': lambda x: ','.join(x),
            'cancers': lambda x: ','.join(x),
            'primary_sources': lambda x: ','.join(x),
            'secondary_sources': lambda x: ','.join(x)
        }
    ).reset_index()
    duplicates = fuson_db.loc[fuson_db['fusiongenes'].str.count(',')>0]['aa_seq'].tolist()
    log_update(f"\t\treorganized database contains {len(duplicates)} proteins with >1 fusion gene")
    log_update(f"\t\treorganized database contains {len(fuson_db)} unique oncofusion sequences")

    # Turn each aggregated column into a set of only the unique entires
    for column in fuson_db.columns[1::]:
        fuson_db[column] = fuson_db[column].apply(lambda x: (',').join(sorted(set(
            [y for y in x.split(',') if len(y)>0]))))
    
    # Add a column for length
    log_update(f"\tadding a column for length...")
    fuson_db['length'] = fuson_db['aa_seq'].apply(lambda x: len(x))

    # Sort by fusiongenes, then length
    log_update(f"\tsorting by fusion gene name, then length...")
    fuson_db = fuson_db.sort_values(by=['fusiongenes','length'],ascending=[True,True]).reset_index(drop=True)

    # Add a seq_id column: seq1, seq2, ..., seqn
    log_update(f"\tadding sequence ids: seq1, seq2, ..., seqn")
    fuson_db['seq_id'] = ['seq'+str(i+1) for i in range(len(fuson_db))]

    # Final checks of database cleanliness
    log_update(f"\tperforming final checks on fuson_db...")
    duplicates = len(fuson_db.loc[fuson_db['aa_seq'].duplicated()]['aa_seq'].tolist())
    log_update(f"\t\t{duplicates} duplicate sequences")
    all_unique_seqs = len(fuson_db)==len(fuson_db['aa_seq'].unique())
    log_update(f"\t\tevery row contains a unique oncofusion sequence: {all_unique_seqs}")
    
    return fuson_db

def head_tail_mappings(fuson_db):
    log_update("\nGenes and Ensembl IDs corresponding to the head and tail proteins have been mapped on UniProt. Now, combining these results.")

    # Read the ensembl map, gene name map, and dictionary from gene --> ensembl ids
    ensembl_map = pd.read_csv("head_tail_data/ensembl_ht_idmap.txt",sep="\t")
    name_map = pd.read_csv("head_tail_data/genename_ht_idmap.txt",sep="\t")    
    with open("head_tail_data/gene_to_ensembl_dict.pkl", "rb") as f:
        gene_ens_dict = pickle.load(f)

    log_update(f"\tCheck: ensembl map and gene name map have same columns: {set(ensembl_map.columns)==set(name_map.columns)}")
    log_update(f"\t\tColumns = {list(ensembl_map.columns)}")

    # Prepare to merge 
    log_update(f"\tMerging the ensembl map and gene name map:")
    ensembl_map = ensembl_map.rename(columns={'From': 'ensembl_id'})    # mapped from ensembl ids
    name_map = name_map.rename(columns={'From': 'htgene'})              # mapped from head or tail genes
    name_map['ensembl_id'] = name_map['htgene'].map(gene_ens_dict)      # add ensembl id column bsed on head and tail genes
    name_map['ensembl_id'] = name_map['ensembl_id'].apply(lambda x: x.split(',') if type(x)==str else x)   # make it a string if multiple matches
    log_update(f"\t\tLength of gene-based map before exploding ensembl_id column: {len(name_map)}")
    name_map = name_map.explode('ensembl_id')       # explode so each ensembl id is its own line 
    log_update(f"\t\tLength of gene-based map after exploding ensembl_id column: {len(name_map)}")
    log_update(f"\t\tLength of ensembl-based map: {len(ensembl_map)}")
    unimap = pd.merge(name_map[['htgene','ensembl_id','Entry','Reviewed']],
                    ensembl_map[['ensembl_id','Entry','Reviewed']],
                    on=['ensembl_id','Entry','Reviewed'],
                    how='outer'
                    )
    unimap['Reviewed'] = unimap['Reviewed'].apply(lambda x: '1' if x=='reviewed' else '0' if x=='unreviewed' else 'N')  # N for nan
    log_update(f"\t\tLength of merge: {len(unimap)}. Merge preview:")
    log_update(unimap.head())
    unimap = unimap.drop_duplicates(['htgene','Entry','Reviewed']).reset_index(drop=True)
    log_update(f"\t\tLength of merge after dropping rows where only ensembl_id changed: {len(unimap)}. Merge preview: ")
    log_update(unimap.head())
    unimap = unimap.groupby('htgene').agg(
        {
            'Entry': lambda x: ','.join(x),
            'Reviewed': lambda x: ''.join(x)
        }
    ).reset_index()
    unimap = unimap.rename(columns={
        'htgene': 'Gene',
        'Entry': 'UniProtID',
    })
    log_update(f"\t\tLength of merge after grouping by gene name: {len(unimap)}. Merge preview:")
    log_update(unimap.head())

    # what are the proteins whose head or tail genes are in this list? 
    log_update(f"\tChecking which fusion proteins have unmappable heads and/or tails:")
    temp = fuson_db.copy(deep=True)
    temp['fusiongenes'] = temp['fusiongenes'].apply(lambda x: x.split(','))
    temp = temp.explode('fusiongenes')
    temp['hgene'] = temp['fusiongenes'].str.split('::',expand=True)[0]
    temp['tgene'] = temp['fusiongenes'].str.split('::',expand=True)[1]

    # See which gene IDs weren't covered
    log_update(f"\tChecking which gene IDs were not mapped by either method")
    all_geneids = temp['hgene'].tolist() +temp['tgene'].tolist()
    all_geneids = list(set(all_geneids))
    all_mapped_genes = unimap['Gene'].unique().tolist()
    unmapped_geneids = set(all_geneids) - set(all_mapped_genes)
    log_update(f"\t\t{len(all_mapped_genes)}/{len(all_geneids)} were mapped\n\t\t{len(unmapped_geneids)}/{len(all_geneids)} were unmapped")
    log_update(f"\t\tUnmapped geneids: {','.join(unmapped_geneids)}")
    
    # Find the ok ones and print
    ok_seqs =  temp.loc[
        (temp['hgene'].isin(all_mapped_genes)) |    # head gene was found, OR
        (temp['tgene'].isin(all_mapped_genes))      # tail gene was found
    ]['seq_id'].unique().tolist()
    ok_seqsh =  temp.loc[
        (temp['hgene'].isin(all_mapped_genes))      # head gene was found
    ]['seq_id'].unique().tolist()
    ok_seqst =  temp.loc[
        (temp['tgene'].isin(all_mapped_genes))      # tail gene was found
    ]['seq_id'].unique().tolist()
    ok_seqsboth =  temp.loc[
        (temp['hgene'].isin(all_mapped_genes)) &    # head gene was found, AND
        (temp['tgene'].isin(all_mapped_genes))      # tail gene was found
    ]['seq_id'].unique().tolist()
    
    log_update(f"\tTotal fusion sequence ids: {len(temp['seq_id'].unique())}")
    log_update(f"\tFusion sequences with at least 1 mapped constituent:\
        \n\t\tMapped head: {len(ok_seqsh)}\
            \n\t\tMapped tail: {len(ok_seqst)}\
                \n\t\tMapped head or tail: {len(ok_seqs)}\
                    \n\t\tMapped head AND tail: {len(ok_seqsboth)}")
    
    # Now look at the bad side
    atleast_1_lost = temp.loc[
        ((temp['hgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqsh))) |   # head not found in row, AND head not found for seq_id - OR
        ((temp['tgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqst)))     # tail not found in row, AND tail not found for seq_id
    ]['seq_id'].unique().tolist()
    atleast_1_losth = temp.loc[
        (temp['hgene'].isin(unmapped_geneids)) &                # head not found in this row AND
        ~(temp['seq_id'].isin(ok_seqsh))                        # head not found for this seq id
    ]['seq_id'].unique().tolist()
    atleast_1_lostt = temp.loc[
        (temp['tgene'].isin(unmapped_geneids)) &                # tail not found in this row AND
        ~(temp['seq_id'].isin(ok_seqst))                        # tail not found for this seq id
    ]['seq_id'].unique().tolist()
    both_lost = temp.loc[
        ((temp['hgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqsh))) &   # there's no head, and this seq id has no head - AND
        ((temp['tgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqst)))     # there's no tail, and this seq id has no tail
    ]['seq_id'].unique().tolist()
    log_update(f"\tFusion sequences with at least 1 unmapped constituent:")
    log_update(f"\t\tUnmapped head: {len(atleast_1_losth)}\
        \n\t\tUnmapped tail: {len(atleast_1_lostt)}\
            \n\t\tUnmapped head or tail: {len(atleast_1_lost)}\
                \n\t\tUnmapped head AND tail: {len(both_lost)}")
    log_update(f"\tseq_ids with at least 1 unmapped part: {atleast_1_lost}")
    
    assert len(ok_seqsboth)+ len(atleast_1_lost) == len(temp['seq_id'].unique())
    log_update(f"\tFusions with H&T covered plus Fusions with H|T lost = total = {len(ok_seqsboth)}+ {len(atleast_1_lost)} = {len(ok_seqsboth)+ len(atleast_1_lost)} = {len(temp['seq_id'].unique())}")
 
    ### Save the unimap 
    unimap.to_csv('head_tail_data/htgenes_uniprotids.csv',index=False)

def assemble_uniprot_query(path_to_gene_ens_dict="head_tail_data/gene_to_ensembl_dict.pkl",path_to_fuson_db="fuson_db.csv"):
    """
    To analyze the BLAST results effectively, we must know which UniProt accessions we *expect* to see for each fusion oncoprotein.
    We will try to map each FO to its head and tail accessions by searching UniProt ID map by gene name and Ensembl ID.
    
    This method will create two input lists for UniProt:
        - gene_name_inputs.txt: list of all uinque head and tail gene names
        - ensembl_inputs.txt
    """
    log_update("\nMaking inputs for UniProt ID map, to find accessions for head and tail genes")
    if not(os.path.exists(path_to_gene_ens_dict)):
        raise Exception(f"File {path_to_gene_ens_dict} does not exist")
    
    with open(path_to_gene_ens_dict, "rb") as f:
        gene_ens_dict = pickle.load(f)
    
    all_htgenes_temp = list(gene_ens_dict.keys())
    all_ens = list(gene_ens_dict.values())
    all_ens = list(set(",".join(all_ens).split(",")))
    log_update(f"\tTotal unique head and tail genes, only accounting for FusionPDB: {len(all_htgenes_temp)}")
    
    # need to add other htgenes from UniProt 
    fuson_db = pd.read_csv(path_to_fuson_db)
    fuson_db['fusiongenes'] = fuson_db['fusiongenes'].apply(lambda x: x.split(','))
    fuson_db = fuson_db.explode('fusiongenes')
    fuson_db['hgene'] = fuson_db['fusiongenes'].str.split('::',expand=True)[0]
    fuson_db['tgene'] = fuson_db['fusiongenes'].str.split('::',expand=True)[1]
    fuson_htgenes = fuson_db['hgene'].tolist() + fuson_db['tgene'].tolist()
    fuson_htgenes = set(fuson_htgenes)
    all_htgenes = set(all_htgenes_temp).union(set(fuson_htgenes))
    all_htgenes = list(set(all_htgenes))
    
    log_update(f"\tTotal unique head and tail genes after adding FOdb: {len(all_htgenes)}")
    log_update(f"\tTotal unique ensembl IDs: {len(all_ens)}")
    # go through each and write a file
    input_dir = "head_tail_data/uniprot_idmap_inputs"
    os.makedirs(input_dir,exist_ok=True)
    
    if os.path.exists(f"{input_dir}/head_tail_genes.txt"):
        log_update("\nAlready assembled UniProt ID mapping input for head and tail genes. Continuing")
    else:
        with open(f"{input_dir}/head_tail_genes.txt", "w") as f:
            for i, gene in enumerate(all_htgenes):
                if i!=len(all_htgenes)-1:
                    f.write(f"{gene}\n")
                else:
                    f.write(f"{gene}")
    
    if os.path.exists(f"{input_dir}/head_tail_ens.txt"):
        log_update("\nAlready assembled UniProt ID mapping input for head and tail ensembl IDs. Continuing")
    else:
        with open(f"{input_dir}/head_tail_ens.txt", "w") as f:
            for i, ens in enumerate(all_ens):
                if i!=len(all_ens)-1:
                    f.write(f"{ens}\n")
                else:
                    f.write(f"{ens}")
def main():
    # Define global variables from config.DATA_CLEANING
    FODB_PATH = CLEAN.FODB_PATH
    FODB_PUNCTA_PATH = CLEAN.FODB_PUNCTA_PATH
    FUSIONPDB_PATH = CLEAN.FUSIONPDB_PATH
    LOG_PATH = "data_cleaning_log.txt"
    SAVE_CLEANED_FODB = False
    
    # Prepare the log file
    with open_logfile(LOG_PATH):
        log_update("Loaded data-cleaning configurations from config.py")
        CLEAN.print_config(indent='\t')
        
        log_update("Reading FusionPDB...")
        fusionpdb = pd.read_csv(FUSIONPDB_PATH,sep='\t',header=None)
        fusionpdb = clean_fusionpdb(fusionpdb, TCGA_CODES, DELIMITERS, VALID_AAS)

        log_update("Saving FusionPDB to FusionPDB_cleaned.csv...")
        fusionpdb.to_csv('raw_data/FusionPDB_cleaned.csv', index=False)

        # Clean FOdb, optinoally save
        log_update("Reading FOdb...")
        fodb = pd.read_csv(FODB_PATH)
        fodb = clean_fodb(fodb, FODB_CODES, DELIMITERS, VALID_AAS)

        if SAVE_CLEANED_FODB:
            log_update("Saving FOdb to FOdb_cleaned.csv...")
            fusionpdb.to_csv('FOdb_cleaned.csv', index=False)

        # Merge FusionPDB and FOdb to fuson_db
        fuson_db = create_fuson_db(fusionpdb, fodb)
        
        # Mark benchmarking sequences
        # FOdb puncta benchmark
        log_update("Adding benchmarking sequences to fuson_db...")
        fodb_puncta = pd.read_csv(FODB_PUNCTA_PATH)
        
        # handle the mistake sequence - take the "-" off the end
        special_seq = "MKRAHPEYSSSDSELDETIEVEKESADENGNLSSALGSMSPTTSSQILARKRRRGIIEKRRRDRINNSLSELRRLVPSAFEKQGSAKLEKAEILQMTVDHLKMLHTAGGKAFNNPRPGQLGRLLPNQNLPLDITLQSPTGAGPFPPIRNSSPYSVIPQPGMMGNQGMIGNQGNLGNSSTGMIGNSASRPTMPSGEWAPQSSAVRVTCAATTSAMNRPVQGGMIRNPAASIPMRPSSQPGQRQTLQSQVMNIGPSELEMNMGGPQYSQQQAPPNQTAPWPESILPIDQASFASQNRQPFGSSPDDLLCPHPAAESPSDEGALLDQLYLALRNFDGLEEIDRALGIPELVSQSQAVDPEQFSSQDSNIMLEQKAPVFPQQYASQAQMAQGSYSPMQDPNFHTMGQRPSYATLRMQPRPGLRPTGLVQNQPNQLRLQLQHRLQAQQNRQPLMNQISNVSNVNLTLRPGVPTQAPINAQMLAQRQREILNQHLRQRQMHQQQQVQQRTLMMRGQGLNMTPSMVAPSGIPATMSNPRIPQANAQQFPFPPNYGISQQPDPGFTGATTPQSPLMSPRMAHTQSPMMQQSQANPAYQAPSDINGWAQGNMGGNSMFSQQSPPHFGQQANTSMYSNNMNINVSMATNTGGMSSMNQMTGQISMTSVTSVPTSGLSSMGPEQVNDPALRGGNLFPNQLPGMDMIKQEGDTTRKYC-"
        special_seq_name = "HEY1_NCOA2"
        fodb_puncta.loc[
            (fodb_puncta['FO_Name']==special_seq_name) & 
            (fodb_puncta['AAseq']==special_seq), 'AAseq'
        ] = special_seq.replace('-','')
        
        fodb_puncta_sequences = fodb_puncta['AAseq'].unique().tolist()
        benchmark_sequences = dict(zip(fodb_puncta_sequences, ['Puncta']*len(fodb_puncta_sequences)))
        log_update(f"\tRead FOdb puncta data and isolated {len(benchmark_sequences)} sequences for puncta benchmark")
        # Biological discovery benchmark
        benchmark_sequences2 = fuson_db.loc[
            (fuson_db['fusiongenes'].str.contains('EWSR1::FLI1')) | 
            (fuson_db['fusiongenes'].str.contains('PAX3::FOXO1')) | 
            (fuson_db['fusiongenes'].str.contains('BCR::ABL1')) | 
            (fuson_db['fusiongenes'].str.contains('EML4::ALK'))  
        ]['aa_seq'].unique().tolist()
        benchmark_sequences2 = dict(zip(benchmark_sequences2, ['Biological Discovery']*len(benchmark_sequences2)))
        log_update(f"\tIsolated all EWSR1::FLI1, PAX3::FOXO1, BCR::ABL1, and EML4::ALK sequences ({len(benchmark_sequences2)} total) for biological benchmarks...")
        
        for k, v in benchmark_sequences2.items():
            if k in benchmark_sequences:
                benchmark_sequences[k] = benchmark_sequences[k] + ',' + v
            else:
                benchmark_sequences[k] = v
        
        log_update(f"\tTotal unique benchmark sequences: {len(benchmark_sequences)}")     
        # Add benchmark column
        log_update("\tAdding benchmark column...")
        fuson_db['benchmark'] = fuson_db['aa_seq'].apply(lambda x: benchmark_sequences[x] if x in benchmark_sequences else np.nan)
    
        # Save fuson_db
        log_update("\nWriting final database to fuson_db.csv...")
        fuson_db.to_csv('fuson_db.csv', index=False)
        log_update("Cleaning complete.")
        
        # Assemble head tail queries for UniProt 
        assemble_uniprot_query()
        
        # Do the head tail mappings
        head_tail_mappings(fuson_db)
        
if __name__ == '__main__':
    main()