import gradio as gr import matplotlib.pyplot as plt import numpy as np from functools import partial import datasets from datasets import load_dataset ai4code_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/AI4Code/data.json", use_auth_token=True) amps_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/AMPS/data.json", use_auth_token=True) apache_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/ASFPublicMail/data.json", use_auth_token=True) books3_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/Books3/data.json", use_auth_token=True) cp_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/CPDataset/data.json", use_auth_token=True) dmmath_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/DMMath/data.json", use_auth_token=True) discourse_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/Discourse/data.json", use_auth_token=True) wiki_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/Enwiki/data.json") euro_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/EuroParliamentProceedings/data.json", use_auth_token=True) freelaw_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/FreeLaw_Options/data.json", use_auth_token=True) ghdiffs_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/GitHubDiff/data.json", use_auth_token=True) ghissues_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/GitHubIssues/data.json", use_auth_token=True) gutenberg_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/Gutenberg/data.json", use_auth_token=True) leet_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/LeetCode/data.json", use_auth_token=True) pileoflaw_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/PileOfLaw/data.json", use_auth_token=True) pubmed_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/PubMed/data.json", use_auth_token=True) s2orc_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/S2ORC/data.json", use_auth_token=True) se_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/StackExchange/data.json", use_auth_token=True) usenet_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/USENET/data.json", use_auth_token=True) uspto_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/USPTO/data.json", use_auth_token=True) ubuntuirc_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/UbuntuIRC/data.json", use_auth_token=True) arxiv_ds = load_dataset("CarperAI/pile-v2-small", data_files="data/arXiv/data.json", use_auth_token=True) dataset_data = { "ai4code" : ai4code_ds["train"], "amps" : amps_ds["train"], "apache" : apache_ds["train"], "books3" : books3_ds["train"], "competitive_programming" : cp_ds["train"], "dmmath" : dmmath_ds["train"], "discourse" : discourse_ds["train"], "enwiki" : wiki_ds["train"], "euro" : euro_ds["train"], "freelaw" : freelaw_ds["train"], "ghdiffs" : ghdiffs_ds["train"], "ghissues" : ghissues_ds["train"], "gutenberg" : gutenberg_ds["train"], "leetcode" : leet_ds["train"], "pileoflaw" : pileoflaw_ds["train"], "pubmed" : pubmed_ds["train"], "s2orc" : s2orc_ds["train"], "se" : se_ds["train"], "usenet" : usenet_ds["train"], "uspto" : uspto_ds["train"], "ubuntuirc" : ubuntuirc_ds["train"], "arxiv" : arxiv_ds["train"] } # dataset_data = { # "AI4Code": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "AMPS": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "ASFPublicMail": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "Books3": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "CPDataset": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "DMMath": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "Discourse": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "Enwiki": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "EuroParliamentProceedings": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "FreeLaw_Options": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "GitHubDiff": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "GitHubIssues": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "Gutenberg": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "LeetCode": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "PileOfLaw": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "PubMed": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "S2ORC": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "StackExchange": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "USENET": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "USPTO": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "UbuntuIRC": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # "arXiv": { # # create fake data for the different ratios # "word_rep_ratios": np.random.randn(1000), # "char_rep_ratios": np.random.randn(1000), # "flagged_word_ratios": np.random.randn(1000), # "num_words": np.random.randint(0, 1000, 1000), # }, # } def plt_plot(ratio, dataset, threshold): x = dataset_data[dataset][ratio] # calculate percentage of data that will be removed given threshold perc = np.sum(x < threshold) / len(x) # create a figure fig = plt.figure() # add a subplot ax = fig.add_subplot(111) # plot some data using black ax.hist(x, bins=50, color="black") # plot red dashed line at threshold ax.axvline(threshold, color='r', linestyle='dashed', linewidth=2) # set title # add percentage of data removed ax.set_title(f"{dataset} (removed {perc:.2%})") plt.xlabel("Value") plt.ylabel("Frequency") # make it look nice plt.tight_layout() return fig with gr.Blocks() as demo: dataset = gr.Radio(list(dataset_data.keys()), label="Dataset", value="arXiv") print(dataset.value) with gr.Tab("Character Repetition Ratio"): # plot some random data plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=100, label="Threshold") calculate = gr.Button("Calculate") plot_fn = partial(plt_plot, "word_rep_ratios") calculate.click(plot_fn, [dataset, threshold], plot) with gr.Tab("Word Repetition Ratio"):# plot some random data plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") plot_fn = partial(plt_plot, "char_rep_ratios") calculate.click(plot_fn, [dataset, threshold], plot) with gr.Tab("Flagged Word Ratio"):# plot some random data plot = gr.Plot() threshold = gr.Slider(minimum=0, maximum=1, label="Threshold") calculate = gr.Button("Calculate") plot_fn = partial(plt_plot, "flagged_word_ratios") calculate.click(plot_fn, [dataset, threshold], plot) if __name__ == "__main__": demo.launch(share=True)