Fill-Mask
Transformers
Safetensors
esm
File size: 3,226 Bytes
1e6a1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6efd653
1e6a1f0
 
 
6efd653
 
 
 
 
 
1e6a1f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from fuson_plm.utils.visualizing import set_font

global pos_id_label_dict 
pos_id_label_dict = {
    'top_UniProt_nIdentities': 'Identities',
    'top_UniProt_nPositives': 'Positives'   # Just makes it easier to label these on plots
}       

def plot_pos_or_id_pcnt_hist(data, column_name, save_path=None, ax=None):
    """
    column_name is Positives or Identities
    """
    set_font()
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 7))
       
    # Make the sample data 
    data = data[['seq_id','aa_seq_len', column_name]].dropna()  # only keep those with alignments
    data[column_name] = data[column_name]*100 # so it's % 
    data[f"{column_name} Percent Coverage"] = data[column_name] / data['aa_seq_len']
    
    # Save this sample data 
    source_data_save_path = save_path.replace(".png","_source_data.csv")
    source_data = data[['seq_id',f"{column_name} Percent Coverage"]].sort_values(by=f"{column_name} Percent Coverage",ascending=True)
    source_data[f"{column_name} Percent Coverage"] = source_data[f"{column_name} Percent Coverage"].round(3)
    source_data.to_csv(source_data_save_path,index=False)
    
    # Calculate the mean and median of the percent coverage
    mean_coverage = data[f"{column_name} Percent Coverage"].mean()
    median_coverage = data[f"{column_name} Percent Coverage"].median()

    # Plot histogram for percent coverage
    ax.hist(data[f"{column_name} Percent Coverage"], bins=50, edgecolor='grey', alpha=0.8, color='mediumpurple')

    # Add vertical line for the mean
    ax.axvline(mean_coverage, color='black', linestyle='--', linewidth=2)
    
    # Add vertical line for the median
    ax.axvline(median_coverage, color='black', linestyle='-', linewidth=2)

    # Add text label for the mean line
    ax.text(mean_coverage, ax.get_ylim()[1] * 0.9, f'Mean: {mean_coverage:.1f}%', color='black', 
            ha='center', va='top', fontsize=40, backgroundcolor='white')

    # Add text label for the median line
    ax.text(median_coverage, ax.get_ylim()[1] * 0.8, f'Median: {median_coverage:.1f}%', color='black', 
            ha='center', va='top', fontsize=40, backgroundcolor='white')

    # Labels and title
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)
    ax.set_xlabel(f"Max % {pos_id_label_dict[column_name]}", fontsize=40)
    ax.set_ylabel("Count", fontsize=40)
    #ax.set_title(f"{pos_id_label_dict[column_name]} Percent Coverage (n={len(data):,})", fontsize=40)
    
    plt.tight_layout()

    # Save the plot
    if save_path is not None:
        plt.savefig(save_path, dpi=300)
    
    # Show the plot if no ax is provided
    if ax is None:
        plt.show()
        
def group_pos_id_plot(data):
    set_font()
    
    plot_pos_or_id_pcnt_hist(data, 'top_UniProt_nIdentities', save_path=f"figures/identities_hist.png", ax=None)
    
def main():
    swissprot_top_alignments_df = pd.read_csv("blast_outputs/swissprot_top_alignments.csv")
    plot_pos_or_id_pcnt_hist(swissprot_top_alignments_df, 
                             'top_UniProt_nIdentities', save_path=f"figures/identities_hist.png", ax=None)

if __name__ == '__main__':
    main()