Fill-Mask
Transformers
Safetensors
esm
root commited on
Commit
1e6a1f0
·
1 Parent(s): 9a73cb0

uploading data folder

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fuson_plm/data/README.md +91 -0
  2. fuson_plm/data/__init__.py +0 -0
  3. fuson_plm/data/__pycache__/__init__.cpython-310.pyc +0 -0
  4. fuson_plm/data/__pycache__/clean.cpython-310.pyc +0 -0
  5. fuson_plm/data/__pycache__/cluster.cpython-310.pyc +0 -0
  6. fuson_plm/data/__pycache__/config.cpython-310.pyc +0 -0
  7. fuson_plm/data/__pycache__/split_vis.cpython-310.pyc +0 -0
  8. fuson_plm/data/blast/README.md +113 -0
  9. fuson_plm/data/blast/__pycache__/blast_fusions.cpython-310.pyc +0 -0
  10. fuson_plm/data/blast/__pycache__/plot.cpython-310.pyc +0 -0
  11. fuson_plm/data/blast/blast_fusions.py +838 -0
  12. fuson_plm/data/blast/blast_outputs/best_htg_alignments_swissprot_seqs.pkl +3 -0
  13. fuson_plm/data/blast/blast_outputs/ht_uniprot_query.txt +3 -0
  14. fuson_plm/data/blast/blast_outputs/swissprot_blast_output_analyzed.pkl +3 -0
  15. fuson_plm/data/blast/blast_outputs/swissprot_blast_stats.csv +3 -0
  16. fuson_plm/data/blast/blast_outputs/swissprot_no_match.csv +3 -0
  17. fuson_plm/data/blast/blast_outputs/swissprot_no_match.txt +3 -0
  18. fuson_plm/data/blast/blast_outputs/swissprot_top_alignments.csv +3 -0
  19. fuson_plm/data/blast/extract_blast_seqs.py +62 -0
  20. fuson_plm/data/blast/figures/identities_hist.png +0 -0
  21. fuson_plm/data/blast/fusion_blast_log.txt +3 -0
  22. fuson_plm/data/blast/fuson_ht_db.csv +3 -0
  23. fuson_plm/data/blast/plot.py +75 -0
  24. fuson_plm/data/clean.py +594 -0
  25. fuson_plm/data/cluster.py +50 -0
  26. fuson_plm/data/clustering/input.fasta +3 -0
  27. fuson_plm/data/clustering/mmseqs_full_results.csv +3 -0
  28. fuson_plm/data/clustering_log.txt +3 -0
  29. fuson_plm/data/config.py +34 -0
  30. fuson_plm/data/data_cleaning_log.txt +3 -0
  31. fuson_plm/data/fuson_db.csv +3 -0
  32. fuson_plm/data/head_tail_data/ensembl_ht_idmap.txt +3 -0
  33. fuson_plm/data/head_tail_data/gene_to_ensembl_dict.pkl +3 -0
  34. fuson_plm/data/head_tail_data/genename_ht_idmap.txt +3 -0
  35. fuson_plm/data/head_tail_data/htgenes_uniprotids.csv +3 -0
  36. fuson_plm/data/head_tail_data/isoform_fasta_id_output_formatted.fasta +3 -0
  37. fuson_plm/data/head_tail_data/uniprot_idmap_inputs/head_tail_ens.txt +3 -0
  38. fuson_plm/data/head_tail_data/uniprot_idmap_inputs/head_tail_genes.txt +3 -0
  39. fuson_plm/data/raw_data/FOdb_SD5.csv +3 -0
  40. fuson_plm/data/raw_data/FOdb_all.csv +3 -0
  41. fuson_plm/data/raw_data/FOdb_puncta.csv +3 -0
  42. fuson_plm/data/raw_data/FusionPDB.txt +3 -0
  43. fuson_plm/data/raw_data/FusionPDB_cleaned.csv +3 -0
  44. fuson_plm/data/split.py +120 -0
  45. fuson_plm/data/split_vis.py +333 -0
  46. fuson_plm/data/splits/combined_plot.png +0 -0
  47. fuson_plm/data/splits/test_cluster_split.csv +3 -0
  48. fuson_plm/data/splits/test_df.csv +3 -0
  49. fuson_plm/data/splits/train_cluster_split.csv +3 -0
  50. fuson_plm/data/splits/train_df.csv +3 -0
fuson_plm/data/README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training Data Curation and Processing
2
+
3
+ The `data` folder and its subfolders hold all raw data and processed data used to assemble FusOn-DB, as well as all processing scripts. Additional benchmarking datasets can be found in the `benchmarking` folder.
4
+
5
+ ### From raw data to train/val/test splits and head/tail data
6
+ This section will outline the pipeline for converting the raw FusionPDB and FOdb datasets into the train/val/test splits used in FusOn-pLM. This process included data cleaning, clustering, and splitting. During the cleaning process, we also extracted data about the heads and tails of each fusion oncoprpotein.
7
+
8
+ ```
9
+ data/
10
+ └── clustering/
11
+ ├── input.fasta
12
+ ├── mmseqs_full_results.csv
13
+ └── head_tail_data/
14
+ └── uniprot_idmap_inputs/
15
+ └── raw_data/
16
+ ├── FOdb_all.csv
17
+ ├── FOdb_puncta.csv
18
+ ├── FOdb_SD5.csv
19
+ ├── FusionPDB_cleaned.csv
20
+ ├── FusionPDB.txt
21
+ ├── gene_to_ensembl_dict.pkl
22
+ └── splits/
23
+ ├── combined_plot.png
24
+ ├── train_df.csv
25
+ ├── train_cluster_split.csv
26
+ ├── val_df.csv
27
+ ├── val_cluster_split.csv
28
+ ├── test_df.csv
29
+ ├── test_cluster_split.csv
30
+ ├── clean.py
31
+ ├── cluster.py
32
+ ├── config.py
33
+ ├── split.py
34
+ ├── split_vis.py
35
+ ├── data_cleaning_log.txt
36
+ ├── clustering_log.txt
37
+ ├── splitting_log.txt
38
+ ├── fuson_db.csv
39
+ ```
40
+ - **`clean.py`**: script for cleaning the datasets in `raw_data`. Print statements in this code produce `data_cleaning_log.txt`.
41
+ - **`cluster.py`**: script for clustering the processed data in fuson_db.csv. Print statements in this code produce `clustering_log.txt`.
42
+ - **`config.py`**: configs for the cleaning, clustering, and splitting scripts.
43
+ - **`split.py`**: script for splitting the data, post-clusteirng. Print statements in this code produce `splitting_log.txt`.
44
+ - **`split_vis.py`** script with code for the plots in `splits/combined_plot.png`, which describe the content of the train, validation, and test splits (length distribution, Shannon Entropy, amino acid frequencies, and cluster sizes)
45
+
46
+ #### Usage
47
+ To repeat our cleaning, clustering, and splitting process, proceed as follows.
48
+ 1. Install MMSeqs2 at `/*/FusOn-pLM/fuson_plm/mmseqs2` according to these instructions: https://github.com/soedinglab/MMseqs2. Make sure that in `config.py`, CLUSTER.PATH_TO_MMSEQS points to your mmseqs installation.
49
+ 2. Run the cleaning script:
50
+ ```python
51
+ python clean.py
52
+ ```
53
+
54
+ This script will create the following files:
55
+ - **`fuson_db.csv`**: FusOn-DB. Our full database of 44,414 fusion oncoproteins.
56
+ - **`raw_data/FusionPDB_cleaned.csv`**: a processed version of the FusionPDB database with the following columns: `aa_seq`,`n_fusiongenes`,`fusiongenes`,`cancers`,`primary_sources`,`secondary_source`.
57
+ - **`head_tail_data/uniprot_idmap_inputs/head_tail_ens.txt`** and **`head_tail_data/uniprot_idmap_inputs/head_tail_genes.txt`**: all unique Ensembl IDs and gene symbols for all unique head/tail proteins corresponding to any fusion oncoproteins in FusOn-DB. These were submitted to the UniProt ID-mapping tool to create **`head_tail_data/ensembl_ht_idmap.txt`** and **`head_tail_data/genename_ht_idmap.txt`, respectively.
58
+ - **`head_tail_data/uniprot_idmap_inputs/gene_to_ensembl_dict.pkl`**: a dictionary mapping each unique gene symbol to a comma-separated list of its associated Ensembl IDs, according to FusionPDB.
59
+ - **`head_tail_data/uniprot_idmap_inputs/htgenes_uniprotids.csv`** a file with each unique gene symbol (`Gene`), a comma-separated list of all associated UniProt IDs (`UniProtID`), and a concatenated list of 1s and 0s representing whether each ID in the `UniProtID` column is reviewed or not (`Reviewed`).
60
+ - For example, a `Reviewed` value of "100" means the first ID in the `UniProtID` column of the same row is reviewed (1) and the second and third are not (0)
61
+
62
+ 3. Run the clustering script:
63
+ ```python
64
+ python cluster.py
65
+ ```
66
+
67
+ The command entered by this script to the clustering software is:
68
+ ```bash
69
+ mmseqs easy-cluster clustering/input.fasta clustering/raw_output/mmseqs clustering/raw_output --min-seq-id 0.3 -c 0.8 --cov-mode 0
70
+ ```
71
+
72
+ This script will cluster all sequences length 2000 or shorter (see `config.py`) and create the following files:
73
+ - **`clustering/input.fasta`**: the input file used by MMSeqs2 to cluster the fusion oncoprotein sequences. Headers are our assigned sequence IDs (can be found in the `seq_id` column of `fuson_db.csv`.)
74
+ - **`clustering/mmseqs_full_results.csv`**: clustering results. Columns:
75
+ - `representative seq_id`: the seq_id of the sequence representing this cluster
76
+ - `member seq_id`: the seq_id of a member of the cluster
77
+ - `representative seq`: the amino acid sequence of the cluster representative (representative seq_id)
78
+ - `member seq`: the amino acid sequence of the cluster member
79
+
80
+ 4. Run the splitting script:
81
+ ```python
82
+ python split.py
83
+ ```
84
+
85
+ This script will create the following files:
86
+ - **`splits/train_cluster_split.csv`, `splits/val_cluster_split.csv`, `splits/test_cluster_split.csv`**: The subsets of `clustering/mmseqs_full_results.csv` that have been partitioned into the train, validation, and test sets respectively.
87
+ - **`splits/train_df.csv`, `splits/val_df.csv`, `splits/test_df.csv`**: The train, validation, and testing splits used to train FusOn-pLM. Columns: `sequence`,`member length`
88
+ - **`splits/combined_plot.png`**: plot displaying the composition of the train, validation, and test splits.
89
+
90
+ ### BLAST
91
+ We ran BLAST to get the best alignment of each sequence in FusOn-DB to a protein in SwissProt. See the README in the `blast` folder for more details.
fuson_plm/data/__init__.py ADDED
File without changes
fuson_plm/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
fuson_plm/data/__pycache__/clean.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
fuson_plm/data/__pycache__/cluster.cpython-310.pyc ADDED
Binary file (6.02 kB). View file
 
fuson_plm/data/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
fuson_plm/data/__pycache__/split_vis.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
fuson_plm/data/blast/README.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ We ran local BLAST to get the best alignment of each fusion oncoprotein sequence to every protein in SwissProt.
2
+
3
+ ### Downloading BLAST Executables and Database
4
+ First, we needed to downloaded the BLAST executables by entering the following in terminal (if you don't have a Linux system, find the correct download for your system at https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST):
5
+ ```
6
+ wget https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/ncbi-blast-2.16.0+-x64-linux.tar.gz
7
+ tar -zxvf ncbi-blast-2.16.0+-x64-linux.tar.gz
8
+ rm ncbi-blast-2.16.0+-x64-linux.tar.gz
9
+
10
+ cd ncbi-blast-2.16.0+
11
+ mkdir swissprot
12
+ cd swissprot
13
+ perl ../bin/update_blastdb.pl --passive --decompress swissprot
14
+
15
+ chmod +x "blast/ncbi-blast-2.16.0+/bin/blastp"
16
+ sudo chmod -R 755 FusOn-pLMfuson_plm/data/blast/ncbi-blast-2.16.0+
17
+ ```
18
+
19
+ ### Running BLAST
20
+ The directory is structured as follows:
21
+ ```
22
+ data/
23
+ └── blast/
24
+ └── blast_outputs/
25
+ ├── swissprot_blast_output_analyzed.pkl
26
+ ├── swissprot_blast_stats.csv
27
+ ├── swissprot_no_match.csv
28
+ ├── swissprot_no_match.txt
29
+ ├── swissprot_top_alignments.csv
30
+ ├── best_htg_alignments_swissprot_seqs.pkl
31
+ ├── ht_uniprot_query.txt
32
+ └── figures/
33
+ ├── identities_hist.png
34
+ ├── blast_fusions.py
35
+ ├── extract_blast_seqs.py
36
+ ├── plot.py
37
+ ├── fusion_blast_log.txt
38
+ ├── fuson_ht_db.csv
39
+ ```
40
+
41
+ - **`blast_fusions.py`**: script that will prepare FusOn-DB for BLAST, run BLAST against SwissProt (given you've installed BLAST software properly), extract top alignments and calculate statistics on the BLAST results, and make results plots. Print statements in this script create the log file `fusion_blast_log.txt`.
42
+ - **`extract_blast_seqs.py`**: script that will extract sequences of all the head/tail proteins that formed the best alignment during BLAST, directly from the SwissProt BLAST database. Creates the file `blast_outputs/best_htg_alignments_swissprot_seqs.pkl`.
43
+ - **`plot.py`**: script to make the plot found at `figures/identities_hist.png`. This plot displays the maximum % identity of each fusion oncoprotein sequence with a SwissProt sequence, based on BLAST. This plot is also automatically created by `blast_fusions.py`.
44
+ - **`fuson_ht_db.csv`**: Database that merges FusOn-DB (`/*/FusOn-pLM/fuson_plm/data/fuson_db.csv`) with `/*/FusOn-pLM/fuson_plm/data/head_tail_data/htgenes_uniprotids.csv`, which simplifies the process of analyzing BLAST results. In FusOn-DB, certain amino acid sequences are associated with multiple fusion oncoproteins, whose names are comma-separated in the `fusiongenes` column. In `fuson_ht_db.csv`, the `fusiongenes` column is exploded such that exach row only has one fusion gene. Therefore, this database has more rows than FusOn-DB, and some duplicate sequences.
45
+
46
+ To run BLAST search and analysis, we recommend using nohup as the process will take a long time.
47
+
48
+ ```python
49
+ nohup python blast_fusions.py > blastrun.out 2> blastrun.err &
50
+ ```
51
+
52
+ ### Understanding the output files
53
+
54
+ Here, we will break down each file in the `blast/blast_outputs` directory.
55
+
56
+ - **`best_htg_alignments_swissprot_seqs.pkl`**: a dictionary where the keys are UniProt IDs, "."-concatenated to their isoform (e.g. "Q8NFI3.1"), and the values are the amino acid sequence corresponding to that isoform. The sequences were pulled directly from the SwissProt BLAST dataase.
57
+ - **`ht_uniprot_query.txt`**: a list of all head and tail proteins producing top SwissProt alignments, in the format described above (e.g. "Q8NFI3.1"). Used to query the SwissProt database and create the `best_htg_alignments_swissprot_seqs.pkl` file.
58
+ - **`swissprot_blast_output_analyzed.pkl`**: dictionary that summarizes key BLAST results for each fusion protein. The keys are seq_ids, each corresponding to a fusion oncoprotein sequence in FusOn-DB. The values are dictionaries holding BLAST results for that seq_id. Each UniProt ID corresponding to a known head or tail (stored in `fuson_ht_db.csv`) is checked for an alignment. If there is no alignment, the value is None (e.g. `swissprot_blast_output_analyzed['seq18']['F8WED0']` is `None`). If there is an alignment, we store the Isoform, Score, Expect, Query_Aligned, Subject_Aligned, H_or_T (whether this ID is for teh head or tail protein), Best (whether this is the best - highest-scoring - alignment to this fusion oncoprotein), Identities, Positives, Gaps, Query_Start, Query_End, Subject_Start, and Subject_End. If the best alignment is not a known head or tail, this alignment is also stored. Below is the example dictionary for seq18.
59
+
60
+ ```python
61
+ swissprot_blast_output_analyzed['seq18'] =
62
+ {
63
+ "F8WED0": None,
64
+ "Q9Y2X3": {
65
+ "Isoform": 1,
66
+ "Score": 452.0,
67
+ "Expect": "6e-148",
68
+ "Query_Aligned": "AGTGSLLNLAKHAASTVQILGAEKALFRALKSRRDTPKYGLIYHASLVGQTSPKHKGKISRMLAAKTVLAIRYDAFGEDSSSAMGVENRAKLEARLRTLEDRGIRKISGTGKALAKTEKYEHKSEVKTYDPSGDSTLPTCSKKRKIEQVDKEDEITEKKAKKAKIKVKVEEEEEEKVAEEEETSVKKKKKRGKKKHIKEEPLSEEEPCTSTAIASPEKKKKKKKKRENED",
69
+ "Subject_Aligned": "AHAGSLLNLAKHAASTVQILGAEKALFRALKSRRDTPKYGLIYHASLVGQTSPKHKGKISRMLAAKTVLAIRYDAFGEDSSSAMGVENRAKLEARLRTLEDRGIRKISGTGKALAKTEKYEHKSEVKTYDPSGDSTLPTCSKKRKIEQVDKEDEITEKKAKKAKIKVKVEEEEEEKVAEEEETSVKKKKKRGKKKHIKEEPLSEEEPCTSTAIASPEKKKKKKKKRENED",
70
+ "H_or_T": "Tail",
71
+ "Best": False,
72
+ "Identities": "228/230 (99%)",
73
+ "Positives": "228/230 (99%)",
74
+ "Gaps": "0/230 (0%)",
75
+ "Query_Start": 754,
76
+ "Query_End": 983,
77
+ "Sbjct_Start": 300,
78
+ "Sbjct_End": 529,
79
+ },
80
+ "L0R804": None,
81
+ "A0A096LP60": None,
82
+ "A0A096LNZ0": None,
83
+ "H7BZ72": None,
84
+ "A0A096LP25": None,
85
+ "Q9BUD9": None,
86
+ "B7ZLC4": None,
87
+ "Q2M2I8": {
88
+ "Isoform": 3,
89
+ "Score": 1558.0,
90
+ "Expect": "0.0",
91
+ "Query_Aligned": "MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTG",
92
+ "Subject_Aligned": "MKKFFDSRREQGGSGLGSGSSGGGGSTSGLGSGYIGRVFGIGRQQVTVDEVLAEGGFAIVFLVRTSNGMKCALKRMFVNNEHDLQVCKREIQIMRDLSGHKNIVGYIDSSINNVSSGDVWEVLILMDFCRGGQVVNLMNQRLQTGFTENEVLQIFCDTCEAVARLHQCKTPIIHRDLKVENILLHDRGHYVLCDFGSATNKFQNPQTEGVNAVEDEIKKYTTLSYRAPEMVNLYSGKIITTKADIWALGCLLYKLCYFTLPFGESQVAICDGNFTIPDNSRYSQDMHCLIRYMLEPDPDKRPDIYQVSYFSFKLLKKECPIPNVQNSPIPAKLPEPVKASEAAAKKTQPKARLTDPIPTTETSIAPRQRPKAGQTQPNPGILPIQPALTPRKRATVQPPPQAAGSSNQPGLLASVPQPKPQAPPSQPLPQTQAKQPQAPPTPQQTPSTQAQGLPAQAQATPQHQQQLFLKQQQQQQQPPPAQQQPAGTFYQQQQAQTQQFQAVHPATQKPAIAQFPVVSQGGSQQQLMQNFYQQQQQQQQQQQQQQLATALHQQQLMTQQAALQQKPTMAAGQQPQPQPAAAPQPAPAQEPAIQAPVRQQPKVQTTPPPAVQGQKVGSLTPPSSPKTQRAGHRRILSDVTHSAVFGVPASKSTQLLQAAAAEASLNKSKSATTTPSGSPRTSQQNVYNPSEGSTWNPFDDDNFSKLTAEELLNKDFAKLGEGKHPEKLGGSAESLIPGFQSTQGDAFATTSFSAGTA",
93
+ "H_or_T": "Head",
94
+ "Best": True,
95
+ "Identities": "756/757 (99%)",
96
+ "Positives": "756/757 (99%)",
97
+ "Gaps": "0/757 (0%)",
98
+ "Query_Start": 1,
99
+ "Query_End": 757,
100
+ "Sbjct_Start": 1,
101
+ "Sbjct_End": 757,
102
+ },
103
+ "E9PG46": None,
104
+ }
105
+ ```
106
+
107
+ - **`swissprot_blast_stats.csv`**: a database summarizing the BLAST scores across all fusion oncoproteins. Columns are: seq_id,hgAlignments,tgAlignments,totalAlignments,best_hgScore,best_tgScore,best_Score,h_or_t_alignment,h_and_t_alignment
108
+ - `h_or_t_alignment` is True if either the head or tail has an alignment returned by BLAST. `h_and_t_alignment` is True if both the head and tail have an alignment returned by BLAST.
109
+ - **`swissprot_no_match.txt`**: names of the BLAST output files that said "No hits found"
110
+ - **`swissprot_no_match.csv`**: more information on the fusion oncoproteins indicated in swissprot_no_match.txt
111
+ - **`swissprot_top_alignments.csv`**: a database summarizing the most important information acquired by BLAST across all fusion oncoproteins. Columns are: seq_id,top_hg_UniProtID,top_hg_UniProt_isoform,top_hg_UniProt_fus_indices,top_tg_UniProtID,top_tg_UniProt_isoform,top_tg_UniProt_fus_indices,top_UniProtID,top_UniProt_isoform,top_UniProt_fus_indices,top_UniProt_nIdentities,top_UniProt_nPositives,aa_seq_len
112
+ - All indices (e.g. `top_hg_UniProt_fus_indices`) are 1-indexed.
113
+ - This database can be used to eestimate breakpoints using the `top_hg_UniProt_fus_indices` and `top_tg_UniProt_fus_indices` columns. For example, if `top_hg_UniProt_fus_indices` is "1,300" and `top_tg_UniProt_fus_indices` is "301,546", then that means residues 1-300 of the fusion protein aligned with the head protein indicated in `top_hg_UniProtID` and `top_hg_isoform`, and residues 301-546 of the fusion protein aligned with the tail protein indicated in `top_tg_UniProtID` and `top_tg_isoform`. The breakpoint is between residues 300 and 301.
fuson_plm/data/blast/__pycache__/blast_fusions.cpython-310.pyc ADDED
Binary file (25.1 kB). View file
 
fuson_plm/data/blast/__pycache__/plot.cpython-310.pyc ADDED
Binary file (9.08 kB). View file
 
fuson_plm/data/blast/blast_fusions.py ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Prepare to BLAST all of our sequences against UniProt
2
+ import pandas as pd
3
+ import os
4
+ import subprocess
5
+ import time
6
+ import re
7
+ import pickle
8
+ import numpy as np
9
+
10
+ from fuson_plm.utils.logging import log_update, open_logfile
11
+ from fuson_plm.utils.embedding import redump_pickle_dictionary
12
+ from fuson_plm.data.blast.plot import group_difference_plot, group_swiss_and_ht_plot, group_box_plot, group_pos_id_plot
13
+
14
+ def prepare_blast_inputs():
15
+ log_update("\nPreparing BLAST Inputs. Logging every 1000 sequences... ")
16
+ # make directory for input and output
17
+ os.makedirs("blast_inputs", exist_ok=True)
18
+
19
+ # read the fuson database
20
+ fuson_db = pd.read_csv('../fuson_db.csv')
21
+
22
+ # make dictionary mapping sequences to seqids (for naming input filess)
23
+ fuson_db_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
24
+
25
+ # convert the database into fasta format
26
+ new_fa_files_created = 0
27
+ old_fa_files_found = 0
28
+ total_seqs_processed=0
29
+ for i, (seq, seqid) in enumerate(fuson_db_dict.items()):
30
+ total_seqs_processed+=1
31
+ # if the path already exists, skip
32
+ if os.path.exists(f"blast_inputs/{seqid}.fa"):
33
+ old_fa_files_found+=1
34
+ else:
35
+ new_fa_files_created+=1
36
+ with open(f"blast_inputs/{seqid}.txt", 'w') as f:
37
+ fasta_lines = '>' + seqid + '\n' + seq
38
+ f.write(fasta_lines)
39
+ # rename it to .fa
40
+ os.rename(f"blast_inputs/{seqid}.txt", f"blast_inputs/{seqid}.fa")
41
+
42
+ if i%1000==0:
43
+ log_update(f"\t\t{i}\t{seqid}:{seq}")
44
+
45
+ log_update("\tFinished preparing BLAST Inputs (results in blast_inputs folder)")
46
+ log_update(f"\t\tSequences processed: {total_seqs_processed}/{len(fuson_db)} seqs in FusOn-DB\n\t\tFasta files found: {old_fa_files_found}\n\t\tNew fasta files created: {new_fa_files_created}")
47
+
48
+ def run_blast(blast_inputs_dir, database="swissprot",n=1,interval=2000):
49
+ """
50
+ Run BLAST on all files in blast_inputs_dir
51
+ """
52
+ # Must change the PATH variable to include the BLAST executables
53
+ os.environ['PATH'] += ":./ncbi-blast-2.16.0+/bin"
54
+ os.environ['BLASTDB'] = f"ncbi-blast-2.16.0+/{database}"
55
+
56
+ # make directory for outputs
57
+ os.makedirs("blast_outputs", exist_ok=True)
58
+ os.makedirs(f"blast_outputs/{database}", exist_ok=True)
59
+ already_blasted = os.listdir(f"blast_outputs/{database}")
60
+ blast_input_files = os.listdir(blast_inputs_dir)
61
+ # Sort the list using a custom key to extract the numeric part
62
+ blast_input_files = sorted(blast_input_files, key=lambda x: int(re.search(r'\d+', x).group()))
63
+
64
+ # print how many we've already blasted
65
+ log_update(f"Running BLAST.\n\t{len(blast_input_files)} input files\n\t{len(already_blasted)} already blasted\n")
66
+
67
+ tot_seqs_processed = 0
68
+ total_blast_time = 0
69
+
70
+ start_i = interval*(n-1)
71
+ end_i = interval*n
72
+ if end_i>len(blast_input_files): end_i = len(blast_input_files)
73
+ for i, blast_input_file in enumerate(blast_input_files[start_i:end_i]):
74
+ tot_seqs_processed+=1
75
+ # blast_input_file is of the format seqid.fa
76
+ seqid = blast_input_file.split('.fa')[0]
77
+ input_path = f"blast_inputs/{blast_input_file}"
78
+ output_path = f"blast_outputs/{database}/{seqid}_{database}_results.out"
79
+
80
+ if os.path.exists(output_path):
81
+ log_update(f"\t{i+1}.\tAlready blasted {seqid}")
82
+ continue
83
+
84
+ # Construct the command as a list of arguments
85
+ command = [
86
+ "ncbi-blast-2.16.0+/bin/blastp",
87
+ "-db", database,
88
+ "-query", input_path,
89
+ "-out", output_path
90
+ ]
91
+
92
+ # Run the command, and time it
93
+ blast_start_time = time.time()
94
+ result = subprocess.run(command, capture_output=True, text=True)
95
+ blast_end_time = time.time()
96
+ blast_seq_time = blast_end_time-blast_start_time
97
+ total_blast_time+=blast_seq_time
98
+
99
+ # Check if there was an error
100
+ if result.returncode != 0:
101
+ log_update(f"\t{i+1}.\tError running BLAST for {seqid}: {result.stderr} ({blast_seq_time:.2f}s)")
102
+ else:
103
+ log_update(f"\t{i+1}.\tBLAST search completed for {seqid} ({blast_seq_time:.2f}s)")
104
+
105
+ log_update(f"\tFinished processing {tot_seqs_processed} sequences ({total_blast_time:.2f}s)")
106
+
107
+ def remove_incomplete_blasts(database="swissprot"):
108
+ incomplete_list = []
109
+ for fname in os.listdir(f"blast_outputs/{database}"):
110
+ complete=False
111
+ with open(f"blast_outputs/{database}/{fname}", "r") as f:
112
+ lines = f.readlines()
113
+ if len(lines)>1 and "Window for multiple hits:" in lines[-1]:
114
+ complete=True
115
+ if not complete:
116
+ incomplete_list.append(fname)
117
+
118
+ log_update(f"\t{len(incomplete_list)} BLAST files are incomplete (due to BLAST errors). Deleting them. Rerun these")
119
+ # remove all these files
120
+ for fname in incomplete_list:
121
+ os.remove(f"blast_outputs/{database}/{fname}")
122
+
123
+ def find_nomatch_blasts(fuson_ht_db, database="swissprot"):
124
+ no_match_list = []
125
+ for fname in os.listdir(f"blast_outputs/{database}"):
126
+ match=True
127
+ with open(f"blast_outputs/{database}/{fname}", "r") as f:
128
+ lines = f.readlines()
129
+ if len(lines)>1 and "No hits found" in lines[28]: # it'll say no hits found if there are no hits
130
+ match=False
131
+ if not match:
132
+ no_match_list.append(fname)
133
+
134
+ log_update(f"\t{len(no_match_list)} sequence IDs had no match in the BLAST database {database}")
135
+ # write no match list to a file in blast_outputs
136
+ with open(f"blast_outputs/{database}_no_match.txt","w") as f:
137
+ for i, fname in enumerate(no_match_list):
138
+ if i!=len(no_match_list)-1:
139
+ f.write(f"{fname}\n")
140
+ else:
141
+ f.write(f"{fname}")
142
+
143
+ # write a subset of fuson_ht_db containing these sequences as well
144
+ no_match_ids = [x.split('_')[0] for x in no_match_list]
145
+ subset = fuson_ht_db.loc[
146
+ fuson_ht_db['seq_id'].isin(no_match_ids)
147
+ ].reset_index(drop=True)
148
+ subset.to_csv(f"blast_outputs/{database}_no_match.csv",index=False)
149
+
150
+ return no_match_ids
151
+
152
+ def make_fuson_ht_db(path_to_fuson_db="../fuson_db.csv", path_to_unimap="../head_tail_data/htgenes_uniprotids.csv",savepath="fuson_ht_db.csv"):
153
+ """
154
+ Make a version of the fuson_db that has all the heads and tails for each of the genes. Will make it easier to analyze blast results
155
+ """
156
+ if os.path.exists(savepath):
157
+ df = pd.read_csv(savepath)
158
+ return df
159
+
160
+ # read both of teh databases
161
+ fuson_db = pd.read_csv(path_to_fuson_db)
162
+ ht_db = pd.read_csv(path_to_unimap)
163
+
164
+ # Make it such that each row of fuson_db just has ONE head and ONE tail
165
+ fuson_ht_db = fuson_db.copy(deep=True)
166
+ fuson_ht_db['fusiongenes'] = fuson_ht_db['fusiongenes'].apply(lambda x: x.split(','))
167
+ fuson_ht_db = fuson_ht_db.explode('fusiongenes')
168
+ fuson_ht_db['hgene'] = fuson_ht_db['fusiongenes'].str.split('::',expand=True)[0]
169
+ fuson_ht_db['tgene'] = fuson_ht_db['fusiongenes'].str.split('::',expand=True)[1]
170
+
171
+ # Merge on head, then merge on tail
172
+ fuson_ht_db = pd.merge( # merge on head
173
+ fuson_ht_db,
174
+ ht_db.rename(columns={
175
+ 'Gene': 'hgene',
176
+ 'UniProtID': 'hgUniProt',
177
+ 'Reviewed': 'hgUniProtReviewed'
178
+ }),
179
+ on='hgene',
180
+ how='left'
181
+ )
182
+ fuson_ht_db = pd.merge( # merge on tail
183
+ fuson_ht_db,
184
+ ht_db.rename(columns={
185
+ 'Gene': 'tgene',
186
+ 'UniProtID': 'tgUniProt',
187
+ 'Reviewed': 'tgUniProtReviewed'
188
+ }),
189
+ on='tgene',
190
+ how='left'
191
+ )
192
+
193
+ # Make sure we haven't lost anything
194
+ tot_og_seqids = len(fuson_db['seq_id'].unique())
195
+ tot_final_seqids = len(fuson_ht_db['seq_id'].unique())
196
+ log_update(f"\tTotal sequence IDs in combined database = {tot_final_seqids}. Matches expected: {tot_final_seqids==tot_og_seqids}")
197
+ # Each fusion should have the same number of ROWS as it does commas+1
198
+ fuson_db['n_commas'] = fuson_db['fusiongenes'].str.count(',') + 1
199
+ seqid_rows_map = dict(zip(fuson_db['seq_id'],fuson_db['n_commas']))
200
+ vc = fuson_ht_db['seq_id'].value_counts().reset_index()
201
+ vc['expected_count'] = vc['index'].map(seqid_rows_map)
202
+ log_update(f"\tEach seq_id has the expected number of head-tail combos: {(vc['expected_count']==vc['seq_id']).all()}")
203
+
204
+ log_update(f"\tPreview of combined database:")
205
+ prev = fuson_ht_db.head(10)
206
+ prev['aa_seq'] = prev['aa_seq'].apply(lambda x: x[0:10]+'...')
207
+ log_update(prev.to_string(index=False))
208
+ fuson_ht_db.to_csv(savepath, index=False)
209
+ return fuson_ht_db
210
+
211
+ def format_dict(d, indent=0):
212
+ """
213
+ Recursively formats a dictionary for display purposes.
214
+
215
+ Args:
216
+ d (dict): The dictionary to format.
217
+ indent (int): The current level of indentation.
218
+
219
+ Returns:
220
+ str: A formatted string representing the dictionary.
221
+ """
222
+ formatted_str = ""
223
+ # Iterate through each key-value pair in the dictionary
224
+ for key, value in d.items():
225
+ # Create the current indentation
226
+ current_indent = " " * (indent * 4)
227
+ # Add the key
228
+ formatted_str += f"{current_indent}{repr(key)}: "
229
+
230
+ # Check the type of the value
231
+ if isinstance(value, dict):
232
+ # If dictionary, call format_dict recursively
233
+ formatted_str += "{\n" + format_dict(value, indent + 1) + current_indent + "},\n"
234
+ elif isinstance(value, list):
235
+ # If list, convert it to a formatted string
236
+ formatted_str += f"[{', '.join(repr(item) for item in value)}],\n"
237
+ elif isinstance(value, str):
238
+ # If string, enclose in quotes
239
+ formatted_str += f"'{value}',\n"
240
+ elif value is None:
241
+ # If None, display as 'None'
242
+ formatted_str += "None,\n"
243
+ else:
244
+ formatted_str += f"{repr(value)},\n"
245
+
246
+ return formatted_str
247
+
248
+ def parse_blast_output(file_path, head_ids, tail_ids):
249
+ """
250
+ Args:
251
+ - file_path: /path/to/blast/output
252
+ - head_ids: list of all UniProt IDs for the head protien
253
+ - tail_ids: list of all UniProt IDs for the tail protein
254
+ """
255
+ target_ids = list(set(head_ids + tail_ids)) # make a list to make some functions easier
256
+ with open(file_path, 'r') as file:
257
+ best_data = {tid: None for tid in target_ids} # stores the best alignment for each ID we care about
258
+ current_data = {tid: {} for tid in target_ids} # stores the current data for each ID we care about (most recent alignment we read)
259
+ best_score = {tid: -float('inf') for tid in target_ids} # stores the best score for each ID we care about
260
+ capture = {tid: False for tid in target_ids} # whether we are currently processing this ID
261
+ replace_best = {tid: False for tid in target_ids} # whether we should replace the best_data with the current_data for this ID
262
+ isoform_dict = {tid: None for tid in target_ids} # dictionary of isoforms for
263
+
264
+ # variables that will only be used for getting the best alignment
265
+ alignment_count = 0
266
+ cur_id = None
267
+ on_best_alignment=False
268
+
269
+ # Iterate through lines
270
+ for line in file:
271
+ line = line.strip()
272
+ # if NEW ID (not necessarily new alignment! can be multiple alignmetns under one >)
273
+ if line.startswith('>'):
274
+ found_tid_in_header=False # assume we have not found a target ID we are looking for
275
+ alignment_count+=1
276
+ if alignment_count==1: # we're on the best alignment because this is the one that's listed first! it should be
277
+ on_best_alignment=True
278
+ else:
279
+ on_best_alignment = False
280
+
281
+ ## We may have just finisehd processing an ID. Check for the one who currently has capture set to true
282
+ just_captured = None
283
+ total_captured = 0
284
+ for k, v in capture.items():
285
+ if v:
286
+ total_captured+=1
287
+ just_captured = k
288
+ # we should never be capturing more than one thing at a time. make sure of this
289
+ assert total_captured<2
290
+ if just_captured is not None:
291
+ if replace_best[just_captured]: # if we just finished an alignment for the just_captured ID, and it's the best one, put it in
292
+ best_data[just_captured] = current_data[just_captured].copy()
293
+ replace_best[just_captured] = False # we just did the replacement, so reset it
294
+
295
+ # Check if the line contains any of the target IDs.
296
+ # This means EITHER [UniProtID] or [UniProtID.Isoform] or [UniProtID-Isoform] is in the line
297
+ for tid in target_ids:
298
+ pattern = fr">{tid}([.-]\d+)? " # for ID P02671, would match ">P02671 ", ">P02671.2 " and ">P02671-2 "
299
+ if re.search(pattern, line): # if this ID matches
300
+ isoform_dict[tid] = None # set it to None, update it if we need to
301
+ if "." in line: # look for isoform denoted by . if there is one, otherwise it'll stay as None
302
+ isoform = int(line.split(".")[1].split(" ")[0])
303
+ isoform_dict[tid] = isoform
304
+ #print(f"\t\tID = {tid} (is a head or tail), isoform={isoform}")
305
+ elif "-" in line: # look for isoform denoted by - if there is one, otherwise it'll stay as None
306
+ isoform = int(line.split("-")[1].split(" ")[0])
307
+ isoform_dict[tid] = isoform
308
+ #print(f"\t\tID = {tid} (is a head or tail), isoform={isoform}")
309
+ capture[tid] = True
310
+ current_data[tid] = {'header': line}
311
+ found_tid_in_header=True # we've found the tid that's in this line, so no need to check theothers
312
+ else:
313
+ capture[tid] = False
314
+
315
+ if on_best_alignment: # if this is the best alignment
316
+ if not(found_tid_in_header): # if none of our TIDs are it
317
+ cur_id_full = line.split('>')[1].split(' ')[0]
318
+ cur_id, isoform = cur_id_full, None
319
+ isoform_dict[cur_id] = None # change this if we need
320
+ if "." in cur_id_full: # if there's a dot, it's an isoform.
321
+ cur_id = cur_id_full.split(".")[0]
322
+ isoform = int(cur_id_full.split(".")[1])
323
+ isoform_dict[cur_id] = isoform
324
+ #log_update(f"\t\tID = {cur_id} (best alignment, not a head or tail), isoform={isoform}")
325
+ #log_update(f"\t\t\tFull line: {line}") # so we can see the gene name. does it make sense?
326
+ elif "-" in cur_id_full: # if there's a -, it's an isoform.
327
+ cur_id = cur_id_full.split("-")[0]
328
+ isoform = int(cur_id_full.split("-")[1])
329
+ isoform_dict[cur_id] = isoform
330
+ #log_update(f"\t\tID = {cur_id} (best alignment, not a head or tail), isoform={isoform}")
331
+ #log_update(f"\t\t\tFull line: {line}") # so we can see the gene name. does it make sense?
332
+ # add this id to all the dictionaries
333
+ best_data[cur_id] = None
334
+ current_data[cur_id] = {}
335
+ best_score[cur_id] = -float('inf')
336
+ capture[cur_id] = False
337
+ replace_best[cur_id] = False
338
+
339
+
340
+ for tid in target_ids:
341
+ if capture[tid]: # if we're currently on an alignment for a tid we care about
342
+ if 'Score =' in line:
343
+ if replace_best[tid]: # if we're replacing the best alignment with this one, within the same ID, do it
344
+ best_data[tid] = current_data[tid].copy()
345
+ # now reset the variable!
346
+ replace_best[tid] = False
347
+
348
+ score_value = float(line.split()[2]) # Assuming "Score = 1053 bits (2723)" format
349
+ current_data[tid] = {} # Reset current_data for this ID
350
+ current_data[tid]['Isoform'] = isoform_dict[tid]
351
+ current_data[tid]['Score'] = score_value
352
+ current_data[tid]['Expect'] = line.split('Expect =')[1].split(', Method')[0].strip()
353
+ current_data[tid]['Query_Aligned'] = []
354
+ current_data[tid]['Subject_Aligned'] = []
355
+ # Set the ID as a head or tail, or neither (neither shouldn't happen here though)
356
+ if tid in head_ids:
357
+ current_data[tid]['H_or_T'] = 'Head'
358
+ if tid in tail_ids:
359
+ current_data[tid]['H_or_T'] = 'Head,Tail'
360
+ elif tid in tail_ids:
361
+ current_data[tid]['H_or_T'] = 'Tail'
362
+ else:
363
+ current_data[tid]['H_or_T'] = np.nan
364
+
365
+ current_data[tid]['Best'] = True if on_best_alignment else False
366
+ if score_value > best_score[tid]: # if this is the best score we have for an alignment of this protein
367
+ best_score[tid] = score_value
368
+ replace_best[tid] = True
369
+ else:
370
+ replace_best[tid] = False
371
+
372
+ if 'Identities =' in line:
373
+ idents = line.split(', ')
374
+ current_data[tid]['Identities'] = idents[0].split('=')[1].strip()
375
+ current_data[tid]['Positives'] = idents[1].split('=')[1].strip()
376
+ current_data[tid]['Gaps'] = idents[2].split('=')[1].strip()
377
+ if line.startswith('Query'):
378
+ parts = line.split()
379
+ if 'Query_Start' not in current_data[tid]:
380
+ current_data[tid]['Query_Start'] = int(parts[1])
381
+ current_data[tid]['Query_End'] = int(parts[3])
382
+ current_data[tid]['Query_Aligned'].append(parts[2])
383
+ if line.startswith('Sbjct'):
384
+ parts = line.split()
385
+ if 'Sbjct_Start' not in current_data[tid]:
386
+ current_data[tid]['Sbjct_Start'] = int(parts[1])
387
+ current_data[tid]['Sbjct_End'] = int(parts[3])
388
+ current_data[tid]['Subject_Aligned'].append(parts[2])
389
+
390
+ # if we're on the best alignment and it's not one of our target_ids, still process it the same way
391
+ if on_best_alignment:
392
+ if not(found_tid_in_header):
393
+ if 'Score =' in line:
394
+ if replace_best[cur_id]: # if we're replacing the best alignment with this one, within the same ID, do it
395
+ best_data[cur_id] = current_data[cur_id].copy()
396
+ # now reset the variable!
397
+ replace_best[cur_id] = False
398
+
399
+ score_value = float(line.split()[2]) # Assuming "Score = 1053 bits (2723)" format
400
+ current_data[cur_id] = {} # Reset current_data for this ID
401
+ current_data[cur_id]['Isoform'] = isoform_dict[cur_id]
402
+ current_data[cur_id]['Score'] = score_value
403
+ current_data[cur_id]['Expect'] = line.split('Expect =')[1].split(', Method')[0].strip()
404
+ current_data[cur_id]['Query_Aligned'] = []
405
+ current_data[cur_id]['Subject_Aligned'] = []
406
+ # Set the ID as a head or tail, or neither
407
+ if cur_id in head_ids:
408
+ current_data[cur_id]['H_or_T'] = 'Head'
409
+ if cur_id in tail_ids:
410
+ current_data[cur_id]['H_or_T'] = 'Head,Tail'
411
+ elif cur_id in tail_ids:
412
+ current_data[cur_id]['H_or_T'] = 'Tail'
413
+ else:
414
+ current_data[cur_id]['H_or_T'] = np.nan
415
+
416
+ current_data[cur_id]['Best'] = True
417
+ if score_value > best_score[cur_id]: # if this is the best score we have for an alignment of this protein
418
+ best_score[cur_id] = score_value
419
+ replace_best[cur_id] = True
420
+ else:
421
+ replace_best[cur_id] = False
422
+
423
+ if 'Identities =' in line:
424
+ idents = line.split(', ')
425
+ current_data[cur_id]['Identities'] = idents[0].split('=')[1].strip()
426
+ current_data[cur_id]['Positives'] = idents[1].split('=')[1].strip()
427
+ current_data[cur_id]['Gaps'] = idents[2].split('=')[1].strip()
428
+ if line.startswith('Query'):
429
+ parts = line.split()
430
+ if 'Query_Start' not in current_data[cur_id]:
431
+ current_data[cur_id]['Query_Start'] = int(parts[1])
432
+ current_data[cur_id]['Query_End'] = int(parts[3])
433
+ current_data[cur_id]['Query_Aligned'].append(parts[2])
434
+ if line.startswith('Sbjct'):
435
+ parts = line.split()
436
+ if 'Sbjct_Start' not in current_data[cur_id]:
437
+ current_data[cur_id]['Sbjct_Start'] = int(parts[1])
438
+ current_data[cur_id]['Sbjct_End'] = int(parts[3])
439
+ current_data[cur_id]['Subject_Aligned'].append(parts[2])
440
+
441
+ # add cur_id to target_ids if it's not none
442
+ if not(cur_id is None):
443
+ target_ids += [cur_id]
444
+
445
+ # Check at the end of the file if the last scores are the best
446
+ for tid in target_ids:
447
+ if replace_best[tid]:
448
+ best_data[tid] = current_data[tid].copy()
449
+
450
+ # Combine sequences into single strings for the best data for each ID
451
+ for tid in target_ids:
452
+ #print(tid)
453
+ if best_data[tid]:
454
+ #print(f"there is a best alignment for {tid}")
455
+ #print(f"best: {best_data[tid]}")
456
+ #print(f"current: {current_data[tid]}")
457
+ best_data[tid]['Query_Aligned'] = ''.join(best_data[tid]['Query_Aligned'])
458
+ best_data[tid]['Subject_Aligned'] = ''.join(best_data[tid]['Subject_Aligned'])
459
+
460
+ return best_data
461
+
462
+ def parse_all_blast_results(fuson_ht_db, database="swissprot"):
463
+ """
464
+ Analyze the BLAST outputs for each fusion protein against UniProt.
465
+ Use the fuson_ht_db to look for the heads and tails that we expect. If they can't be found, ... ?
466
+ """
467
+ output_file=f"blast_outputs/{database}_blast_output_analyzed.pkl"
468
+ all_seq_ids = fuson_ht_db['seq_id'].unique().tolist()
469
+ all_seq_ids = sorted(all_seq_ids, key=lambda x: int(re.search(r'\d+', x).group())) # sort by the number. seq1, seq2, ...
470
+
471
+ prior_results = {}
472
+ if os.path.exists(output_file):
473
+ with open(output_file, "rb") as f:
474
+ prior_results = pickle.load(f)
475
+
476
+ # Iterate through seq_ids
477
+ total_parse_time = 0
478
+ tot_seqs_processed = 0
479
+ for seq_id in all_seq_ids:
480
+ try:
481
+ tot_seqs_processed+=1
482
+ # If we've already processed it, skip
483
+ if seq_id in prior_results:
484
+ log_update(f"\tAlready processed {seq_id} blast results. Continuing")
485
+ continue
486
+
487
+ file_path = f"blast_outputs/{database}/{seq_id}_{database}_results.out"
488
+
489
+ aa_seq = fuson_ht_db.loc[
490
+ fuson_ht_db['seq_id']==seq_id
491
+ ]['aa_seq'].tolist()[0]
492
+
493
+ # Remember, fuson_ht_db has all the IDs for ALL the different head and tail gene identifiers.
494
+ fusion_genes = fuson_ht_db.loc[
495
+ fuson_ht_db['seq_id']==seq_id
496
+ ]['fusiongenes'].tolist()
497
+
498
+ ##### Process heads
499
+ head_ids = fuson_ht_db.loc[
500
+ fuson_ht_db['seq_id']==seq_id
501
+ ]['hgUniProt'].dropna().tolist()
502
+ head_reviewed, head_reviewed_dict = "", {}
503
+ if len(head_ids)>0: # if we found head IDs, we can process them and figure out if they're reviewed
504
+ head_ids = ",".join(head_ids).split(",")
505
+ head_reviewed = fuson_ht_db.loc[
506
+ fuson_ht_db['seq_id']==seq_id
507
+ ]['hgUniProtReviewed'].dropna().tolist()
508
+ head_reviewed = list("".join(head_reviewed))
509
+
510
+ head_reviewed_dict = dict(zip(head_ids, head_reviewed))
511
+ head_ids = list(head_reviewed_dict.keys()) # there may be some duplicates, so separate them out again
512
+ head_reviewed = list(head_reviewed_dict.values())
513
+
514
+ head_genes = fuson_ht_db.loc[
515
+ fuson_ht_db['seq_id']==seq_id
516
+ ]['hgene'].unique().tolist()
517
+
518
+ ##### Process tails - same logic
519
+ tail_ids = fuson_ht_db.loc[
520
+ fuson_ht_db['seq_id']==seq_id
521
+ ]['tgUniProt'].dropna().tolist()
522
+ tail_reviewed, tail_reviewed_dict = "", {}
523
+ if len(tail_ids)>0: # if we found tail IDs, we can process them and figure out if they're reviewed
524
+ tail_ids = ",".join(tail_ids).split(",")
525
+ tail_reviewed = fuson_ht_db.loc[
526
+ fuson_ht_db['seq_id']==seq_id
527
+ ]['tgUniProtReviewed'].dropna().tolist()
528
+ tail_reviewed = list("".join(tail_reviewed))
529
+
530
+ tail_reviewed_dict = dict(zip(tail_ids, tail_reviewed))
531
+ tail_ids = list(tail_reviewed_dict.keys()) # there may be some duplicates, so separate them out again
532
+ tail_reviewed = list(tail_reviewed_dict.values())
533
+
534
+ tail_genes = fuson_ht_db.loc[
535
+ fuson_ht_db['seq_id']==seq_id
536
+ ]['tgene'].unique().tolist()
537
+
538
+ ###### Log what we just found
539
+ log_update(f"\tEvaluating {seq_id}, fusion genes = {fusion_genes}, len = {len(aa_seq)}...\n\t\tfile_path={file_path}")
540
+ #log_update(f"\n\t\thead genes={head_genes}\n\t\thead_ids={head_ids}\n\t\ttail genes={tail_genes}\n\t\ttail_ids={tail_ids}")
541
+
542
+ ### Do the analysis and time it
543
+ parse_start_time = time.time() # time it
544
+ blast_data = parse_blast_output(file_path, head_ids, tail_ids)
545
+ parse_end_time = time.time()
546
+ parse_seq_time = parse_end_time-parse_start_time
547
+ total_parse_time+=parse_seq_time
548
+ log_update(f"\t\tBLAST output analysis completed for {seq_id} ({parse_seq_time:.2f}s)")
549
+
550
+ # Give preview of results. Logging the whole dict would be too much, so let's just see what we found
551
+ #log_update(format_dict(blast_data,indent=3))
552
+ n_og_reviewed_head_ids = len([x for x in head_reviewed if x=='1'])
553
+ found_head_ids = [x for x in list(blast_data.keys()) if (blast_data[x] is not None) and (blast_data[x].get('H_or_T',None) in ['Head','Head,Tail'])]
554
+ n_found_reviewed_head_ids = len([x for x in found_head_ids if head_reviewed_dict[x]=='1'])
555
+
556
+ n_og_reviewed_tail_ids = len([x for x in tail_reviewed if x=='1'])
557
+ found_tail_ids = [x for x in list(blast_data.keys()) if (blast_data[x] is not None) and (blast_data[x].get('H_or_T',None) in ['Tail','Head,Tail'])]
558
+ n_found_reviewed_tail_ids = len([x for x in found_tail_ids if tail_reviewed_dict[x]=='1'])
559
+
560
+ #log_update(f"\t\t{len(found_head_ids)}/{len(head_ids)} head protein UniProt IDs ({n_found_reviewed_head_ids}/{n_og_reviewed_head_ids} REVIEWED heads) had alignments")
561
+ #log_update(f"\t\t{len(found_tail_ids)}/{len(tail_ids)} tail protein UniProt IDs ({n_found_reviewed_tail_ids}/{n_og_reviewed_tail_ids} REVIEWED tails) had alignments")
562
+
563
+ # write results to pickle file
564
+ to_pickle_dict = {seq_id: blast_data}
565
+ with open(output_file, 'ab+') as f:
566
+ pickle.dump(to_pickle_dict, f)
567
+
568
+ except:
569
+ log_update(f"{seq_id} failed")
570
+ # redump the pickle even if we hit an error, so that we can fix the error and continue processing results
571
+ redump_pickle_dictionary(output_file)
572
+
573
+ # Log total time
574
+ log_update(f"\tFinished processing {tot_seqs_processed} sequences ({total_parse_time:.2f}s)")
575
+
576
+ # redump the pickle
577
+ redump_pickle_dictionary(output_file)
578
+
579
+ def analyze_blast_results(fuson_ht_db, database="swissprot"):
580
+ blast_results_path=f"blast_outputs/{database}_blast_output_analyzed.pkl"
581
+ stats_df_savepath = f"blast_outputs/{database}_blast_stats.csv"
582
+ top_alignments_df_savepath = f"blast_outputs/{database}_top_alignments.csv"
583
+
584
+ stats_df, top_alignments_df = None, None
585
+ if os.path.exists(stats_df_savepath) and os.path.exists(top_alignments_df_savepath):
586
+ stats_df = pd.read_csv(stats_df_savepath)
587
+ top_alignments_df = pd.read_csv(top_alignments_df_savepath, dtype={'top_hg_UniProt_isoform':'str',
588
+ 'top_tg_UniProt_isoform': 'str',
589
+ 'top_UniProt_isoform': 'str'})
590
+
591
+ else:
592
+ with open(blast_results_path, "rb") as f:
593
+ results = pickle.load(f)
594
+
595
+ # analyze the results
596
+ # first, basic stats. How many of them have at least one head or tail alignment??
597
+ seqid_stats = {}
598
+ top_alignments_dict = {}
599
+ for seq_id in list(results.keys()):
600
+ seqid_stats[seq_id] = {
601
+ 'hgAlignments': 0,
602
+ 'tgAlignments': 0,
603
+ 'totalAlignments': 0,
604
+ 'best_hgScore': 0,
605
+ 'best_tgScore': 0,
606
+ 'best_Score': 0
607
+ }
608
+ top_alignments_dict[seq_id] = {
609
+ 'top_hg_UniProtID': None,
610
+ 'top_hg_UniProt_isoform': None,
611
+ 'top_hg_UniProt_fus_indices': None,
612
+ 'top_tg_UniProtID': None,
613
+ 'top_tg_UniProt_isoform': None,
614
+ 'top_tg_UniProt_fus_indices': None,
615
+ 'top_UniProtID': None,
616
+ 'top_UniProt_isoform': None,
617
+ 'top_UniProt_fus_indices': None
618
+ }
619
+ for uniprot, d in results[seq_id].items():
620
+ if not(d is None):
621
+ isoform = d['Isoform']
622
+ # set up the indices string
623
+ query_start = d['Query_Start']
624
+ if (query_start is None) or (type(query_start)==float and np.isnan(query_start)):
625
+ query_start = ''
626
+ else:
627
+ query_start = int(query_start)
628
+ query_end = d['Query_End']
629
+ if (query_end is None) or (type(query_end)==float and np.isnan(query_end)):
630
+ query_end = ''
631
+ else:
632
+ query_end = int(query_end)
633
+ fus_indices = f"{query_start},{query_end}".strip(",")
634
+
635
+ if d['H_or_T'] in ['Head', 'Head,Tail']:
636
+ seqid_stats[seq_id]['hgAlignments'] +=1
637
+ if d['Score'] > seqid_stats[seq_id]['best_hgScore']:
638
+ seqid_stats[seq_id]['best_hgScore'] = d['Score']
639
+ if type(uniprot)==float or uniprot is None:
640
+ top_alignments_dict[seq_id]['top_hg_UniProtID'] = ''
641
+ else:
642
+ top_alignments_dict[seq_id]['top_hg_UniProtID'] = uniprot
643
+ if (type(isoform)==float and np.isnan(isoform)) or isoform is None:
644
+ top_alignments_dict[seq_id]['top_hg_UniProt_isoform'] = ''
645
+ else:
646
+ top_alignments_dict[seq_id]['top_hg_UniProt_isoform'] = str(int(isoform))
647
+
648
+ top_alignments_dict[seq_id]['top_hg_UniProt_fus_indices'] = fus_indices
649
+
650
+ if d['H_or_T'] in ['Tail','Head,Tail']:
651
+ seqid_stats[seq_id]['tgAlignments'] +=1
652
+ if d['Score'] > seqid_stats[seq_id]['best_tgScore']:
653
+ seqid_stats[seq_id]['best_tgScore'] = d['Score']
654
+ if type(uniprot)==float or uniprot is None:
655
+ top_alignments_dict[seq_id]['top_tg_UniProtID'] = ''
656
+ else:
657
+ top_alignments_dict[seq_id]['top_tg_UniProtID'] = uniprot
658
+ if (type(isoform)==float and np.isnan(isoform)) or isoform is None:
659
+ top_alignments_dict[seq_id]['top_tg_UniProt_isoform'] = ''
660
+ else:
661
+ top_alignments_dict[seq_id]['top_tg_UniProt_isoform'] = str(int(isoform))
662
+
663
+ top_alignments_dict[seq_id]['top_tg_UniProt_fus_indices'] = fus_indices
664
+ # increment total no matter what type of alignment it is
665
+ seqid_stats[seq_id]['totalAlignments']+=1
666
+ #if d['Score'] > seqid_stats[seq_id]['best_Score']:
667
+ if d['Best']==True: # should be indicated if this is the best!!
668
+ seqid_stats[seq_id]['best_Score'] = d['Score']
669
+ if type(uniprot)==float or uniprot is None:
670
+ top_alignments_dict[seq_id]['top_UniProtID'] = ''
671
+ else:
672
+ top_alignments_dict[seq_id]['top_UniProtID'] = uniprot
673
+ if (type(isoform)==float and np.isnan(isoform)) or isoform is None:
674
+ top_alignments_dict[seq_id]['top_UniProt_isoform'] = ''
675
+ else:
676
+ top_alignments_dict[seq_id]['top_UniProt_isoform'] = str(int(isoform))
677
+
678
+ top_alignments_dict[seq_id]['top_UniProt_fus_indices'] = fus_indices
679
+ # now get positives and identities
680
+ if 'Identities' not in d: print(seq_id, uniprot, d.keys())
681
+ identities = d['Identities']
682
+ identities = int(identities.split('/')[0])
683
+ positives = d['Positives']
684
+ positives = int(positives.split('/')[0])
685
+ top_alignments_dict[seq_id]['top_UniProt_nIdentities'] = identities
686
+ top_alignments_dict[seq_id]['top_UniProt_nPositives'] = positives
687
+
688
+
689
+ stats_df = pd.DataFrame.from_dict(seqid_stats, orient='index').reset_index().rename(columns={'index':'seq_id'})
690
+ stats_df['h_or_t_alignment'] = stats_df.apply(lambda row: True if (row['hgAlignments']>0 or row['tgAlignments']>0) else False, axis=1)
691
+ stats_df['h_and_t_alignment'] = stats_df.apply(lambda row: True if (row['hgAlignments']>0 and row['tgAlignments']>0) else False, axis=1)
692
+ stats_df.to_csv(stats_df_savepath,index=False)
693
+
694
+ top_alignments_df = pd.DataFrame.from_dict(top_alignments_dict, orient='index').reset_index().rename(columns={'index':'seq_id'})
695
+ # add in the sequence length so we can get percentages
696
+ fusion_id_seq_dict = dict(zip(fuson_ht_db['seq_id'],fuson_ht_db['aa_seq']))
697
+ assert len(fusion_id_seq_dict) == len(fuson_ht_db['seq_id'].unique()) == len(fuson_ht_db['aa_seq'].unique())
698
+ top_alignments_df['aa_seq_len'] = top_alignments_df['seq_id'].map(fusion_id_seq_dict).str.len()
699
+
700
+ top_alignments_df.to_csv(top_alignments_df_savepath,index=False)
701
+ # also, find which ones have no match at all
702
+ # does it match?
703
+ no_match_list1 = find_nomatch_blasts(fuson_ht_db, database=database)
704
+
705
+ log_update(stats_df.head(10).to_string())
706
+ # how many have at least one head or tail?
707
+ log_update(f"Total sequences: {len(stats_df)}")
708
+ log_update(f"Sequences with >=1 head alignment: {len(stats_df.loc[stats_df['hgAlignments']>0])}")
709
+ log_update(f"Sequences with >=1 tail alignment: {len(stats_df.loc[stats_df['tgAlignments']>0])}")
710
+ log_update(f"Sequences with >=1 head OR tail alignment: {len(stats_df.loc[stats_df['h_or_t_alignment']])}")
711
+ log_update(f"Sequences with >=1 head AND tail alignment: {len(stats_df.loc[stats_df['h_and_t_alignment']])}")
712
+ log_update(f"Sequences with ANY alignment: {len(stats_df.loc[stats_df['totalAlignments']>0])}")
713
+
714
+ top_alignments_df = top_alignments_df.replace({None: ''})
715
+ log_update(f"Preview of top alignments for {database} search:\n{top_alignments_df.head(10).to_string(index=False)}")
716
+ top_alignments_df['hiso'] = top_alignments_df['top_hg_UniProtID']+'-'+top_alignments_df['top_hg_UniProt_isoform']
717
+ top_alignments_df['tiso'] = top_alignments_df['top_tg_UniProtID']+'-'+top_alignments_df['top_tg_UniProt_isoform']
718
+ top_alignments_df['biso'] = top_alignments_df['top_UniProtID']+'-'+top_alignments_df['top_UniProt_isoform']
719
+ top_hgs = set([x.strip('-') for x in top_alignments_df['hiso'].tolist()]) # if things don't have isoforms they'll just end in -
720
+ top_tgs = set([x.strip('-') for x in top_alignments_df['tiso'].tolist()])
721
+ top_bgs = set([x.strip('-') for x in top_alignments_df['biso'].tolist()])
722
+ top_gs = top_hgs | top_tgs | top_bgs
723
+ log_update(f"\nTotal unique head proteins (including isoform) producing top head alignments: {len(top_hgs)}")
724
+ log_update(f"\nTotal unique tail proteins (including isoform) producing top tail alignments: {len(top_tgs)}")
725
+ log_update(f"\nTotal unique proteins (including isoform) - head, tail, or neither - producing top alignments: {len(top_gs)}")
726
+
727
+
728
+
729
+ return stats_df, top_alignments_df
730
+
731
+ def compare_database_blasts(fuson_ht_db, swissprot_blast_stats, fusion_hts_blast_stats, make_new_plots=True):
732
+ # let's start by just returning a list of IDs that were
733
+ # cols = seq_id hgAlignments tgAlignments totalAlignments best_hgScore best_tgScore best_Score h_or_t_alignment h_and_t_alignment
734
+
735
+ # distinguish the columns
736
+ og_cols = list(swissprot_blast_stats.columns)[1::]
737
+ for c in og_cols:
738
+ if c!='seq_id':
739
+ swissprot_blast_stats = swissprot_blast_stats.rename(columns={c: f"swiss_{c}"})
740
+ for c in og_cols:
741
+ if c!='seq_id':
742
+ fusion_hts_blast_stats = fusion_hts_blast_stats.rename(columns={c: f"hts_{c}"})
743
+
744
+ # merge
745
+ merged = pd.merge(swissprot_blast_stats,
746
+ fusion_hts_blast_stats,
747
+ on='seq_id',
748
+ how='outer')
749
+ diff_cols = og_cols[0:-2]
750
+ differences = pd.DataFrame(columns=diff_cols)
751
+ log_update(f"Making volcano plots of the differences between fusion head-tail BLAST and swissprot BLAST in the following columns:\n\t{','.join(diff_cols)}")
752
+ for c in diff_cols:
753
+ differences[c] = merged[f"hts_{c}"] - merged[f"swiss_{c}"]
754
+
755
+ # make some box plots of differences
756
+ # Generate volcano plots for each column
757
+ if make_new_plots:
758
+ os.makedirs("figures",exist_ok=True)
759
+ os.makedirs("figures/database_comparison",exist_ok=True)
760
+ os.makedirs("figures/database_comparison/differences",exist_ok=True)
761
+ os.makedirs("figures/database_comparison/values",exist_ok=True)
762
+ os.makedirs("figures/database_comparison/box",exist_ok=True)
763
+
764
+ group_difference_plot(differences)
765
+ group_swiss_and_ht_plot(merged.drop(columns=['seq_id']), diff_cols)
766
+ group_box_plot(merged.drop(columns=['seq_id']), diff_cols)
767
+
768
+ def fasta_to_dataframe(fasta_file):
769
+ # Read the file into a DataFrame with a single column
770
+ df = pd.read_fwf(fasta_file, header=None, colspecs=[(0, None)], names=['content'])
771
+
772
+ # Select even and odd lines using pandas slicing
773
+ ids = df.iloc[::2].reset_index(drop=True) # Even-indexed lines (IDs)
774
+ sequences = df.iloc[1::2].reset_index(drop=True) # Odd-indexed lines (sequences)
775
+
776
+ # Combine into a new DataFrame
777
+ fasta_df = pd.DataFrame({'ID': ids['content'], 'Sequence': sequences['content']})
778
+ fasta_df['ID'] = fasta_df['ID'].str.split('>',expand=True)[1]
779
+ fasta_df['Sequence'] = fasta_df['Sequence'].str.strip().str.strip('\n')
780
+
781
+ # print a preview of this
782
+ temp = fasta_df.head(10)
783
+ temp['Sequence'] = temp['Sequence'].apply(lambda x: x[0:10]+'...')
784
+ log_update(f"Preview of head/tail fasta sequences in a dataframe:\n{temp.to_string(index=False)}")
785
+
786
+ return fasta_df
787
+
788
+ def get_ht_uniprot_query(swissprot_top_alignments_df):
789
+ '''
790
+ Use swissprot_top_alignments_df to curate all the unique UniProt IDs (ID.Isoform) that created top head and tail alignments
791
+ '''
792
+ swissprot_top_alignments_df['top_hg_full'] = swissprot_top_alignments_df['top_hg_UniProtID']+'.'+swissprot_top_alignments_df['top_hg_UniProt_isoform']
793
+ swissprot_top_alignments_df['top_tg_full'] = swissprot_top_alignments_df['top_tg_UniProtID']+'.'+swissprot_top_alignments_df['top_tg_UniProt_isoform']
794
+
795
+ unique_heads = swissprot_top_alignments_df.loc[
796
+ swissprot_top_alignments_df['top_hg_UniProtID'].notna()
797
+ ]['top_hg_full'].unique().tolist()
798
+
799
+ unique_tails = swissprot_top_alignments_df.loc[
800
+ swissprot_top_alignments_df['top_tg_UniProtID'].notna()
801
+ ]['top_tg_full'].unique().tolist()
802
+
803
+ unique_ht = set(unique_heads).union(set(unique_tails))
804
+ unique_ht = list(unique_ht)
805
+ unique_ht = [x for x in unique_ht if len(x)>1] # not just "."
806
+
807
+ with open("blast_outputs/ht_uniprot_query.txt", "w") as f:
808
+ for i, ht in enumerate(unique_ht):
809
+ if i!= len(unique_ht)-1:
810
+ f.write(f"{ht}\n")
811
+ else:
812
+ f.write(f"{ht}")
813
+
814
+ def main():
815
+ # Later, add the argparse thing back in here and change where the log is and what happens depending on wht the user decides
816
+ # May need to separate blast prep from actual blast for the manuscript, but worry about this later
817
+ with open_logfile(f"fusion_blast_log.txt"):
818
+ # Start by preparing BLAST inputs
819
+ prepare_blast_inputs()
820
+
821
+ # Then run BLAST
822
+ run_blast("blast_inputs",database="swissprot")
823
+
824
+ ###### Analyze BLAST results
825
+ # Make database with head and tail info for each fusion, so we know what to expect
826
+ fuson_ht_db = make_fuson_ht_db(savepath="fuson_ht_db.csv")
827
+
828
+ #parse_all_blast_results(fuson_ht_db, database="swissprot")
829
+ swissprot_blast_stats, swissprot_top_alignments_df = analyze_blast_results(fuson_ht_db,database="swissprot")
830
+
831
+ swissprot_top_alignments_df = pd.read_csv("blast_outputs/swissprot_top_alignments.csv")
832
+ get_ht_uniprot_query(swissprot_top_alignments_df)
833
+ os.makedirs("figures/top_blast_visuals",exist_ok=True)
834
+ group_pos_id_plot(swissprot_top_alignments_df)
835
+
836
+ if __name__ == '__main__':
837
+ main()
838
+
fuson_plm/data/blast/blast_outputs/best_htg_alignments_swissprot_seqs.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d65107b19f8119ac3d37e2269857070be4736facdeebcdb6a7ddbcc339a5d7dc
3
+ size 6855252
fuson_plm/data/blast/blast_outputs/ht_uniprot_query.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d06e54181e6ef2ecd6b1bb09cb6e202f3bfcd2638e9991a9959398ded385f8a6
3
+ size 83285
fuson_plm/data/blast/blast_outputs/swissprot_blast_output_analyzed.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c94fdc967e9e3a84ccf84eff110b36fad3bf3966a8069fe16c1c5f74202ba4cf
3
+ size 96067168
fuson_plm/data/blast/blast_outputs/swissprot_blast_stats.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b4adfd714310a6c5df1490318dbb61ffa7cea6e50ca4d2c0454dd3d1747fa6e
3
+ size 1915092
fuson_plm/data/blast/blast_outputs/swissprot_no_match.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2ed60f354663bc37deaca18480cf7864588d782dad37997bb7917b5498e933c
3
+ size 43049
fuson_plm/data/blast/blast_outputs/swissprot_no_match.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4aab65af73b8ef76f9df2866a1dc4d93819ebcaafff229a5133f517df02386fd
3
+ size 2680
fuson_plm/data/blast/blast_outputs/swissprot_top_alignments.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e97b1341d9a878ff55f4e0940d4b9556c26a2b93ff5a52220cebb6d97150d6f
3
+ size 3293203
fuson_plm/data/blast/extract_blast_seqs.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick script to just get sequences out of
2
+ import subprocess
3
+ import os
4
+ import pandas as pd
5
+ import pickle
6
+
7
+ def get_sequences_from_blastdb(database_path, entries):
8
+ """
9
+ Retrieves sequences for a list of entries from a BLAST database.
10
+
11
+ Parameters:
12
+ - database_path (str): Path to the BLAST database (without file extension).
13
+ - entries (list): List of entry IDs to query.
14
+
15
+ Returns:
16
+ - dict: A dictionary with entry IDs as keys and sequences as values.
17
+ """
18
+ sequences = {}
19
+ os.chdir("ncbi-blast-2.16.0+/swissprot")
20
+ for entry in entries:
21
+ try:
22
+ # Run blastdbcmd command to retrieve the sequence for each entry
23
+ result = subprocess.run(
24
+ ["blastdbcmd", "-db", database_path, "-entry", entry],
25
+ capture_output=True, text=True, check=True
26
+ )
27
+
28
+ # Store the output in the dictionary (entry ID as key, sequence as value)
29
+ # make sure the ID is what we think
30
+ result = result.stdout.strip()
31
+ id = result.split(' ',1)[0].split('>')[1]
32
+ assert id==entry
33
+ seq = result.split('\n',1)[1]
34
+ seq = seq.replace('\n','').strip('').strip('\n')
35
+ sequences[entry] = seq
36
+
37
+ except subprocess.CalledProcessError as e:
38
+ print(f"Error retrieving entry {entry}: {e}")
39
+ sequences[entry] = None # Store None if there's an error for this entry
40
+
41
+ return sequences
42
+
43
+
44
+ def main():
45
+ # Query SwissProt database for the sequences of all the head and tail genes that produced the top alignments
46
+
47
+ htgs = pd.read_csv("blast_outputs/ht_uniprot_query.txt",header=None)
48
+ htgs = list(htgs[0])
49
+
50
+ database_path = "swissprot" # Path to the BLAST database without extension
51
+ entries = htgs
52
+
53
+ sequences_dict = get_sequences_from_blastdb(database_path, entries)
54
+ with open("blast_outputs/best_htg_alignments_swissprot_seqs.pkl", "wb") as f:
55
+ pickle.dump(sequences_dict, f)
56
+
57
+ # Now look at the file you just wrote
58
+ with open("blast_outputs/best_htg_alignments_swissprot_seqs.pkl", "rb") as f:
59
+ d = pickle.load(f)
60
+
61
+ if __name__ == '__main__':
62
+ main()
fuson_plm/data/blast/figures/identities_hist.png ADDED
fuson_plm/data/blast/fusion_blast_log.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c22070ae219a5f7f2bf11c2a87527fbad55d2387464a119849102ea80f84174c
3
+ size 9721
fuson_plm/data/blast/fuson_ht_db.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a92b06df9a78f4a969b691bc545d11923ff07145e02adacfb25fba879573a885
3
+ size 45861419
fuson_plm/data/blast/plot.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from fuson_plm.utils.visualizing import set_font
5
+
6
+ global pos_id_label_dict
7
+ pos_id_label_dict = {
8
+ 'top_UniProt_nIdentities': 'Identities',
9
+ 'top_UniProt_nPositives': 'Positives' # Just makes it easier to label these on plots
10
+ }
11
+
12
+ def plot_pos_or_id_pcnt_hist(data, column_name, save_path=None, ax=None):
13
+ """
14
+ column_name is Positives or Identities
15
+ """
16
+ set_font()
17
+
18
+ if ax is None:
19
+ fig, ax = plt.subplots(figsize=(10, 7))
20
+
21
+ # Make the sample data
22
+ data = data[['aa_seq_len', column_name]].dropna() # only keep those with alignments
23
+ data[column_name] = data[column_name]*100 # so it's %
24
+ data[f"{column_name} Percent Coverage"] = data[column_name] / data['aa_seq_len']
25
+
26
+ # Calculate the mean and median of the percent coverage
27
+ mean_coverage = data[f"{column_name} Percent Coverage"].mean()
28
+ median_coverage = data[f"{column_name} Percent Coverage"].median()
29
+
30
+ # Plot histogram for percent coverage
31
+ ax.hist(data[f"{column_name} Percent Coverage"], bins=50, edgecolor='grey', alpha=0.8, color='mediumpurple')
32
+
33
+ # Add vertical line for the mean
34
+ ax.axvline(mean_coverage, color='black', linestyle='--', linewidth=2)
35
+
36
+ # Add vertical line for the median
37
+ ax.axvline(median_coverage, color='black', linestyle='-', linewidth=2)
38
+
39
+ # Add text label for the mean line
40
+ ax.text(mean_coverage, ax.get_ylim()[1] * 0.9, f'Mean: {mean_coverage:.1f}%', color='black',
41
+ ha='center', va='top', fontsize=40, backgroundcolor='white')
42
+
43
+ # Add text label for the median line
44
+ ax.text(median_coverage, ax.get_ylim()[1] * 0.8, f'Median: {median_coverage:.1f}%', color='black',
45
+ ha='center', va='top', fontsize=40, backgroundcolor='white')
46
+
47
+ # Labels and title
48
+ plt.xticks(fontsize=24)
49
+ plt.yticks(fontsize=24)
50
+ ax.set_xlabel(f"Max % {pos_id_label_dict[column_name]}", fontsize=40)
51
+ ax.set_ylabel("Count", fontsize=40)
52
+ #ax.set_title(f"{pos_id_label_dict[column_name]} Percent Coverage (n={len(data):,})", fontsize=40)
53
+
54
+ plt.tight_layout()
55
+
56
+ # Save the plot
57
+ if save_path is not None:
58
+ plt.savefig(save_path, dpi=300)
59
+
60
+ # Show the plot if no ax is provided
61
+ if ax is None:
62
+ plt.show()
63
+
64
+ def group_pos_id_plot(data):
65
+ set_font()
66
+
67
+ plot_pos_or_id_pcnt_hist(data, 'top_UniProt_nIdentities', save_path=f"figures/identities_hist.png", ax=None)
68
+
69
+ def main():
70
+ swissprot_top_alignments_df = pd.read_csv("blast_outputs/swissprot_top_alignments.csv")
71
+ plot_pos_or_id_pcnt_hist(swissprot_top_alignments_df,
72
+ 'top_UniProt_nIdentities', save_path=f"figures/identities_hist.png", ax=None)
73
+
74
+ if __name__ == '__main__':
75
+ main()
fuson_plm/data/clean.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Imports
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+ import pickle
7
+ from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS
8
+ from fuson_plm.utils.logging import open_logfile, log_update
9
+ from fuson_plm.utils.data_cleaning import clean_rows_and_cols, check_columns_for_listlike, check_item_for_listlike, find_delimiters, find_invalid_chars
10
+ from fuson_plm.data.config import CLEAN
11
+
12
+ def clean_fusionpdb(fusionpdb: pd.DataFrame, tcga_codes, delimiters, valid_aas) -> pd.DataFrame:
13
+ """
14
+ Return a cleaned version of the raw FusionPDB database, downloaded from FusionPDB website "Level 1" link
15
+
16
+ Args:
17
+ fusionpdb (pd.DataFrame): The raw FusionPDB database
18
+ delimiters: delimiters to check for
19
+
20
+ Returns:
21
+ pd.DataFrame: A cleaned version of the raw FusionPDB database with no duplicate sequences.
22
+
23
+ Columns:
24
+ - `aa_seq`: amino acid sequence of fusion oncoprotein. each is unique.
25
+ - `n_fusiongenes`: total number of fusion genes with this amino acid sequence.
26
+ - `fusiongenes`: comma-separated list of fusion genes (hgene::tgene) for this sequence. e.g., "MINK1::SPNS3,UBE2G1::SPNS3"
27
+ - `cancers`: comma-separated list of cancer types for this sequence. e.g., "breast invasive carcinoma,stomach adenocarcinoma"
28
+ - `primary_source`: source FusionPDB pulled the data from
29
+ - `secondary_source`:
30
+ """
31
+ # Process and clean FusionPDB database
32
+ log_update("Cleaning FusionPDB raw data")
33
+
34
+ # FusionPDB is downloaded with no column labels. Fill in column labels here.
35
+ log_update(f"\tfilling in column names...")
36
+ fusionpdb = fusionpdb.rename(columns={
37
+ 0: 'ORF_type',
38
+ 1: 'hgene_ens',
39
+ 2: 'tgene_ens',
40
+ 3: '', # no data in this column
41
+ 4: 'primary_source', # database FusionPDB pulled from
42
+ 5: 'cancer',
43
+ 6: 'database_id',
44
+ 7: 'hgene',
45
+ 8: 'hgene_chr',
46
+ 9: 'hgene_bp',
47
+ 10: 'hgene_strand',
48
+ 11: 'tgene',
49
+ 12: 'tgene_chr',
50
+ 13: 'tgene_bp',
51
+ 14: 'tgene_strand',
52
+ 15: 'bp_dna_transcript',
53
+ 16: 'dna_transcript',
54
+ 17: 'aa_seq_len',
55
+ 18: 'aa_seq',
56
+ 19: 'predicted_start_dna_transcript',
57
+ 20: 'predicted_end_dna_transcript'
58
+ })
59
+
60
+ # Clean rows and columns
61
+ fusionpdb = clean_rows_and_cols(fusionpdb)
62
+
63
+ # Check for list-like qualities in the columns we plan to keep
64
+ cols_of_interest = ['hgene','tgene','cancer','aa_seq','primary_source']
65
+ listlike_dict = check_columns_for_listlike(fusionpdb, cols_of_interest, delimiters)
66
+
67
+ # Add a new column for fusiongene, which combines hgene::tgene. e.g., EWS::FLI1
68
+ log_update("\tadding a column for fusiongene = hgene::tgene")
69
+ fusionpdb['fusiongene'] = (fusionpdb['hgene'] + '::' + fusionpdb['tgene']).astype(str)
70
+
71
+ # Make 'cancer' column type string to ease downstream processing
72
+ log_update("\tcleaning the cancer column...")
73
+ # turn '.' and nan entries into empty string
74
+ fusionpdb = fusionpdb.replace('.',np.nan)
75
+ fusionpdb['cancer'] = fusionpdb['cancer'].astype(str).replace('nan','')
76
+ log_update("\t\tconverting cancer acronyms into full cancer names...")
77
+ fusionpdb['cancer'] = fusionpdb['cancer'].apply(lambda x: tcga_codes[x].lower() if x in tcga_codes else x.lower())
78
+ log_update("\t\tconverting all lists into comma-separated...")
79
+ fusionpdb['cancer'] = fusionpdb['cancer'].str.replace(';',',')
80
+ fusionpdb['cancer'] = fusionpdb['cancer'].str.replace(', ', ',')
81
+ fusionpdb['cancer'] = fusionpdb['cancer'].str.strip()
82
+ fusionpdb['cancer'] = fusionpdb['cancer'].str.strip(',')
83
+ log_update(f"\t\tchecking for delimiters in the cleaned column...")
84
+ check_columns_for_listlike(fusionpdb, ['cancer'], delimiters)
85
+
86
+ # Now that we've dealt with listlike instances, make dictionary of hgene and tgene to their ensembl strings
87
+ log_update("\tcreating dictionary of head and tail genes mapped to Ensembl IDs, to be used later for aquiring UniProtAcc for head and tail genes (needed for BLAST analysis)")
88
+ hgene_to_ensembl_dict = fusionpdb.groupby('hgene').agg(
89
+ {
90
+ 'hgene_ens': lambda x: ','.join(set(x))
91
+ }
92
+ ).reset_index()
93
+ hgene_to_ensembl_dict = dict(zip(hgene_to_ensembl_dict['hgene'],hgene_to_ensembl_dict['hgene_ens']))
94
+ tgene_to_ensembl_dict = fusionpdb.groupby('tgene').agg(
95
+ {
96
+ 'tgene_ens': lambda x: ','.join(set(x))
97
+ }
98
+ ).reset_index()
99
+ tgene_to_ensembl_dict = dict(zip(tgene_to_ensembl_dict['tgene'],tgene_to_ensembl_dict['tgene_ens']))
100
+ # now, we might have some of the same heads and tails being mapped to different things
101
+ all_keys = set(hgene_to_ensembl_dict.keys()).union(set(tgene_to_ensembl_dict.keys()))
102
+ gene_to_ensembl_dict = {}
103
+ for k in all_keys:
104
+ ens = hgene_to_ensembl_dict.get(k,'') + ',' + tgene_to_ensembl_dict.get(k,'')
105
+ ens = ','.join(set(list(ens.strip(',').split(','))))
106
+ gene_to_ensembl_dict[k] = ens
107
+ os.makedirs("head_tail_data",exist_ok=True)
108
+ with open(f"head_tail_data/gene_to_ensembl_dict.pkl", "wb") as f:
109
+ pickle.dump(gene_to_ensembl_dict, f)
110
+ total_unique_ens_ids = list(gene_to_ensembl_dict.values())
111
+ total_unique_ens_ids = set(",".join(total_unique_ens_ids).split(","))
112
+ log_update(f"\t\tTotal unique head/tail genes: {len(gene_to_ensembl_dict)}\n\t\tTotal unique ensembl ids: {len(total_unique_ens_ids)}")
113
+
114
+ # To deal with duplicate sequences, group FusionPDB by sequence and concatenate fusion gene names, cancer types, and primary source
115
+ log_update(f"\tchecking FusionPDB for duplicate protein sequences...\n\t\toriginal size: {len(fusionpdb)}")
116
+ duplicates = fusionpdb[fusionpdb.duplicated('aa_seq')]['aa_seq'].unique().tolist()
117
+ n_fgenes_with_duplicates = len(fusionpdb[fusionpdb['aa_seq'].isin(duplicates)]['fusiongene'].unique())
118
+ n_rows_with_duplicates = len(fusionpdb[fusionpdb['aa_seq'].isin(duplicates)])
119
+ log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows and {n_fgenes_with_duplicates} distinct fusiongenes")
120
+ log_update(f"\tgrouping FusionPDB by amino acid sequence...")
121
+ # Merge step
122
+ fusionpdb = pd.merge(
123
+ fusionpdb.groupby('aa_seq').agg({
124
+ 'fusiongene': lambda x: x.nunique()}).reset_index().rename(columns={'fusiongene':'n_fusiongenes'}),
125
+ fusionpdb.groupby('aa_seq').agg({
126
+ 'fusiongene': lambda x: ','.join(x),
127
+ 'cancer': lambda x: ','.join(x),
128
+ 'primary_source': lambda x: ','.join(x)}).reset_index().rename(columns={'fusiongene':'fusiongenes', 'cancer': 'cancers', 'primary_source':'primary_sources'}).reset_index(drop=True).rename(columns={'fusiongene':'fusiongenes'}),
129
+ on='aa_seq'
130
+ )
131
+ # Turn each aggregated column into sorted, comma-separated list
132
+ fusionpdb['fusiongenes'] = fusionpdb['fusiongenes'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
133
+ fusionpdb['cancers'] = fusionpdb['cancers'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
134
+ fusionpdb['primary_sources'] = fusionpdb['primary_sources'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
135
+
136
+ # Count and display sequences with >1 fusion gene
137
+ duplicates = fusionpdb.loc[fusionpdb['n_fusiongenes']>1]['aa_seq'].tolist()
138
+ log_update(f"\t\treorganized database contains {len(duplicates)} proteins with >1 fusion gene")
139
+ log_update(f"\t\treorganized database contains {len(fusionpdb)} unique oncofusion sequences")
140
+
141
+ # Find invalid amino acids for each sequence and log_update the results
142
+ fusionpdb['invalid_chars'] = fusionpdb['aa_seq'].apply(lambda x: find_invalid_chars(x, valid_aas))
143
+ fusionpdb[fusionpdb['invalid_chars'].str.len()>0].sort_values(by='aa_seq')
144
+ all_invalid_chars = set().union(*fusionpdb['invalid_chars'])
145
+ log_update(f"\tchecking for invalid characters...\n\t\tset of all invalid characters discovered within FusionPDB: {all_invalid_chars}")
146
+
147
+ # Filter out any sequences with invalid amino acids
148
+ fusionpdb = fusionpdb[fusionpdb['invalid_chars'].str.len()==0].reset_index(drop=True).drop(columns=['invalid_chars'])
149
+ log_update(f"\tremoving invalid characters...\n\t\tremaining sequences with valid AAs only: {len(fusionpdb)}")
150
+
151
+ # Add a column for secondary source - FusionPDB.
152
+ fusionpdb['secondary_source'] = ['FusionPDB']*len(fusionpdb)
153
+
154
+ # Final checks of database cleanliness
155
+ log_update(f"\tperforming final checks on cleaned FusionPDB...")
156
+ duplicates = len(fusionpdb.loc[fusionpdb['aa_seq'].duplicated()]['aa_seq'].tolist())
157
+ log_update(f"\t\t{duplicates} duplicate sequences")
158
+ invalids=0
159
+ for x in all_invalid_chars:
160
+ invalids += len(fusionpdb.loc[fusionpdb['aa_seq'].str.contains(x)])
161
+ log_update(f"\t\t{invalids} proteins containing invalid chracters")
162
+ all_unique_seqs = len(fusionpdb)==len(fusionpdb['aa_seq'].unique())
163
+ log_update(f"\t\tevery row contains a unique oncofusion sequence: {all_unique_seqs}")
164
+
165
+ return fusionpdb
166
+
167
+ def clean_fodb(fodb: pd.DataFrame, fodb_codes, delimiters, valid_aas) -> pd.DataFrame:
168
+ """
169
+ Cleans the FOdb database
170
+
171
+ Args:
172
+ fodb (pd.DataFrame): raw FOdb.
173
+ fodb_codes:
174
+ delimiters:
175
+ valid_aas:
176
+
177
+
178
+ Returns:
179
+ pd.DataFrame: a cleaned version of FOdb with no duplicate sequences.
180
+
181
+ Columns:
182
+ - `aa_seq`: amino acid sequence of fusion oncoprotein. each is unique.
183
+ - `n_fusiongenes`: total number of fusion genes with this amino acid sequence.
184
+ - `fusiongenes`: comma-separated list of fusion genes (hgene::tgene) for this sequence. e.g., "MINK1::SPNS3,UBE2G1::SPNS3"
185
+ - `cancers`: comma-separated list of cancer types for this sequence. e.g., "breast invasive carinoma,stomach adenocarcinoma"
186
+ - `primary_source`: source FOdb pulled the data from
187
+ - `secondary_source`: FOdb
188
+ """
189
+
190
+ log_update("Cleaning FOdb raw data")
191
+
192
+ fodb['FO_Name'] = fodb['FO_Name'].apply(lambda x: x.split("_")[0]+"::"+x.split("_")[1])
193
+ fodb = fodb.rename(columns={'Sequence_Source': 'primary_source', 'FO_Name': 'fusiongene', 'AA_Sequence': 'aa_seq'})
194
+ fodb.head()
195
+
196
+ # Clean rows and columns
197
+ fodb = clean_rows_and_cols(fodb)
198
+
199
+ # HEY1::NCOA2 has a "-" on the end by mistake. Replace this with '' for benchmarking purposes
200
+ special_seq = "MKRAHPEYSSSDSELDETIEVEKESADENGNLSSALGSMSPTTSSQILARKRRRGIIEKRRRDRINNSLSELRRLVPSAFEKQGSAKLEKAEILQMTVDHLKMLHTAGGKAFNNPRPGQLGRLLPNQNLPLDITLQSPTGAGPFPPIRNSSPYSVIPQPGMMGNQGMIGNQGNLGNSSTGMIGNSASRPTMPSGEWAPQSSAVRVTCAATTSAMNRPVQGGMIRNPAASIPMRPSSQPGQRQTLQSQVMNIGPSELEMNMGGPQYSQQQAPPNQTAPWPESILPIDQASFASQNRQPFGSSPDDLLCPHPAAESPSDEGALLDQLYLALRNFDGLEEIDRALGIPELVSQSQAVDPEQFSSQDSNIMLEQKAPVFPQQYASQAQMAQGSYSPMQDPNFHTMGQRPSYATLRMQPRPGLRPTGLVQNQPNQLRLQLQHRLQAQQNRQPLMNQISNVSNVNLTLRPGVPTQAPINAQMLAQRQREILNQHLRQRQMHQQQQVQQRTLMMRGQGLNMTPSMVAPSGIPATMSNPRIPQANAQQFPFPPNYGISQQPDPGFTGATTPQSPLMSPRMAHTQSPMMQQSQANPAYQAPSDINGWAQGNMGGNSMFSQQSPPHFGQQANTSMYSNNMNINVSMATNTGGMSSMNQMTGQISMTSVTSVPTSGLSSMGPEQVNDPALRGGNLFPNQLPGMDMIKQEGDTTRKYC-"
201
+ special_seq_name = "HEY1::NCOA2"
202
+ fodb.loc[
203
+ (fodb['fusiongene']==special_seq_name) &
204
+ (fodb['aa_seq']==special_seq), 'aa_seq'
205
+ ] = special_seq.replace('-','')
206
+
207
+ # filter out anything remaining with invalid characters
208
+ fodb['invalid_chars'] = fodb['aa_seq'].apply(lambda x: find_invalid_chars(x, valid_aas))
209
+ all_invalid_chars = set().union(*fodb['invalid_chars'])
210
+ log_update(f"\tchecking for invalid characters...\n\t\tset of all invalid characters discovered within FOdb: {all_invalid_chars}")
211
+
212
+ fodb = fodb[fodb['invalid_chars'].str.len()==0].reset_index(drop=True).drop(columns=['invalid_chars'])
213
+ log_update(f"\tremoving invalid characters...\n\t\tremaining sequences with valid AAs only: {len(fodb)}")
214
+
215
+ # aggregate the cancer data - if there's a 1 in the column, add it to the list of affected cancers
216
+ # acronym -> cancer conversions based on Supplementary Table 3 of FOdb paper (Tripathi et al. 2023 Defining)
217
+ log_update(f"\taggregating cancer data from {len(fodb.columns)-4} individual cancer columns into one...")
218
+ log_update(f"\t\tchanging cancer names from acronyms to full")
219
+ cancers = list(fodb.columns)[4::]
220
+ fodb['cancers'] = ['']*len(fodb)
221
+ for cancer in cancers:
222
+ mapped_cancer = fodb_codes[cancer].lower() if cancer in fodb_codes else cancer
223
+ fodb['cancers'] = fodb.apply(
224
+ lambda row: row['cancers'] + f'{mapped_cancer},' if row[cancer] == 1 else row['cancers'],
225
+ axis=1
226
+ )
227
+ fodb['cancers'] = fodb['cancers'].str.strip(',').replace('nan','')
228
+ fodb = fodb.drop(columns=['Patient_Count']+cancers)
229
+
230
+ # Check for list-like qualities in the columns we plan to keep
231
+ cols_of_interest = ['primary_source','fusiongene','aa_seq','cancers']
232
+ listlike_dict = check_columns_for_listlike(fodb, cols_of_interest, delimiters)
233
+
234
+ # To deal with duplicate sequences, group fodb by sequence and concatenate fusion gene names, cancer types, and primary source
235
+ log_update(f"\tchecking fodb for duplicate protein sequences...\n\t\toriginal size: {len(fodb)}")
236
+ duplicates = fodb[fodb.duplicated('aa_seq')]['aa_seq'].unique().tolist()
237
+ n_fgenes_with_duplicates = len(fodb[fodb['aa_seq'].isin(duplicates)]['fusiongene'].unique())
238
+ n_rows_with_duplicates = len(fodb[fodb['aa_seq'].isin(duplicates)])
239
+ log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows and {n_fgenes_with_duplicates} distinct fusiongenes")
240
+ log_update(f"\tgrouping fodb by amino acid sequence...")
241
+ # Merge step
242
+ fodb = pd.merge(
243
+ fodb.groupby('aa_seq').agg({
244
+ 'fusiongene': lambda x: x.nunique()}).reset_index().rename(columns={'fusiongene':'n_fusiongenes'}),
245
+ fodb.groupby('aa_seq').agg({
246
+ 'fusiongene': lambda x: ','.join(x),
247
+ 'cancers': lambda x: ','.join(x),
248
+ 'primary_source': lambda x: ','.join(x)}).reset_index().rename(columns={'fusiongene':'fusiongenes', 'primary_source':'primary_sources'}).reset_index(drop=True).rename(columns={'fusiongene':'fusiongenes'}),
249
+ on='aa_seq'
250
+ )
251
+ # Turn each aggregated column into sorted, comma-separated list
252
+ fodb['fusiongenes'] = fodb['fusiongenes'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
253
+ fodb['cancers'] = fodb['cancers'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
254
+ fodb['primary_sources'] = fodb['primary_sources'].apply(lambda x: (',').join(sorted(set(x.split(','))))).str.strip(',')
255
+
256
+ # Count and display sequences with >1 fusion gene
257
+ duplicates = fodb.loc[fodb['n_fusiongenes']>1]['aa_seq'].tolist()
258
+ log_update(f"\t\treorganized database contains {len(duplicates)} proteins with >1 fusion gene")
259
+ log_update(f"\t\treorganized database contains {len(fodb)} unique oncofusion sequences")
260
+
261
+ # Add secondary source column because FOdb is the secondary source here.
262
+ fodb['secondary_source'] = ['FOdb']*len(fodb)
263
+
264
+ # Final checks of database cleanliness
265
+ log_update(f"\tperforming final checks on cleaned FOdb...")
266
+ duplicates = len(fodb.loc[fodb['aa_seq'].duplicated()]['aa_seq'].tolist())
267
+ log_update(f"\t\t{duplicates} duplicate sequences")
268
+ invalids=0
269
+ for x in all_invalid_chars:
270
+ invalids += len(fodb.loc[fodb['aa_seq'].str.contains(x)])
271
+ log_update(f"\t\t{invalids} proteins containing invalid chracters")
272
+ all_unique_seqs = len(fodb)==len(fodb['aa_seq'].unique())
273
+ log_update(f"\t\tevery row contains a unique oncofusion sequence: {all_unique_seqs}")
274
+
275
+ return fodb
276
+
277
+ def create_fuson_db(fusionpdb: pd.DataFrame, fodb: pd.DataFrame) -> pd.DataFrame:
278
+ """
279
+ Merges cleaned FusionPDB and FOdb to create fuson_db (the full set of fusion sequences for training/benchmarking FusOn-pLM)
280
+
281
+ Args:
282
+ fusionpdb (pd.DataFrame):
283
+ """
284
+ log_update("Creating the merged database...")
285
+
286
+ log_update("\tconcatenating cleaned FusionPDb and cleaned FOdb...")
287
+ fuson_db = pd.concat(
288
+ [
289
+ fusionpdb.rename(columns={'secondary_source':'secondary_sources'}),
290
+ fodb.rename(columns={'secondary_source':'secondary_sources'})
291
+ ]
292
+ )
293
+
294
+ # Handle dupliate amino acid sequences
295
+ log_update(f"\tchecking merged database for duplicate protein sequences...\n\t\toriginal size: {len(fuson_db)}")
296
+ duplicates = fuson_db[fuson_db.duplicated('aa_seq')]['aa_seq'].unique().tolist()
297
+ n_fgenes_with_duplicates = len(fuson_db[fuson_db['aa_seq'].isin(duplicates)]['fusiongenes'].unique())
298
+ n_rows_with_duplicates = len(fuson_db[fuson_db['aa_seq'].isin(duplicates)])
299
+ log_update(f"\t\t{len(duplicates)} duplicated sequences, corresponding to {n_rows_with_duplicates} rows and {n_fgenes_with_duplicates} distinct fusiongenes")
300
+ log_update(f"\tgrouping database by amino acid sequence...")
301
+
302
+ fuson_db = fuson_db.groupby('aa_seq').agg(
303
+ {
304
+ 'fusiongenes': lambda x: ','.join(x),
305
+ 'cancers': lambda x: ','.join(x),
306
+ 'primary_sources': lambda x: ','.join(x),
307
+ 'secondary_sources': lambda x: ','.join(x)
308
+ }
309
+ ).reset_index()
310
+ duplicates = fuson_db.loc[fuson_db['fusiongenes'].str.count(',')>0]['aa_seq'].tolist()
311
+ log_update(f"\t\treorganized database contains {len(duplicates)} proteins with >1 fusion gene")
312
+ log_update(f"\t\treorganized database contains {len(fuson_db)} unique oncofusion sequences")
313
+
314
+ # Turn each aggregated column into a set of only the unique entires
315
+ for column in fuson_db.columns[1::]:
316
+ fuson_db[column] = fuson_db[column].apply(lambda x: (',').join(sorted(set(
317
+ [y for y in x.split(',') if len(y)>0]))))
318
+
319
+ # Add a column for length
320
+ log_update(f"\tadding a column for length...")
321
+ fuson_db['length'] = fuson_db['aa_seq'].apply(lambda x: len(x))
322
+
323
+ # Sort by fusiongenes, then length
324
+ log_update(f"\tsorting by fusion gene name, then length...")
325
+ fuson_db = fuson_db.sort_values(by=['fusiongenes','length'],ascending=[True,True]).reset_index(drop=True)
326
+
327
+ # Add a seq_id column: seq1, seq2, ..., seqn
328
+ log_update(f"\tadding sequence ids: seq1, seq2, ..., seqn")
329
+ fuson_db['seq_id'] = ['seq'+str(i+1) for i in range(len(fuson_db))]
330
+
331
+ # Final checks of database cleanliness
332
+ log_update(f"\tperforming final checks on fuson_db...")
333
+ duplicates = len(fuson_db.loc[fuson_db['aa_seq'].duplicated()]['aa_seq'].tolist())
334
+ log_update(f"\t\t{duplicates} duplicate sequences")
335
+ all_unique_seqs = len(fuson_db)==len(fuson_db['aa_seq'].unique())
336
+ log_update(f"\t\tevery row contains a unique oncofusion sequence: {all_unique_seqs}")
337
+
338
+ return fuson_db
339
+
340
+ def head_tail_mappings(fuson_db):
341
+ log_update("\nGenes and Ensembl IDs corresponding to the head and tail proteins have been mapped on UniProt. Now, combining these results.")
342
+
343
+ # Read the ensembl map, gene name map, and dictionary from gene --> ensembl ids
344
+ ensembl_map = pd.read_csv("head_tail_data/ensembl_ht_idmap.txt",sep="\t")
345
+ name_map = pd.read_csv("head_tail_data/genename_ht_idmap.txt",sep="\t")
346
+ with open("head_tail_data/gene_to_ensembl_dict.pkl", "rb") as f:
347
+ gene_ens_dict = pickle.load(f)
348
+
349
+ log_update(f"\tCheck: ensembl map and gene name map have same columns: {set(ensembl_map.columns)==set(name_map.columns)}")
350
+ log_update(f"\t\tColumns = {list(ensembl_map.columns)}")
351
+
352
+ # Prepare to merge
353
+ log_update(f"\tMerging the ensembl map and gene name map:")
354
+ ensembl_map = ensembl_map.rename(columns={'From': 'ensembl_id'}) # mapped from ensembl ids
355
+ name_map = name_map.rename(columns={'From': 'htgene'}) # mapped from head or tail genes
356
+ name_map['ensembl_id'] = name_map['htgene'].map(gene_ens_dict) # add ensembl id column bsed on head and tail genes
357
+ name_map['ensembl_id'] = name_map['ensembl_id'].apply(lambda x: x.split(',') if type(x)==str else x) # make it a string if multiple matches
358
+ log_update(f"\t\tLength of gene-based map before exploding ensembl_id column: {len(name_map)}")
359
+ name_map = name_map.explode('ensembl_id') # explode so each ensembl id is its own line
360
+ log_update(f"\t\tLength of gene-based map after exploding ensembl_id column: {len(name_map)}")
361
+ log_update(f"\t\tLength of ensembl-based map: {len(ensembl_map)}")
362
+ unimap = pd.merge(name_map[['htgene','ensembl_id','Entry','Reviewed']],
363
+ ensembl_map[['ensembl_id','Entry','Reviewed']],
364
+ on=['ensembl_id','Entry','Reviewed'],
365
+ how='outer'
366
+ )
367
+ unimap['Reviewed'] = unimap['Reviewed'].apply(lambda x: '1' if x=='reviewed' else '0' if x=='unreviewed' else 'N') # N for nan
368
+ log_update(f"\t\tLength of merge: {len(unimap)}. Merge preview:")
369
+ log_update(unimap.head())
370
+ unimap = unimap.drop_duplicates(['htgene','Entry','Reviewed']).reset_index(drop=True)
371
+ log_update(f"\t\tLength of merge after dropping rows where only ensembl_id changed: {len(unimap)}. Merge preview: ")
372
+ log_update(unimap.head())
373
+ unimap = unimap.groupby('htgene').agg(
374
+ {
375
+ 'Entry': lambda x: ','.join(x),
376
+ 'Reviewed': lambda x: ''.join(x)
377
+ }
378
+ ).reset_index()
379
+ unimap = unimap.rename(columns={
380
+ 'htgene': 'Gene',
381
+ 'Entry': 'UniProtID',
382
+ })
383
+ log_update(f"\t\tLength of merge after grouping by gene name: {len(unimap)}. Merge preview:")
384
+ log_update(unimap.head())
385
+
386
+ # what are the proteins whose head or tail genes are in this list?
387
+ log_update(f"\tChecking which fusion proteins have unmappable heads and/or tails:")
388
+ temp = fuson_db.copy(deep=True)
389
+ temp['fusiongenes'] = temp['fusiongenes'].apply(lambda x: x.split(','))
390
+ temp = temp.explode('fusiongenes')
391
+ temp['hgene'] = temp['fusiongenes'].str.split('::',expand=True)[0]
392
+ temp['tgene'] = temp['fusiongenes'].str.split('::',expand=True)[1]
393
+
394
+ # See which gene IDs weren't covered
395
+ log_update(f"\tChecking which gene IDs were not mapped by either method")
396
+ all_geneids = temp['hgene'].tolist() +temp['tgene'].tolist()
397
+ all_geneids = list(set(all_geneids))
398
+ all_mapped_genes = unimap['Gene'].unique().tolist()
399
+ unmapped_geneids = set(all_geneids) - set(all_mapped_genes)
400
+ log_update(f"\t\t{len(all_mapped_genes)}/{len(all_geneids)} were mapped\n\t\t{len(unmapped_geneids)}/{len(all_geneids)} were unmapped")
401
+ log_update(f"\t\tUnmapped geneids: {','.join(unmapped_geneids)}")
402
+
403
+ # Find the ok ones and print
404
+ ok_seqs = temp.loc[
405
+ (temp['hgene'].isin(all_mapped_genes)) | # head gene was found, OR
406
+ (temp['tgene'].isin(all_mapped_genes)) # tail gene was found
407
+ ]['seq_id'].unique().tolist()
408
+ ok_seqsh = temp.loc[
409
+ (temp['hgene'].isin(all_mapped_genes)) # head gene was found
410
+ ]['seq_id'].unique().tolist()
411
+ ok_seqst = temp.loc[
412
+ (temp['tgene'].isin(all_mapped_genes)) # tail gene was found
413
+ ]['seq_id'].unique().tolist()
414
+ ok_seqsboth = temp.loc[
415
+ (temp['hgene'].isin(all_mapped_genes)) & # head gene was found, AND
416
+ (temp['tgene'].isin(all_mapped_genes)) # tail gene was found
417
+ ]['seq_id'].unique().tolist()
418
+
419
+ log_update(f"\tTotal fusion sequence ids: {len(temp['seq_id'].unique())}")
420
+ log_update(f"\tFusion sequences with at least 1 mapped constituent:\
421
+ \n\t\tMapped head: {len(ok_seqsh)}\
422
+ \n\t\tMapped tail: {len(ok_seqst)}\
423
+ \n\t\tMapped head or tail: {len(ok_seqs)}\
424
+ \n\t\tMapped head AND tail: {len(ok_seqsboth)}")
425
+
426
+ # Now look at the bad side
427
+ atleast_1_lost = temp.loc[
428
+ ((temp['hgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqsh))) | # head not found in row, AND head not found for seq_id - OR
429
+ ((temp['tgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqst))) # tail not found in row, AND tail not found for seq_id
430
+ ]['seq_id'].unique().tolist()
431
+ atleast_1_losth = temp.loc[
432
+ (temp['hgene'].isin(unmapped_geneids)) & # head not found in this row AND
433
+ ~(temp['seq_id'].isin(ok_seqsh)) # head not found for this seq id
434
+ ]['seq_id'].unique().tolist()
435
+ atleast_1_lostt = temp.loc[
436
+ (temp['tgene'].isin(unmapped_geneids)) & # tail not found in this row AND
437
+ ~(temp['seq_id'].isin(ok_seqst)) # tail not found for this seq id
438
+ ]['seq_id'].unique().tolist()
439
+ both_lost = temp.loc[
440
+ ((temp['hgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqsh))) & # there's no head, and this seq id has no head - AND
441
+ ((temp['tgene'].isin(unmapped_geneids)) & ~(temp['seq_id'].isin(ok_seqst))) # there's no tail, and this seq id has no tail
442
+ ]['seq_id'].unique().tolist()
443
+ log_update(f"\tFusion sequences with at least 1 unmapped constituent:")
444
+ log_update(f"\t\tUnmapped head: {len(atleast_1_losth)}\
445
+ \n\t\tUnmapped tail: {len(atleast_1_lostt)}\
446
+ \n\t\tUnmapped head or tail: {len(atleast_1_lost)}\
447
+ \n\t\tUnmapped head AND tail: {len(both_lost)}")
448
+ log_update(f"\tseq_ids with at least 1 unmapped part: {atleast_1_lost}")
449
+
450
+ assert len(ok_seqsboth)+ len(atleast_1_lost) == len(temp['seq_id'].unique())
451
+ log_update(f"\tFusions with H&T covered plus Fusions with H|T lost = total = {len(ok_seqsboth)}+ {len(atleast_1_lost)} = {len(ok_seqsboth)+ len(atleast_1_lost)} = {len(temp['seq_id'].unique())}")
452
+
453
+ ### Save the unimap
454
+ unimap.to_csv('head_tail_data/htgenes_uniprotids.csv',index=False)
455
+
456
+ def assemble_uniprot_query(path_to_gene_ens_dict="head_tail_data/gene_to_ensembl_dict.pkl",path_to_fuson_db="fuson_db.csv"):
457
+ """
458
+ To analyze the BLAST results effectively, we must know which UniProt accessions we *expect* to see for each fusion oncoprotein.
459
+ We will try to map each FO to its head and tail accessions by searching UniProt ID map by gene name and Ensembl ID.
460
+
461
+ This method will create two input lists for UniProt:
462
+ - gene_name_inputs.txt: list of all uinque head and tail gene names
463
+ - ensembl_inputs.txt
464
+ """
465
+ log_update("\nMaking inputs for UniProt ID map, to find accessions for head and tail genes")
466
+ if not(os.path.exists(path_to_gene_ens_dict)):
467
+ raise Exception(f"File {path_to_gene_ens_dict} does not exist")
468
+
469
+ with open(path_to_gene_ens_dict, "rb") as f:
470
+ gene_ens_dict = pickle.load(f)
471
+
472
+ all_htgenes_temp = list(gene_ens_dict.keys())
473
+ all_ens = list(gene_ens_dict.values())
474
+ all_ens = list(set(",".join(all_ens).split(",")))
475
+ log_update(f"\tTotal unique head and tail genes, only accounting for FusionPDB: {len(all_htgenes_temp)}")
476
+
477
+ # need to add other htgenes from UniProt
478
+ fuson_db = pd.read_csv(path_to_fuson_db)
479
+ fuson_db['fusiongenes'] = fuson_db['fusiongenes'].apply(lambda x: x.split(','))
480
+ fuson_db = fuson_db.explode('fusiongenes')
481
+ fuson_db['hgene'] = fuson_db['fusiongenes'].str.split('::',expand=True)[0]
482
+ fuson_db['tgene'] = fuson_db['fusiongenes'].str.split('::',expand=True)[1]
483
+ fuson_htgenes = fuson_db['hgene'].tolist() + fuson_db['tgene'].tolist()
484
+ fuson_htgenes = set(fuson_htgenes)
485
+ all_htgenes = set(all_htgenes_temp).union(set(fuson_htgenes))
486
+ all_htgenes = list(set(all_htgenes))
487
+
488
+ log_update(f"\tTotal unique head and tail genes after adding FOdb: {len(all_htgenes)}")
489
+ log_update(f"\tTotal unique ensembl IDs: {len(all_ens)}")
490
+ # go through each and write a file
491
+ input_dir = "head_tail_data/uniprot_idmap_inputs"
492
+ os.makedirs(input_dir,exist_ok=True)
493
+
494
+ if os.path.exists(f"{input_dir}/head_tail_genes.txt"):
495
+ log_update("\nAlready assembled UniProt ID mapping input for head and tail genes. Continuing")
496
+ else:
497
+ with open(f"{input_dir}/head_tail_genes.txt", "w") as f:
498
+ for i, gene in enumerate(all_htgenes):
499
+ if i!=len(all_htgenes)-1:
500
+ f.write(f"{gene}\n")
501
+ else:
502
+ f.write(f"{gene}")
503
+
504
+ if os.path.exists(f"{input_dir}/head_tail_ens.txt"):
505
+ log_update("\nAlready assembled UniProt ID mapping input for head and tail ensembl IDs. Continuing")
506
+ else:
507
+ with open(f"{input_dir}/head_tail_ens.txt", "w") as f:
508
+ for i, ens in enumerate(all_ens):
509
+ if i!=len(all_ens)-1:
510
+ f.write(f"{ens}\n")
511
+ else:
512
+ f.write(f"{ens}")
513
+ def main():
514
+ # Define global variables from config.DATA_CLEANING
515
+ FODB_PATH = CLEAN.FODB_PATH
516
+ FODB_PUNCTA_PATH = CLEAN.FODB_PUNCTA_PATH
517
+ FUSIONPDB_PATH = CLEAN.FUSIONPDB_PATH
518
+ LOG_PATH = "data_cleaning_log.txt"
519
+ SAVE_CLEANED_FODB = False
520
+
521
+ # Prepare the log file
522
+ with open_logfile(LOG_PATH):
523
+ log_update("Loaded data-cleaning configurations from config.py")
524
+ CLEAN.print_config(indent='\t')
525
+
526
+ log_update("Reading FusionPDB...")
527
+ fusionpdb = pd.read_csv(FUSIONPDB_PATH,sep='\t',header=None)
528
+ fusionpdb = clean_fusionpdb(fusionpdb, TCGA_CODES, DELIMITERS, VALID_AAS)
529
+
530
+ log_update("Saving FusionPDB to FusionPDB_cleaned.csv...")
531
+ fusionpdb.to_csv('raw_data/FusionPDB_cleaned.csv', index=False)
532
+
533
+ # Clean FOdb, optinoally save
534
+ log_update("Reading FOdb...")
535
+ fodb = pd.read_csv(FODB_PATH)
536
+ fodb = clean_fodb(fodb, FODB_CODES, DELIMITERS, VALID_AAS)
537
+
538
+ if SAVE_CLEANED_FODB:
539
+ log_update("Saving FOdb to FOdb_cleaned.csv...")
540
+ fusionpdb.to_csv('FOdb_cleaned.csv', index=False)
541
+
542
+ # Merge FusionPDB and FOdb to fuson_db
543
+ fuson_db = create_fuson_db(fusionpdb, fodb)
544
+
545
+ # Mark benchmarking sequences
546
+ # FOdb puncta benchmark
547
+ log_update("Adding benchmarking sequences to fuson_db...")
548
+ fodb_puncta = pd.read_csv(FODB_PUNCTA_PATH)
549
+
550
+ # handle the mistake sequence - take the "-" off the end
551
+ special_seq = "MKRAHPEYSSSDSELDETIEVEKESADENGNLSSALGSMSPTTSSQILARKRRRGIIEKRRRDRINNSLSELRRLVPSAFEKQGSAKLEKAEILQMTVDHLKMLHTAGGKAFNNPRPGQLGRLLPNQNLPLDITLQSPTGAGPFPPIRNSSPYSVIPQPGMMGNQGMIGNQGNLGNSSTGMIGNSASRPTMPSGEWAPQSSAVRVTCAATTSAMNRPVQGGMIRNPAASIPMRPSSQPGQRQTLQSQVMNIGPSELEMNMGGPQYSQQQAPPNQTAPWPESILPIDQASFASQNRQPFGSSPDDLLCPHPAAESPSDEGALLDQLYLALRNFDGLEEIDRALGIPELVSQSQAVDPEQFSSQDSNIMLEQKAPVFPQQYASQAQMAQGSYSPMQDPNFHTMGQRPSYATLRMQPRPGLRPTGLVQNQPNQLRLQLQHRLQAQQNRQPLMNQISNVSNVNLTLRPGVPTQAPINAQMLAQRQREILNQHLRQRQMHQQQQVQQRTLMMRGQGLNMTPSMVAPSGIPATMSNPRIPQANAQQFPFPPNYGISQQPDPGFTGATTPQSPLMSPRMAHTQSPMMQQSQANPAYQAPSDINGWAQGNMGGNSMFSQQSPPHFGQQANTSMYSNNMNINVSMATNTGGMSSMNQMTGQISMTSVTSVPTSGLSSMGPEQVNDPALRGGNLFPNQLPGMDMIKQEGDTTRKYC-"
552
+ special_seq_name = "HEY1_NCOA2"
553
+ fodb_puncta.loc[
554
+ (fodb_puncta['FO_Name']==special_seq_name) &
555
+ (fodb_puncta['AAseq']==special_seq), 'AAseq'
556
+ ] = special_seq.replace('-','')
557
+
558
+ fodb_puncta_sequences = fodb_puncta['AAseq'].unique().tolist()
559
+ benchmark_sequences = dict(zip(fodb_puncta_sequences, ['Puncta']*len(fodb_puncta_sequences)))
560
+ log_update(f"\tRead FOdb puncta data and isolated {len(benchmark_sequences)} sequences for puncta benchmark")
561
+ # Biological discovery benchmark
562
+ benchmark_sequences2 = fuson_db.loc[
563
+ (fuson_db['fusiongenes'].str.contains('EWSR1::FLI1')) |
564
+ (fuson_db['fusiongenes'].str.contains('PAX3::FOXO1')) |
565
+ (fuson_db['fusiongenes'].str.contains('BCR::ABL1')) |
566
+ (fuson_db['fusiongenes'].str.contains('EML4::ALK'))
567
+ ]['aa_seq'].unique().tolist()
568
+ benchmark_sequences2 = dict(zip(benchmark_sequences2, ['Biological Discovery']*len(benchmark_sequences2)))
569
+ log_update(f"\tIsolated all EWSR1::FLI1, PAX3::FOXO1, BCR::ABL1, and EML4::ALK sequences ({len(benchmark_sequences2)} total) for biological benchmarks...")
570
+
571
+ for k, v in benchmark_sequences2.items():
572
+ if k in benchmark_sequences:
573
+ benchmark_sequences[k] = benchmark_sequences[k] + ',' + v
574
+ else:
575
+ benchmark_sequences[k] = v
576
+
577
+ log_update(f"\tTotal unique benchmark sequences: {len(benchmark_sequences)}")
578
+ # Add benchmark column
579
+ log_update("\tAdding benchmark column...")
580
+ fuson_db['benchmark'] = fuson_db['aa_seq'].apply(lambda x: benchmark_sequences[x] if x in benchmark_sequences else np.nan)
581
+
582
+ # Save fuson_db
583
+ log_update("\nWriting final database to fuson_db.csv...")
584
+ fuson_db.to_csv('fuson_db.csv', index=False)
585
+ log_update("Cleaning complete.")
586
+
587
+ # Assemble head tail queries for UniProt
588
+ assemble_uniprot_query()
589
+
590
+ # Do the head tail mappings
591
+ head_tail_mappings(fuson_db)
592
+
593
+ if __name__ == '__main__':
594
+ main()
fuson_plm/data/cluster.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ from Bio import SeqIO
7
+ import shutil
8
+ from fuson_plm.utils.logging import open_logfile, log_update
9
+ from fuson_plm.data.config import CLUSTER
10
+
11
+ def main():
12
+ # Read all the input args
13
+ LOG_PATH = "clustering_log.txt"
14
+ INPUT_PATH = CLUSTER.INPUT_PATH
15
+ MIN_SEQ_ID = CLUSTER.MIN_SEQ_ID
16
+ C = CLUSTER.C
17
+ COV_MODE = CLUSTER.COV_MODE
18
+ PATH_TO_MMSEQS = CLUSTER.PATH_TO_MMSEQS
19
+ MAX_SEQ_LENGTH = CLUSTER.MAX_SEQ_LENGTH
20
+
21
+ with open_logfile(LOG_PATH):
22
+ log_update("Input params from config.py:")
23
+ CLUSTER.print_config(indent='\t')
24
+ # Make a subfolder for clustering results, and direct MMSeqs2 outputs here
25
+ if not(os.path.exists("clustering")):
26
+ os.mkdir("clustering")
27
+ output_dir = "clustering/raw_output"
28
+
29
+ # Make fasta of input file
30
+ sequences = pd.read_csv(INPUT_PATH)
31
+ log_update(f"\nPreparing input data...\n\tInitial dataset size: {len(sequences)} sequences")
32
+
33
+ sequences = sequences.loc[sequences['aa_seq'].str.len() <= MAX_SEQ_LENGTH].reset_index(drop=True)
34
+ log_update(f"\tApplied length cutoff of {MAX_SEQ_LENGTH}AAs. New dataset size: {len(sequences)} sequences")
35
+
36
+ sequences = dict(zip(sequences['seq_id'],sequences['aa_seq']))
37
+ fasta_path = make_fasta(sequences, "clustering/input.fasta")
38
+ log_update(f"\tMade fasta of input sequences, saved at {fasta_path}")
39
+
40
+ run_mmseqs_clustering(fasta_path, output_dir, min_seq_id=MIN_SEQ_ID, c=C, cov_mode=COV_MODE, path_to_mmseqs=PATH_TO_MMSEQS)
41
+
42
+ # Brief read to preview results
43
+ clusters = analyze_clustering_result('clustering/input.fasta', 'clustering/raw_output/mmseqs_cluster.tsv')
44
+ # Save clusters
45
+ clusters.to_csv('clustering/mmseqs_full_results.csv',index=False)
46
+ log_update("Processed and combined mmseqs output. Wrote comprehensive results to clustering/mmseqs_full_results.csv")
47
+ cluster_summary(clusters)
48
+
49
+ if __name__ == "__main__":
50
+ main()
fuson_plm/data/clustering/input.fasta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76ad3e500fd5e51220d210c2fdef65d761a9f9e8b7962c94bcc79b093408f7b7
3
+ size 27788610
fuson_plm/data/clustering/mmseqs_full_results.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e503a6511f93265a964fd105200a05fa957d9fc2e0edee37dbb3f0b0f55486e
3
+ size 55967813
fuson_plm/data/clustering_log.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73371a8475ebcbef54f15b9c48caa32da3b2ebd6ffac224677c8208792fef41d
3
+ size 2931
fuson_plm/data/config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ from fuson_plm.utils.logging import CustomParams
3
+
4
+ CLEAN = CustomParams(
5
+ ### Changing these parameters is not recommended
6
+ FODB_PATH = '../data/raw_data/FOdb_all.csv', # path to raw FOdb database
7
+ FODB_PUNCTA_PATH = '../data/raw_data/FOdb_puncta.csv', # path to raw FOdb puncta experimental data
8
+ FUSIONPDB_PATH = '../data/raw_data/FusionPDB.txt', # path to raw FusionPDB Level 1 .txt download
9
+ )
10
+
11
+ # Clustering Parameters
12
+ CLUSTER = CustomParams(
13
+ MAX_SEQ_LENGTH = 2000, # INCLUSIVE max length (amino acids) of a sequence for training, validation, or testing
14
+
15
+ # MMSeqs2 parameters: see GitHub or MMSeqs2 Wiki for guidance
16
+ MIN_SEQ_ID = 0.3, # % identity
17
+ C = 0.8, # % sequence length overlap
18
+ COV_MODE = 0, # cov-mode: 0 = bidirectional, 1 = target coverage, 2 = query coverage, 3 = target-in-query length coverage.
19
+ # File paths
20
+ INPUT_PATH = '../data/fuson_db.csv',
21
+ PATH_TO_MMSEQS = '../mmseqs' # path to where you installed MMSeqs2
22
+ )
23
+
24
+ # Splitting Parameters
25
+ # We randomly split clusters in two rounds to arrive at a Train, Validation, and Test set.
26
+ # Round 1) All clusters -> Train (final) and Other (temp). Round 2) Other (temp) clusters -> Val (final) and Test (final)
27
+ SPLIT = CustomParams(
28
+ FUSON_DB_PATH = '../data/fuson_db.csv',
29
+ CLUSTER_OUTPUT_PATH = '../data/clustering/mmseqs_full_results.csv',
30
+ RANDOM_STATE_1 = 2, # random_state_1 = state for splitting all data into train & other
31
+ TEST_SIZE_1 = 0.18, # test size for data -> train/test split. e.g. 20 means 80% clusters in train, 20% clusters in other
32
+ RANDOM_STATE_2 = 6, # random_state_2 = state for splitting other from ^ into val and test
33
+ TEST_SIZE_2 = 0.44 # test size for train -> train/val split. e.g. 0.50 means 50% clusters in train, 50% clusters in test
34
+ )
fuson_plm/data/data_cleaning_log.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3210df509a0c50e4ec07f9df396ff166abf85212342e203eb9d9dac1115eca71
3
+ size 10381
fuson_plm/data/fuson_db.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5c841ff582a45ba2427ed504d965ba01ab49e6cb2f3dacf2e4e3cbedad255d3
3
+ size 37076062
fuson_plm/data/head_tail_data/ensembl_ht_idmap.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ae973ad3b86408e684efc0af63249af50cc8c6e1bce73465550dc0a9c2bc839
3
+ size 28978535
fuson_plm/data/head_tail_data/gene_to_ensembl_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46f49a7d49c00e80aa4426da6245558d7d36cf21525620d3f1d5339c1772df40
3
+ size 547183
fuson_plm/data/head_tail_data/genename_ht_idmap.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7de358687c013017ca30c67b613d861794430f8c8e040478710e56535801c92
3
+ size 54844814
fuson_plm/data/head_tail_data/htgenes_uniprotids.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f1ff59b6e41f58585f0ce588d2232bfe5605219b047f478c6ce74ecd2715a1d
3
+ size 889031
fuson_plm/data/head_tail_data/isoform_fasta_id_output_formatted.fasta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a43d06330e000c1f81d905fb3bba4e911f5530f6d34d1476f206bea4a420ddd
3
+ size 41148731
fuson_plm/data/head_tail_data/uniprot_idmap_inputs/head_tail_ens.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a85c21273e9cfdb0baa5e22f8a033413c68299d71ba7a8deecdf93d057d280ac
3
+ size 442383
fuson_plm/data/head_tail_data/uniprot_idmap_inputs/head_tail_genes.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b9ce1957875bdf05670d6c7412fa082896837714573197c6d92c5777cb24746
3
+ size 67736
fuson_plm/data/raw_data/FOdb_SD5.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b4fae3b8ab9661500ac34dc5e96069f76118f79ceac8e101d5361fe5e46d4b4
3
+ size 19345
fuson_plm/data/raw_data/FOdb_all.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acd531078289d83d42cdd8c0031b682612d396d92eea2e1e8b1871044424fdb0
3
+ size 3876082
fuson_plm/data/raw_data/FOdb_puncta.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ece2df3daef17ef73676ec35074fe9534038160a82b8125411ab4f7fefed54b
3
+ size 237498
fuson_plm/data/raw_data/FusionPDB.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4040f181272d77ca5ab72fd85d00c6c36d7edd43613cebb44357f582eac7f3db
3
+ size 531417333
fuson_plm/data/raw_data/FusionPDB_cleaned.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:351ef6f4f93e40859b5cef19ab5ac0729c6eeda8ac732fbe4bed7b68e1c5c7d2
3
+ size 34245297
fuson_plm/data/split.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import pickle
4
+ from fuson_plm.data.config import SPLIT
5
+ from fuson_plm.utils.logging import log_update, open_logfile
6
+ from fuson_plm.utils.splitting import split_clusters, check_split_validity
7
+ from fuson_plm.utils.visualizing import set_font, visualize_splits
8
+
9
+ def get_benchmark_data(fuson_db_path, clusters):
10
+ """
11
+ """
12
+ # Read the fusion database
13
+ fuson_db = pd.read_csv(fuson_db_path)
14
+
15
+ # Get original benchmark sequences, and benchmark sequences that were clustered
16
+ original_benchmark_sequences = fuson_db.loc[(fuson_db['benchmark'].notna()) ]
17
+ benchmark_sequences = fuson_db.loc[
18
+ (fuson_db['benchmark'].notna()) & # it's a benchmark sequence
19
+ (fuson_db['aa_seq'].isin(list(clusters['member seq']))) # it was clustered (it's under the length limit specified for clustering)
20
+ ]['aa_seq'].to_list()
21
+
22
+ # Get the sequence IDs of all clustered benchmark sequences.
23
+ benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
24
+
25
+ # Use benchmark_seq_ids to find which clusters contain benchmark sequences.
26
+ benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
27
+ log_update(f"\t{len(benchmark_sequences)}/{len(original_benchmark_sequences)} benchmarking sequences (only those shorter than config.CLUSTERING[\'max_seq_length\']) were grouped into {len(benchmark_cluster_reps)} clusters. These will be reserved for the test set.")
28
+
29
+ return benchmark_cluster_reps, benchmark_sequences
30
+
31
+ def get_training_dfs(train, val, test):
32
+ log_update('\nMaking dataframes for ESM finetuning...')
33
+
34
+ # Delete cluster-related columns we don't need
35
+ train = train.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
36
+ val = val.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
37
+ test = test.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'})
38
+
39
+ return train, val, test
40
+
41
+ def main():
42
+ """
43
+ """
44
+ # Read all the input files
45
+ LOG_PATH = "splitting_log.txt"
46
+ FUSON_DB_PATH = SPLIT.FUSON_DB_PATH
47
+ CLUSTER_OUTPUT_PATH = SPLIT.CLUSTER_OUTPUT_PATH
48
+ RANDOM_STATE_1 = SPLIT.RANDOM_STATE_1
49
+ TEST_SIZE_1 = SPLIT.TEST_SIZE_1
50
+ RANDOM_STATE_2 = SPLIT.RANDOM_STATE_2
51
+ TEST_SIZE_2 = SPLIT.TEST_SIZE_2
52
+
53
+ # set font
54
+ set_font()
55
+
56
+ # Prepare the log file
57
+ with open_logfile(LOG_PATH):
58
+
59
+ log_update("Loaded data-splitting configurations from config.py")
60
+ SPLIT.print_config(indent='\t')
61
+
62
+ # Prepare directory to save results
63
+ os.makedirs("splits",exist_ok=True)
64
+
65
+ # Read the clusters and get a list of the representative IDs for splitting
66
+ clusters = pd.read_csv(CLUSTER_OUTPUT_PATH)
67
+ reps = clusters['representative seq_id'].unique().tolist()
68
+ log_update(f"\nPreparing clusters...\n\tCollected {len(reps)} clusters for splitting")
69
+
70
+ # Get the benchmark cluster representatives and sequences
71
+ benchmark_cluster_reps, benchmark_sequences = get_benchmark_data(FUSON_DB_PATH, clusters)
72
+
73
+ # Make the splits and extract the results
74
+ splits = split_clusters(reps, benchmark_cluster_reps=benchmark_cluster_reps,
75
+ random_state_1 = RANDOM_STATE_1, random_state_2 = RANDOM_STATE_2, test_size_1 = TEST_SIZE_1, test_size_2 = TEST_SIZE_2)
76
+ X_train = splits['X_train']
77
+ X_val = splits['X_val']
78
+ X_test = splits['X_test']
79
+
80
+ # Make slices of clusters dataframe for train, val, and test
81
+ train_clusters = clusters.loc[clusters['representative seq_id'].isin(X_train)].reset_index(drop=True)
82
+ val_clusters = clusters.loc[clusters['representative seq_id'].isin(X_val)].reset_index(drop=True)
83
+ test_clusters = clusters.loc[clusters['representative seq_id'].isin(X_test)].reset_index(drop=True)
84
+
85
+ # Check validity
86
+ check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=benchmark_sequences)
87
+
88
+ # Print min and max sequence lengths
89
+ min_train_seqlen = min(train_clusters['member seq'].str.len())
90
+ max_train_seqlen = max(train_clusters['member seq'].str.len())
91
+ min_val_seqlen = min(val_clusters['member seq'].str.len())
92
+ max_val_seqlen = max(val_clusters['member seq'].str.len())
93
+ min_test_seqlen = min(test_clusters['member seq'].str.len())
94
+ max_test_seqlen = max(test_clusters['member seq'].str.len())
95
+ log_update(f"\nLength breakdown summary...\n\tTrain: min seq length = {min_train_seqlen}, max seq length = {max_train_seqlen}")
96
+ log_update(f"\tVal: min seq length = {min_val_seqlen}, max seq length = {max_val_seqlen}")
97
+ log_update(f"\tTest: min seq length = {min_test_seqlen}, max seq length = {max_test_seqlen}")
98
+
99
+ # Make plots to visualize the splits
100
+ visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps)
101
+
102
+ # cols = representative seq_id,member seq_id,representative seq,member seq
103
+ train_clusters.to_csv("../data/splits/train_cluster_split.csv",index=False)
104
+ val_clusters.to_csv("../data/splits/val_cluster_split.csv",index=False)
105
+ test_clusters.to_csv("../data/splits/test_cluster_split.csv",index=False)
106
+ log_update('\nSaved cluster splits to splitting/train_cluster_split.csv, splitting/val_cluster_split.csv, splitting/test_cluster_split.csv')
107
+ cols=','.join(list(train_clusters.columns))
108
+ log_update(f'\tColumns: {cols}')
109
+
110
+ # IF SnP vectors have been comptued already, make train_df, val_df, test_df: the data that will be input to the training script
111
+ train_df, val_df, test_df = get_training_dfs(train_clusters, val_clusters, test_clusters)
112
+ train_df.to_csv("../data/splits/train_df.csv",index=False)
113
+ val_df.to_csv("../data/splits/val_df.csv",index=False)
114
+ test_df.to_csv("../data/splits/test_df.csv",index=False)
115
+ log_update('\nSaved training dataframes to splits/train_df.csv, splits/val_df.csv, splits/test_df.csv')
116
+ cols=','.join(list(train_df.columns))
117
+ log_update(f'\tColumns: {cols}')
118
+
119
+ if __name__ == "__main__":
120
+ main()
fuson_plm/data/split_vis.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from scipy.stats import entropy
4
+ from sklearn.manifold import TSNE
5
+ import pickle
6
+ import pandas as pd
7
+ import os
8
+ from fuson_plm.utils.logging import log_update
9
+ from fuson_plm.utils.visualizing import set_font
10
+
11
+ def calculate_aa_composition(sequences):
12
+ composition = {}
13
+ total_length = sum([len(seq) for seq in sequences])
14
+
15
+ for seq in sequences:
16
+ for aa in seq:
17
+ if aa in composition:
18
+ composition[aa] += 1
19
+ else:
20
+ composition[aa] = 1
21
+
22
+ # Convert counts to relative frequency
23
+ for aa in composition:
24
+ composition[aa] /= total_length
25
+
26
+ return composition
27
+
28
+ def calculate_shannon_entropy(sequence):
29
+ """
30
+ Calculate the Shannon entropy for a given sequence.
31
+
32
+ Args:
33
+ sequence (str): A sequence of characters (e.g., amino acids or nucleotides).
34
+
35
+ Returns:
36
+ float: Shannon entropy value.
37
+ """
38
+ bases = set(sequence)
39
+ counts = [sequence.count(base) for base in bases]
40
+ return entropy(counts, base=2)
41
+
42
+ def visualize_splits_hist(train_lengths, val_lengths, test_lengths, colormap, savepath=f'../data/splits/length_distributions.png', axes=None):
43
+ log_update('\nMaking histogram of length distributions')
44
+ # Create a figure and axes with 1 row and 3 columns
45
+ if axes is None:
46
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
47
+
48
+ # Unpack the labels and titles
49
+ xlabel, ylabel = ['Sequence Length (AA)', 'Frequency']
50
+
51
+ # Plot the first histogram
52
+ axes[0].hist(train_lengths, bins=20, edgecolor='k',color=colormap['train'])
53
+ axes[0].set_xlabel(xlabel, fontsize=24)
54
+ axes[0].set_ylabel(ylabel, fontsize=24)
55
+ axes[0].set_title(f'Train Set Length Distribution (n={len(train_lengths)})', fontsize=24)
56
+ axes[0].grid(True)
57
+ axes[0].set_axisbelow(True)
58
+ axes[0].tick_params(axis='x', labelsize=24) # Customize x-axis tick label size
59
+ axes[0].tick_params(axis='y', labelsize=24) # Customize y-axis tick label size
60
+
61
+
62
+ # Plot the second histogram
63
+ axes[1].hist(val_lengths, bins=20, edgecolor='k',color=colormap['val'])
64
+ axes[1].set_xlabel(xlabel, fontsize=24)
65
+ axes[1].set_ylabel(ylabel, fontsize=24)
66
+ axes[1].set_title(f'Validation Set Length Distribution (n={len(val_lengths)})', fontsize=24)
67
+ axes[1].grid(True)
68
+ axes[1].set_axisbelow(True)
69
+ axes[1].tick_params(axis='x', labelsize=24)
70
+ axes[1].tick_params(axis='y', labelsize=24)
71
+
72
+ # Plot the third histogram
73
+ axes[2].hist(test_lengths, bins=20, edgecolor='k',color=colormap['test'])
74
+ axes[2].set_xlabel(xlabel, fontsize=24)
75
+ axes[2].set_ylabel(ylabel, fontsize=24)
76
+ axes[2].set_title(f'Test Set Length Distribution (n={len(test_lengths)})', fontsize=24)
77
+ axes[2].grid(True)
78
+ axes[2].set_axisbelow(True)
79
+ axes[2].tick_params(axis='x', labelsize=24)
80
+ axes[2].tick_params(axis='y', labelsize=24)
81
+
82
+ # Adjust layout
83
+ if savepath is not None:
84
+ plt.tight_layout()
85
+
86
+ # Save the figure
87
+ plt.savefig(savepath)
88
+
89
+ def visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap, savepath='../data/splits/scatterplot.png', ax=None):
90
+ log_update("\nMaking scatterplot with distribution of cluster sizes across train, test, and val")
91
+ # Make grouped versions of these DataFrames for size analysis
92
+ train_clustersgb = train_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
93
+ val_clustersgb = val_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
94
+ test_clustersgb = test_clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
95
+
96
+ # Isolate benchmark-containing clusters so their contribution can be plotted separately
97
+ total_test_proteins = sum(test_clustersgb['member count'])
98
+ test_clustersgb['benchmark cluster'] = test_clustersgb['representative seq_id'].isin(benchmark_cluster_reps)
99
+ benchmark_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']].reset_index(drop=True)
100
+ test_clustersgb = test_clustersgb.loc[test_clustersgb['benchmark cluster']==False].reset_index(drop=True)
101
+
102
+ # Convert them to value counts
103
+ train_clustersgb = train_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
104
+ val_clustersgb = val_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
105
+ test_clustersgb = test_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
106
+ benchmark_clustersgb = benchmark_clustersgb['member count'].value_counts().reset_index().rename(columns={'index':'cluster size (n_members)','member count': 'n_clusters'})
107
+
108
+ # Get the percentage of each dataset that's made of each cluster size
109
+ train_clustersgb['n_proteins'] = train_clustersgb['cluster size (n_members)']*train_clustersgb['n_clusters'] # proteins per cluster * n clusters = # proteins
110
+ train_clustersgb['percent_proteins'] = train_clustersgb['n_proteins']/sum(train_clustersgb['n_proteins'])
111
+ val_clustersgb['n_proteins'] = val_clustersgb['cluster size (n_members)']*val_clustersgb['n_clusters']
112
+ val_clustersgb['percent_proteins'] = val_clustersgb['n_proteins']/sum(val_clustersgb['n_proteins'])
113
+ test_clustersgb['n_proteins'] = test_clustersgb['cluster size (n_members)']*test_clustersgb['n_clusters']
114
+ test_clustersgb['percent_proteins'] = test_clustersgb['n_proteins']/total_test_proteins
115
+ benchmark_clustersgb['n_proteins'] = benchmark_clustersgb['cluster size (n_members)']*benchmark_clustersgb['n_clusters']
116
+ benchmark_clustersgb['percent_proteins'] = benchmark_clustersgb['n_proteins']/total_test_proteins
117
+
118
+ # Specially mark the benchmark clusters because these can't be reallocated
119
+ if ax is None:
120
+ fig, ax = plt.subplots(figsize=(18, 6))
121
+
122
+ ax.plot(train_clustersgb['cluster size (n_members)'],train_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['train'],label='train')
123
+ ax.plot(val_clustersgb['cluster size (n_members)'],val_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['val'],label='val')
124
+ ax.plot(test_clustersgb['cluster size (n_members)'],test_clustersgb['percent_proteins'],linestyle='None',marker='.',color=colormap['test'],label='test')
125
+ ax.plot(benchmark_clustersgb['cluster size (n_members)'],benchmark_clustersgb['percent_proteins'],
126
+ marker='o',
127
+ linestyle='None',
128
+ markerfacecolor=colormap['test'], # fill same as test
129
+ markeredgecolor='black', # outline black
130
+ markeredgewidth=1.5,
131
+ label='benchmark'
132
+ )
133
+ ax.set_ylabel('Percentage of Proteins in Dataset', fontsize=24)
134
+ ax.set_xlabel('Cluster Size', fontsize=24)
135
+ ax.tick_params(axis='x', labelsize=24) # Customize x-axis tick label size
136
+ ax.tick_params(axis='y', labelsize=24) # Customize y-axis tick label size
137
+
138
+ ax.legend(fontsize=24,markerscale=4)
139
+
140
+ # save the figure
141
+ if savepath is not None:
142
+ plt.tight_layout()
143
+ plt.savefig(savepath)
144
+ log_update(f"\tSaved figure to {savepath}")
145
+
146
+ def get_avg_embeddings_for_tsne(train_sequences, val_sequences, test_sequences, embedding_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl'):
147
+ embeddings = {}
148
+
149
+ try:
150
+ with open(embedding_path, 'rb') as f:
151
+ embeddings = pickle.load(f)
152
+
153
+ train_embeddings = [v for k, v in embeddings.items() if k in train_sequences]
154
+ val_embeddings = [v for k, v in embeddings.items() if k in val_sequences]
155
+ test_embeddings = [v for k, v in embeddings.items() if k in test_sequences]
156
+
157
+ return train_embeddings, val_embeddings, test_embeddings
158
+ except:
159
+ print("could not open embeddings")
160
+
161
+
162
+ def visualize_splits_tsne(train_sequences, val_sequences, test_sequences, colormap, esm_type="esm2_t33_650M_UR50D", embedding_path="fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl", savepath='../data/splits/tsne_plot.png',ax=None):
163
+ """
164
+ Generate a t-SNE plot of embeddings for train, test, and validation.
165
+ """
166
+ log_update('\nMaking t-SNE plot of train, val, and test embeddings')
167
+ # Combine the embeddings into one array
168
+ train_embeddings, val_embeddings, test_embeddings = get_avg_embeddings_for_tsne(train_sequences, val_sequences, test_sequences, embedding_path=embedding_path)
169
+ embeddings = np.concatenate([train_embeddings, val_embeddings, test_embeddings])
170
+
171
+ # Labels for the embeddings
172
+ labels = ['train'] * len(train_embeddings) + ['val'] * len(val_embeddings) + ['test'] * len(test_embeddings)
173
+
174
+ # Perform t-SNE
175
+ tsne = TSNE(n_components=2, random_state=42)
176
+ tsne_results = tsne.fit_transform(embeddings)
177
+
178
+ # Convert t-SNE results into a DataFrame
179
+ tsne_df = pd.DataFrame(data=tsne_results, columns=['TSNE_1', 'TSNE_2'])
180
+ tsne_df['label'] = labels
181
+
182
+ # Plotting
183
+ if ax is None:
184
+ fig, ax = plt.subplots(figsize=(10, 8))
185
+
186
+ # Scatter plot for each set
187
+ for label, color in colormap.items():
188
+ subset = tsne_df[tsne_df['label'] == label].reset_index(drop=True)
189
+ ax.scatter(subset['TSNE_1'], subset['TSNE_2'], c=color, label=label.capitalize(), alpha=0.6)
190
+
191
+ ax.set_title(f't-SNE of {esm_type} Embeddings')
192
+ ax.set_xlabel('t-SNE Dimension 1')
193
+ ax.set_ylabel('t-SNE Dimension 2')
194
+ ax.legend(fontsize=24, markerscale=2)
195
+ ax.grid(True)
196
+
197
+ # Save the figure if savepath is provided
198
+ if savepath:
199
+ plt.tight_layout()
200
+ fig.savefig(savepath)
201
+
202
+ def visualize_splits_shannon_entropy(train_sequences, val_sequences, test_sequences, colormap, savepath='../data/splits/shannon_entropy_plot.png',axes=None):
203
+ """
204
+ Generate Shannon entropy plots for train, validation, and test sets.
205
+ """
206
+ log_update('\nMaking histogram of Shannon Entropy distributions')
207
+ train_entropy = [calculate_shannon_entropy(seq) for seq in train_sequences]
208
+ val_entropy = [calculate_shannon_entropy(seq) for seq in val_sequences]
209
+ test_entropy = [calculate_shannon_entropy(seq) for seq in test_sequences]
210
+
211
+ if axes is None:
212
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
213
+
214
+ axes[0].hist(train_entropy, bins=20, edgecolor='k', color=colormap['train'])
215
+ axes[0].set_title(f'Train Set (n={len(train_entropy)})', fontsize=24)
216
+ axes[0].set_xlabel('Shannon Entropy', fontsize=24)
217
+ axes[0].set_ylabel('Frequency', fontsize=24)
218
+ axes[0].grid(True)
219
+ axes[0].set_axisbelow(True)
220
+ axes[0].tick_params(axis='x', labelsize=24)
221
+ axes[0].tick_params(axis='y', labelsize=24)
222
+
223
+ axes[1].hist(val_entropy, bins=20, edgecolor='k', color=colormap['val'])
224
+ axes[1].set_title(f'Validation Set (n={len(val_entropy)})', fontsize=24)
225
+ axes[1].set_xlabel('Shannon Entropy', fontsize=24)
226
+ axes[1].grid(True)
227
+ axes[1].set_axisbelow(True)
228
+ axes[1].tick_params(axis='x', labelsize=24)
229
+ axes[1].tick_params(axis='y', labelsize=24)
230
+
231
+ axes[2].hist(test_entropy, bins=20, edgecolor='k', color=colormap['test'])
232
+ axes[2].set_title(f'Test Set (n={len(test_entropy)})', fontsize=24)
233
+ axes[2].set_xlabel('Shannon Entropy', fontsize=24)
234
+ axes[2].grid(True)
235
+ axes[2].set_axisbelow(True)
236
+ axes[2].tick_params(axis='x', labelsize=24)
237
+ axes[2].tick_params(axis='y', labelsize=24)
238
+
239
+ if savepath is not None:
240
+ plt.tight_layout()
241
+ plt.savefig(savepath)
242
+
243
+ def visualize_splits_aa_composition(train_sequences, val_sequences, test_sequences,colormap, savepath='../data/splits/aa_comp.png',ax=None):
244
+ log_update('\nMaking bar plot of AA composition across each set')
245
+ train_comp = calculate_aa_composition(train_sequences)
246
+ val_comp = calculate_aa_composition(val_sequences)
247
+ test_comp = calculate_aa_composition(test_sequences)
248
+
249
+ # Create DataFrame
250
+ comp_df = pd.DataFrame([train_comp, val_comp, test_comp], index=['train', 'val', 'test']).T
251
+ colors = [colormap[col] for col in comp_df.columns]
252
+
253
+ # Plotting
254
+ #fig, ax = plt.subplots(figsize=(12, 6))
255
+ if ax is None:
256
+ fig, ax = plt.subplots(figsize=(12, 6))
257
+ else:
258
+ fig = ax.get_figure()
259
+
260
+ comp_df.plot(kind='bar', color=colors, ax=ax)
261
+ ax.set_title('Amino Acid Composition Across Datasets', fontsize=24)
262
+ ax.set_xlabel('Amino Acid', fontsize=24)
263
+ ax.set_ylabel('Relative Frequency', fontsize=24)
264
+ ax.tick_params(axis='x', labelsize=24) # Customize x-axis tick label size
265
+ ax.tick_params(axis='y', labelsize=24) # Customize y-axis tick label size
266
+ ax.legend(fontsize=16, markerscale=2)
267
+
268
+ if savepath is not None:
269
+ fig.savefig(savepath)
270
+
271
+ def visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, train_color='#0072B2',val_color='#009E73',test_color='#E69F00',esm_embeddings_path=None, onehot_embeddings_path=None):
272
+ colormap = {
273
+ 'train': train_color,
274
+ 'val': val_color,
275
+ 'test': test_color
276
+ }
277
+ # Add columns for plotting
278
+ train_clusters['member length'] = train_clusters['member seq'].str.len()
279
+ val_clusters['member length'] = val_clusters['member seq'].str.len()
280
+ test_clusters['member length'] = test_clusters['member seq'].str.len()
281
+
282
+ # Prepare lengths and seqs for plotting
283
+ train_lengths = train_clusters['member length'].tolist()
284
+ val_lengths = val_clusters['member length'].tolist()
285
+ test_lengths = test_clusters['member length'].tolist()
286
+ train_sequences = train_clusters['member seq'].tolist()
287
+ val_sequences = val_clusters['member seq'].tolist()
288
+ test_sequences = test_clusters['member seq'].tolist()
289
+
290
+ # Create a combined figure with 3 rows and 3 columns
291
+ fig_combined, axs = plt.subplots(3, 3, figsize=(24, 18))
292
+
293
+ # Make the three visualization plots for saving TOGETHER
294
+ visualize_splits_hist(train_lengths,val_lengths,test_lengths,colormap, savepath=None,axes=axs[0])
295
+ visualize_splits_shannon_entropy(train_sequences,val_sequences,test_sequences,colormap,savepath=None,axes=axs[1])
296
+ visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap, savepath=None, ax=axs[2, 0])
297
+ visualize_splits_aa_composition(train_sequences,val_sequences,test_sequences, colormap, savepath=None, ax=axs[2, 1])
298
+ if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
299
+ visualize_splits_tsne(train_sequences, val_sequences, test_sequences, colormap, savepath=None, ax=axs[2, 2])
300
+ else:
301
+ # Leave the last subplot blank
302
+ axs[2, 2].axis('off')
303
+
304
+ plt.tight_layout()
305
+ fig_combined.savefig('../data/splits/combined_plot.png')
306
+
307
+ # Make the three visualization plots for saving separately
308
+ visualize_splits_hist(train_clusters['member length'].tolist(), val_clusters['member length'].tolist(), test_clusters['member length'].tolist(),colormap)
309
+ visualize_splits_scatter(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps, colormap)
310
+ visualize_splits_aa_composition(train_clusters['member seq'].tolist(), val_clusters['member seq'].tolist(), test_clusters['member seq'].tolist(),colormap)
311
+ visualize_splits_shannon_entropy(train_sequences,val_sequences,test_sequences,colormap)
312
+ if not(esm_embeddings_path is None) and os.path.exists(esm_embeddings_path):
313
+ visualize_splits_tsne(train_clusters['member seq'].tolist(), val_clusters['member seq'].tolist(), test_clusters['member seq'].tolist(),colormap)
314
+
315
+ def main():
316
+ set_font()
317
+ train_clusters = pd.read_csv('splits/train_cluster_split.csv')
318
+ val_clusters = pd.read_csv('splits/val_cluster_split.csv')
319
+ test_clusters = pd.read_csv('splits/test_cluster_split.csv')
320
+
321
+ clusters = pd.concat([train_clusters,val_clusters,test_clusters])
322
+
323
+ fuson_db = pd.read_csv('fuson_db.csv')
324
+ # Get the sequence IDs of all clustered benchmark sequences.
325
+ benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id']
326
+ # Use benchmark_seq_ids to find which clusters contain benchmark sequences.
327
+ benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist()
328
+
329
+ visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps,
330
+ esm_embeddings_path='fuson_db_embeddings/fuson_db_esm2_t33_650M_UR50D_avg_embeddings.pkl', onehot_embeddings_path=None)
331
+
332
+ if __name__ == "__main__":
333
+ main()
fuson_plm/data/splits/combined_plot.png ADDED
fuson_plm/data/splits/test_cluster_split.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed561338b86264b420b12e37e92c8c434b3119081a9a02a7688c1934343ee5fb
3
+ size 5628545
fuson_plm/data/splits/test_df.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9590919f2c09430dc6f8b617b8a738f7e174fb04fdae1f35ceaa0351ea05612f
3
+ size 32236663
fuson_plm/data/splits/train_cluster_split.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1be258917ebd87f7cf71223e6fa340e7b6228464c42ae0b630c24efea8d2bd14
3
+ size 44850849
fuson_plm/data/splits/train_df.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddf563734014c7d4fec944c639431d3423d2ab79e1e6e9e800c955c24438c8eb
3
+ size 257270565