pilev2_pipeline / app.py
ncoop57
Make path a Path obj
898b5fd
raw
history blame
3.23 kB
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from datasets import load_dataset
from pathlib import Path
# get secret environment variable
token = os.environ["HF_TOKEN"]
# write to disk
path = Path("./huggingface")
path.mkdir(parents=True, exist_ok=True)
with open(path/"token", "w") as f:
f.write(token)
dataset_names = [
"AI4Code",
"AMPS",
"ASFPublicMail",
"CPDataset",
"DMMath",
"Discourse",
"Enwiki",
"EuroParliamentProceedings",
"FreeLaw_Options",
"GithubDiff",
"GithubIssues",
"Gutenberg",
"LeetCode",
"PileOfLaw",
"PubMed",
"S2ORC",
"StackExchange",
"USENET",
"USPTO",
"UbuntuIRC",
"arXiv",
]
dataset_data = {}
for name in dataset_names:
path = f"data/{name}/data.json"
ds = load_dataset(
"CarperAI/pilev2_smol_metadata",
data_files=path,
use_auth_token=True,
split="train",
# download_mode="force_redownload",
)
dataset_data[name] = {
"ds": ds,
"word_rep_ratios": np.random.randn(len(ds)),
"char_rep_ratios": np.array(ds["check_char_repetition_criteria"]),
"flagged_word_ratios": np.array(ds["check_flagged_words_criteria"]),
}
def plt_plot(ratio, dataset, threshold):
plt.close("all")
x = dataset_data[dataset][ratio]
# calculate percentage of data that will be removed given threshold
perc = np.sum(x > threshold) / len(x)
# create a figure
fig = plt.figure()
# add a subplot
ax = fig.add_subplot(111)
# plot some data using black
ax.hist(x, bins=50, color="black")
# plot red dashed line at threshold
ax.axvline(threshold, color='r', linestyle='dashed', linewidth=2)
# set title
# add percentage of data removed
ax.set_title(f"{dataset} (removed {perc:.2%})")
plt.xlabel("Value")
plt.ylabel("Frequency")
# make it look nice
plt.tight_layout()
return fig
def check_filtered():
...
with gr.Blocks() as demo:
dataset = gr.Radio(dataset_names, label="Dataset", value="arXiv")
print(dataset.value)
with gr.Tab("Character Repetition Ratio"):
# plot some random data
plot = gr.Plot()
threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
calculate = gr.Button("Calculate")
check = gr.Button("Check Filtered Data")
plot_fn = partial(plt_plot, "char_rep_ratios")
calculate.click(plot_fn, [dataset, threshold], plot)
with gr.Tab("Word Repetition Ratio"):# plot some random data
plot = gr.Plot()
threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
calculate = gr.Button("Calculate")
plot_fn = partial(plt_plot, "word_rep_ratios")
calculate.click(plot_fn, [dataset, threshold], plot)
with gr.Tab("Flagged Word Ratio"):# plot some random data
plot = gr.Plot()
threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
calculate = gr.Button("Calculate")
plot_fn = partial(plt_plot, "flagged_word_ratios")
calculate.click(plot_fn, [dataset, threshold], plot)
if __name__ == "__main__":
demo.launch()