File size: 6,999 Bytes
77ba698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"def get_setting(name):\n",
" if \"terminal-punct\" in name:\n",
" return {\"x\": \"Fraction of lines ended with punctuation\", \"ylim\": (0, 0.1)}\n",
" \n",
" if \"line-dedup\" in name:\n",
" return {\"x\": \"Fraction of chars in duplicated lines\", \"xlim\": (0, 0.1), \"ylim\": (0,0.02)}\n",
" \n",
" if \"short-line\" in name:\n",
" return {\"x\": \"Fraction of lines shorter than 30 chars\", \"xlim\": (0.4, 1.0), \"ylim\": (0,0.05)}\n",
" \n",
" if \"avg_words_per_line\" in name:\n",
" return {\"x\": \"Avg. words per line\", \"x-log\": True, \"x-log\": True, \"round\": 0}\n",
" if \"avg_line_length\" in name:\n",
" return {\"x\": \"Avg. words per line\", \"x-log\": True, \"round\": 0}\n",
" \n",
" if \"global-length.json\" == name:\n",
" return {\"x\": \"Num. UTF-8 chars\", \"x-log\": True}\n",
" \n",
" if \"global-digit_ratio.json\" == name:\n",
" return {\"x\": \"Digit ratio\", \"xlim\": (0, 0.25)}\n",
" \n",
" if \"global-avg_word_length.json\" == name:\n",
" return {\"x\": \"Avg. word length\", \"xlim\": (2.5, 6.5)}\n",
"\n",
" \n",
" raise ValueError(f\"Unknown dataset name: {name}\")\n",
"\n",
"\n",
"def plot_scatter(data):\n",
" \"\"\"\n",
" Plot scatter plots with smoothing for each dataset in the data list on a single grid.\n",
" Each dataset is expected to be a dictionary with the first key as the dataset name,\n",
" and the value as another dictionary where keys are data points and values are their counts.\n",
" \"\"\"\n",
" import matplotlib.pyplot as plt\n",
" import numpy as np\n",
"\n",
" # Determine the number of plots and create a subplot grid\n",
" num_datasets = len(data)\n",
" cols = 2 # Define number of columns in the grid\n",
" rows = (num_datasets) // cols # Calculate the required number of rows\n",
" fig, axs = plt.subplots(rows, cols, figsize=(8 * cols, 3 * rows), dpi=350)\n",
" if rows * cols > 1:\n",
" axs = axs.flatten() # Flatten the array of axes if more than one subplot\n",
" else:\n",
" axs = [axs] # Encapsulate the single AxesSubplot object into a list for uniform handling\n",
"\n",
" plot_index = 0\n",
" legend_handles = [] # List to store handles for the legend\n",
" legend_labels = [] # List to store labels for the legend\n",
" for name, dataset in data.items():\n",
" setting = get_setting(name)\n",
" ax = axs[plot_index]\n",
" if \"name\" in setting:\n",
" ax.set_title(setting[\"name\"])\n",
" if \"x\" in setting:\n",
" ax.set_xlabel(setting[\"x\"])\n",
" if \"xlim\" in setting:\n",
" ax.set_xlim(setting[\"xlim\"])\n",
" if \"ylim\" in setting:\n",
" ax.set_ylim(setting[\"ylim\"])\n",
" if \"x-log\" in setting:\n",
" ax.set_xscale('log')\n",
"\n",
" # Use 2 decimal places for the y-axis labels\n",
" ax.yaxis.set_major_formatter('{x:.3f}')\n",
"\n",
"\n",
" plot_index += 1\n",
" # Each dataset may contain multiple lines\n",
" for i, (line_name, line_data) in enumerate(dataset.items()):\n",
" if \"round\" in setting:\n",
" tmp_line_data = defaultdict(list)\n",
" for p, p_v in line_data.items():\n",
" rounded_key = str(round(float(p), setting[\"round\"]))\n",
" tmp_line_data[rounded_key].append(p_v)\n",
"\n",
" # If you want to sum the values that have the same rounded key\n",
" tmp_line_data = {k: sum(v) for k, v in tmp_line_data.items()}\n",
" line_data = tmp_line_data\n",
" \n",
" # Check that if you sum the values you get 1\n",
" assert sum(line_data.values()) == 1\n",
"\n",
" # Add smoothing for 4-5 points\n",
" # Implementing smoothing using a rolling window\n",
" line_name = rename_dataset(line_name)\n",
" # Sorting the line data by keys\n",
" sorted_line_data = dict(sorted(line_data.items(), key=lambda item: float(item[0])))\n",
"\n",
" window_size = setting.get(\"window_size\", 5) # Define the window size for smoothing\n",
" x = np.array(list(sorted_line_data.keys()), dtype=float)\n",
" y = np.array(list(sorted_line_data.values()), dtype=float)\n",
" if len(y) >= window_size: # Ensure there are enough points to apply smoothing\n",
" # Convert y to a pandas Series to use rolling function\n",
" y_series = pd.Series(y)\n",
" # Apply rolling window and mean to smooth the data\n",
" y_smoothed = y_series.rolling(window=window_size).mean()\n",
" # Drop NaN values that result from the rolling mean calculation\n",
" y_smoothed = y_smoothed.dropna()\n",
" # Update x to correspond to the length of the smoothed y\n",
" x = x[len(x) - len(y_smoothed):]\n",
" y = y_smoothed.to_numpy() # Convert back to numpy array for plotting\n",
"\n",
"\n",
"\n",
" # Use the line name as the label to unify same line names across different plots\n",
"\n",
" line, = ax.plot(x, y, label=line_name) # Use default colors\n",
" if line_name not in legend_labels:\n",
" legend_handles.append(line)\n",
" legend_labels.append(line_name)\n",
"\n",
" # Place a single shared legend on the top of the figure\n",
" fig.legend(handles=legend_handles, labels=legend_labels, loc='lower center', ncol=1)\n",
" for ax in axs:\n",
" ax.set_ylabel('Document Frequency')\n",
"\n",
" fig.suptitle(\"Histograms of selected statistics\")\n",
" plt.tight_layout(rect=[0, 0.15, 1, 1]) # Adjust the layout to make room for the legend\n",
" fig.set_size_inches(13, 6) # Set the figure size to 18 inches by 12 inches\n",
" plt.show()\n",
"\n",
"plot_scatter(data)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|