pilev2_pipeline / app.py
ncoop57's picture
Fix dataset loading bug (#1)
3e4a220
raw
history blame
11.4 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
import datasets
from datasets import load_dataset
ai4code_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/AI4Code/data.json")
amps_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/AMPS/data.json")
apache_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/ASFPublicMail/data.json")
books3_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Books3/data.json")
cp_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/CPDataset/data.json")
dmmath_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/DMMath/data.json")
discourse_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Discourse/data.json")
wiki_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Enwiki/data.json")
euro_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/EuroParliamentProceedings/data.json")
freelaw_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/FreeLaw_Options/data.json")
ghdiffs_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/GitHubDiff/data.json")
ghissues_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/GitHubIssues/data.json")
gutenberg_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Gutenberg/data.json")
leet_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/LeetCode/data.json")
pileoflaw_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/PileOfLaw/data.json")
pubmed_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/PubMed/data.json")
s2orc_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/S2ORC/data.json")
se_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/StackExchange/data.json")
usenet_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/USENET/data.json")
uspto_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/USPTO/data.json")
ubuntuirc_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/UbuntuIRC/data.json")
arxiv_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/arXiv/data.json")
dataset_data = {
"ai4code" : ai4code_ds["train"],
"amps" : amps_ds["train"],
"apache" : apache_ds["train"],
"books3" : books3_ds["train"],
"competitive_programming" : cp_ds["train"],
"dmmath" : dmmath_ds["train"],
"discourse" : discourse_ds["train"],
"enwiki" : wiki_ds["train"],
"euro" : euro_ds["train"],
"freelaw" : freelaw_ds["train"],
"ghdiffs" : ghdiffs_ds["train"],
"ghissues" : ghissues_ds["train"],
"gutenberg" : gutenberg_ds["train"],
"leetcode" : leet_ds["train"],
"pileoflaw" : pileoflaw_ds["train"],
"pubmed" : pubmed_ds["train"],
"s2orc" : s2orc_ds["train"],
"se" : se_ds["train"],
"usenet" : usenet_ds["train"],
"uspto" : uspto_ds["train"],
"ubuntuirc" : ubuntuirc_ds["train"],
"arxiv" : arxiv_ds["train"]
}
# dataset_data = {
# "AI4Code": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "AMPS": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "ASFPublicMail": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "Books3": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "CPDataset": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "DMMath": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "Discourse": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "Enwiki": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "EuroParliamentProceedings": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "FreeLaw_Options": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "GitHubDiff": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "GitHubIssues": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "Gutenberg": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "LeetCode": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "PileOfLaw": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "PubMed": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "S2ORC": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "StackExchange": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "USENET": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "USPTO": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "UbuntuIRC": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# "arXiv": {
# # create fake data for the different ratios
# "word_rep_ratios": np.random.randn(1000),
# "char_rep_ratios": np.random.randn(1000),
# "flagged_word_ratios": np.random.randn(1000),
# "num_words": np.random.randint(0, 1000, 1000),
# },
# }
def plt_plot(ratio, dataset, threshold):
x = dataset_data[dataset][ratio]
# calculate percentage of data that will be removed given threshold
perc = np.sum(x < threshold) / len(x)
# create a figure
fig = plt.figure()
# add a subplot
ax = fig.add_subplot(111)
# plot some data using black
ax.hist(x, bins=50, color="black")
# plot red dashed line at threshold
ax.axvline(threshold, color='r', linestyle='dashed', linewidth=2)
# set title
# add percentage of data removed
ax.set_title(f"{dataset} (removed {perc:.2%})")
plt.xlabel("Value")
plt.ylabel("Frequency")
# make it look nice
plt.tight_layout()
return fig
with gr.Blocks() as demo:
dataset = gr.Radio(list(dataset_data.keys()), label="Dataset", value="arXiv")
print(dataset.value)
with gr.Tab("Character Repetition Ratio"):
# plot some random data
plot = gr.Plot()
threshold = gr.Slider(minimum=0, maximum=100, label="Threshold")
calculate = gr.Button("Calculate")
plot_fn = partial(plt_plot, "word_rep_ratios")
calculate.click(plot_fn, [dataset, threshold], plot)
with gr.Tab("Word Repetition Ratio"):# plot some random data
plot = gr.Plot()
threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
calculate = gr.Button("Calculate")
plot_fn = partial(plt_plot, "char_rep_ratios")
calculate.click(plot_fn, [dataset, threshold], plot)
with gr.Tab("Flagged Word Ratio"):# plot some random data
plot = gr.Plot()
threshold = gr.Slider(minimum=0, maximum=1, label="Threshold")
calculate = gr.Button("Calculate")
plot_fn = partial(plt_plot, "flagged_word_ratios")
calculate.click(plot_fn, [dataset, threshold], plot)
if __name__ == "__main__":
demo.launch(share=True)