File size: 6,999 Bytes
77ba698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "def get_setting(name):\n",
    "    if \"terminal-punct\" in name:\n",
    "        return {\"x\": \"Fraction of lines ended with punctuation\", \"ylim\": (0, 0.1)}\n",
    "    \n",
    "    if \"line-dedup\" in name:\n",
    "        return {\"x\": \"Fraction of chars in duplicated lines\", \"xlim\": (0, 0.1), \"ylim\": (0,0.02)}\n",
    "    \n",
    "    if \"short-line\" in name:\n",
    "        return {\"x\": \"Fraction of lines shorter than 30 chars\", \"xlim\": (0.4, 1.0), \"ylim\": (0,0.05)}\n",
    "    \n",
    "    if \"avg_words_per_line\" in name:\n",
    "        return {\"x\": \"Avg. words per line\", \"x-log\": True, \"x-log\": True, \"round\": 0}\n",
    "    if \"avg_line_length\" in name:\n",
    "        return {\"x\": \"Avg. words per line\", \"x-log\": True, \"round\": 0}\n",
    "    \n",
    "    if \"global-length.json\" == name:\n",
    "        return {\"x\": \"Num. UTF-8 chars\", \"x-log\": True}\n",
    "    \n",
    "    if \"global-digit_ratio.json\" == name:\n",
    "        return {\"x\": \"Digit ratio\", \"xlim\": (0, 0.25)}\n",
    "    \n",
    "    if \"global-avg_word_length.json\" == name:\n",
    "        return {\"x\": \"Avg. word length\", \"xlim\": (2.5, 6.5)}\n",
    "\n",
    "    \n",
    "    raise ValueError(f\"Unknown dataset name: {name}\")\n",
    "\n",
    "\n",
    "def plot_scatter(data):\n",
    "    \"\"\"\n",
    "    Plot scatter plots with smoothing for each dataset in the data list on a single grid.\n",
    "    Each dataset is expected to be a dictionary with the first key as the dataset name,\n",
    "    and the value as another dictionary where keys are data points and values are their counts.\n",
    "    \"\"\"\n",
    "    import matplotlib.pyplot as plt\n",
    "    import numpy as np\n",
    "\n",
    "    # Determine the number of plots and create a subplot grid\n",
    "    num_datasets = len(data)\n",
    "    cols = 2  # Define number of columns in the grid\n",
    "    rows = (num_datasets) // cols  # Calculate the required number of rows\n",
    "    fig, axs = plt.subplots(rows, cols, figsize=(8 * cols, 3 * rows), dpi=350)\n",
    "    if rows * cols > 1:\n",
    "        axs = axs.flatten()  # Flatten the array of axes if more than one subplot\n",
    "    else:\n",
    "        axs = [axs]  # Encapsulate the single AxesSubplot object into a list for uniform handling\n",
    "\n",
    "    plot_index = 0\n",
    "    legend_handles = []  # List to store handles for the legend\n",
    "    legend_labels = []  # List to store labels for the legend\n",
    "    for name, dataset in data.items():\n",
    "        setting = get_setting(name)\n",
    "        ax = axs[plot_index]\n",
    "        if \"name\" in setting:\n",
    "            ax.set_title(setting[\"name\"])\n",
    "        if \"x\" in setting:\n",
    "            ax.set_xlabel(setting[\"x\"])\n",
    "        if \"xlim\" in setting:\n",
    "            ax.set_xlim(setting[\"xlim\"])\n",
    "        if \"ylim\" in setting:\n",
    "            ax.set_ylim(setting[\"ylim\"])\n",
    "        if \"x-log\" in setting:\n",
    "            ax.set_xscale('log')\n",
    "\n",
    "        # Use 2 decimal places for the y-axis labels\n",
    "        ax.yaxis.set_major_formatter('{x:.3f}')\n",
    "\n",
    "\n",
    "        plot_index += 1\n",
    "        # Each dataset may contain multiple lines\n",
    "        for i, (line_name, line_data) in enumerate(dataset.items()):\n",
    "            if \"round\" in setting:\n",
    "                tmp_line_data = defaultdict(list)\n",
    "                for p, p_v in line_data.items():\n",
    "                    rounded_key = str(round(float(p), setting[\"round\"]))\n",
    "                    tmp_line_data[rounded_key].append(p_v)\n",
    "\n",
    "                # If you want to sum the values that have the same rounded key\n",
    "                tmp_line_data = {k: sum(v) for k, v in tmp_line_data.items()}\n",
    "                line_data = tmp_line_data\n",
    "            \n",
    "            # Check that if you sum the values you get 1\n",
    "            assert sum(line_data.values()) == 1\n",
    "\n",
    "            # Add smoothing for 4-5 points\n",
    "            # Implementing smoothing using a rolling window\n",
    "            line_name = rename_dataset(line_name)\n",
    "            # Sorting the line data by keys\n",
    "            sorted_line_data = dict(sorted(line_data.items(), key=lambda item: float(item[0])))\n",
    "\n",
    "            window_size = setting.get(\"window_size\", 5)  # Define the window size for smoothing\n",
    "            x = np.array(list(sorted_line_data.keys()), dtype=float)\n",
    "            y = np.array(list(sorted_line_data.values()), dtype=float)\n",
    "            if len(y) >= window_size:  # Ensure there are enough points to apply smoothing\n",
    "                # Convert y to a pandas Series to use rolling function\n",
    "                y_series = pd.Series(y)\n",
    "                # Apply rolling window and mean to smooth the data\n",
    "                y_smoothed = y_series.rolling(window=window_size).mean()\n",
    "                # Drop NaN values that result from the rolling mean calculation\n",
    "                y_smoothed = y_smoothed.dropna()\n",
    "                # Update x to correspond to the length of the smoothed y\n",
    "                x = x[len(x) - len(y_smoothed):]\n",
    "                y = y_smoothed.to_numpy()  # Convert back to numpy array for plotting\n",
    "\n",
    "\n",
    "\n",
    "            # Use the line name as the label to unify same line names across different plots\n",
    "\n",
    "            line, = ax.plot(x, y, label=line_name)  # Use default colors\n",
    "            if line_name not in legend_labels:\n",
    "                legend_handles.append(line)\n",
    "                legend_labels.append(line_name)\n",
    "\n",
    "    # Place a single shared legend on the top of the figure\n",
    "    fig.legend(handles=legend_handles, labels=legend_labels, loc='lower center', ncol=1)\n",
    "    for ax in axs:\n",
    "        ax.set_ylabel('Document Frequency')\n",
    "\n",
    "    fig.suptitle(\"Histograms of selected statistics\")\n",
    "    plt.tight_layout(rect=[0, 0.15, 1, 1])  # Adjust the layout to make room for the legend\n",
    "    fig.set_size_inches(13, 6)  # Set the figure size to 18 inches by 12 inches\n",
    "    plt.show()\n",
    "\n",
    "plot_scatter(data)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}