File size: 4,387 Bytes
dcc5cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import argparse
import os
from scipy.stats import gaussian_kde
import numpy as np

def get_model_for(doc_type: str, override_model: str) -> str:
    """Returns model type or the override model if specified"""
    if override_model:
        return override_model
    doc_type = doc_type.split("_", 1)[0]
    if doc_type in ("book", "books", "pg19"):
        return "books_pp"
    elif doc_type in ("culturax", "slimpajama", "wikipedia", "digimanus"):
        return "wikipedia_pp"
    elif doc_type in ("newspaper", "newspapers"):
        return "newspapers_pp"
    elif doc_type in ("evalueringsrapport", "lovdata", "maalfrid", "parlamint"):
        return "maalfrid_pp"
    else:
        return "wikipedia_pp"

def load_data(files):
    all_data = []
    for file_path in files:
        with open(file_path, 'r') as file:
            lines = file.readlines()
            data = [json.loads(line) for line in lines]
            all_data.extend(data)
    return pd.DataFrame(all_data)

def plot_histograms(files, output_folder, xlim, override_model):
    df = load_data(files)
    doc_types = df['doctype'].unique()
    fig, axes = plt.subplots(len(doc_types), 1, figsize=(12, 4 * len(doc_types)), squeeze=False)
    
    # Set up a color palette
    palette = sns.color_palette("husl", len(doc_types))
    
    for i, doc_type in enumerate(doc_types):
        ax = axes[i, 0]
        group = df[df['doctype'] == doc_type]
        languages = group['lang'].unique()
        
        # Prepare a unique color for each language within the document type
        colors = sns.color_palette("husl", len(languages))
        
        for j, lang in enumerate(languages):
            lang_group = group[group['lang'] == lang]
            perplexity_model = get_model_for(doc_type, override_model)
            perplexity_values = lang_group['perplexities'].apply(lambda x: x[perplexity_model]).values
            
            series_color = colors[j]
            
            # Plot histogram with lighter color
            sns.histplot(perplexity_values, ax=ax, color=series_color, alpha=0.3, element="step", fill=True, stat="density", binwidth=30)

            # Plot KDE without filling
            sns.kdeplot(perplexity_values, ax=ax, bw_adjust=2, color=series_color, label=f"{lang} - {doc_type} ({perplexity_model})", linewidth=1.5)
            
            
            kde = gaussian_kde(perplexity_values)
            x_range = np.linspace(0, xlim, 1000)
            y_values = kde.evaluate(x_range)
            
            quartiles = np.quantile(perplexity_values, [0.25, 0.5, 0.75])
            quartile_labels = ["Q1", "Q2", "Q3"]
            for q, quartile in enumerate(quartiles):
                idx = (np.abs(x_range-quartile)).argmin()
                y_quartile = y_values[idx]
                ax.plot([quartile, quartile], [0, y_quartile], color=series_color, linestyle='--', linewidth=1)
                ax.text(quartile, y_quartile, f'{quartile_labels[q]}: {quartile:.2f}', verticalalignment='bottom', horizontalalignment='right', color=series_color, fontsize=6)
            
            ax.set_title(f'Document Type: {doc_type} ({perplexity_model})')
            ax.set_xlabel('Perplexity Value')
            ax.set_ylabel('Density')
            ax.legend()
            ax.set_xlim(left=0, right=xlim)
        
    plt.tight_layout()
    output_filename = os.path.join(output_folder, "all_doc_types_plots.png")
    plt.savefig(output_filename, dpi=300)
    plt.close(fig)
    print(f"All document type plots saved to {output_filename}")

def main():
    parser = argparse.ArgumentParser(description="Plot histograms from JSON lines files.")
    parser.add_argument('files', nargs='+', help="Path to the JSON lines files")
    parser.add_argument('-o', '--output_folder', default=".", help="Output folder for the plots")
    parser.add_argument('--xlim', type=int, default=2500, help="Maximum x-axis limit for the plots")
    parser.add_argument('--model', default="", help="Override the perplexity model for all plots")
    
    args = parser.parse_args()
    
    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder, exist_ok=True)
        
    plot_histograms(args.files, args.output_folder, args.xlim, args.model)

if __name__ == "__main__":
    main()