|
import io |
|
import re |
|
import time |
|
from itertools import islice |
|
from functools import partial |
|
from multiprocessing.pool import ThreadPool |
|
from queue import Queue, Empty |
|
from typing import Callable, Iterable, Iterator, Optional, TypeVar |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import requests.exceptions |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
model_id = "microsoft/Phi-3-mini-4k-instruct" |
|
client = InferenceClient(model_id) |
|
|
|
MAX_TOTAL_NB_ITEMS = 100 |
|
MAX_NB_ITEMS_PER_GENERATION_CALL = 10 |
|
NUM_ROWS = 100 |
|
NUM_VARIANTS = 10 |
|
URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub" |
|
|
|
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = ( |
|
"A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. " |
|
f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality dataset that don't exist but sound plausible and would " |
|
"be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. " |
|
"Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated to the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)" |
|
) |
|
|
|
GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = ( |
|
"A ML practitioner is looking for a dataset CSV after the query '{search_query}'. " |
|
"Generate the first 5 rows of a plausible and quality CSV for the dataset '{dataset_name}'. " |
|
"You can get inspiration from related keywords '{tags}' but most importantly the dataset should correspond to the query '{search_query}'. " |
|
"Focus on quality text content and and use a 'label' or 'labels' column if it makes sense (invent labels, avoid reusing the keywords, be accurate while labelling texts). " |
|
"Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**." |
|
) |
|
GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well ? Use the same CSV header '{csv_header}'." |
|
GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples." |
|
GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples." |
|
|
|
RARITIES = ["pretty obvious", "common/regular", "unexpected but useful", "uncommon but still plausible", "rare/niche but still plausible"] |
|
LONG_RARITIES = [ |
|
"obvious", |
|
"expected", |
|
"common", |
|
"regular", |
|
"unexpected but useful" |
|
"original but useful", |
|
"specific but not far-fetched", |
|
"uncommon but still plausible", |
|
"rare but still plausible", |
|
"very nice but still plausible", |
|
] |
|
|
|
landing_page_query = "various datasets on many different subjects and topics, from classification to language modeling, from science to sport to finance to news" |
|
|
|
landing_page_datasets_generated_text = """ |
|
1. NewsEventsPredict (classification, media, trend) |
|
2. FinancialForecast (economy, stocks, regression) |
|
3. HealthMonitor (science, real-time, anomaly detection) |
|
4. SportsAnalysis (classification, performance, player tracking) |
|
5. SciLiteracyTools (language modeling, science literacy, text classification) |
|
6. RetailSalesAnalyzer (consumer behavior, sales trend, segmentation) |
|
7. SocialSentimentEcho (social media, emotion analysis, clustering) |
|
8. NewsEventTracker (classification, public awareness, topical clustering) |
|
9. HealthVitalSigns (anomaly detection, biometrics, prediction) |
|
10. GameStockPredict (classification, finance, sports contingency) |
|
""" |
|
default_output = landing_page_datasets_generated_text.strip().split("\n") |
|
assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL |
|
|
|
css = """ |
|
a { |
|
color: var(--body-text-color); |
|
} |
|
|
|
.datasetButton { |
|
justify-content: start; |
|
justify-content: left; |
|
} |
|
.tags { |
|
font-size: var(--button-small-text-size); |
|
color: var(--body-text-color-subdued); |
|
} |
|
.topButton { |
|
justify-content: start; |
|
justify-content: left; |
|
text-align: left; |
|
background: transparent; |
|
box-shadow: none; |
|
padding-bottom: 0; |
|
} |
|
.topButton::before { |
|
content: url("data:image/svg+xml,%3Csvg style='color: rgb(209 213 219)' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' aria-hidden='true' focusable='false' role='img' width='1em' height='1em' preserveAspectRatio='xMidYMid meet' viewBox='0 0 25 25'%3E%3Cellipse cx='12.5' cy='5' fill='currentColor' fill-opacity='0.25' rx='7.5' ry='2'%3E%3C/ellipse%3E%3Cpath d='M12.5 15C16.6421 15 20 14.1046 20 13V20C20 21.1046 16.6421 22 12.5 22C8.35786 22 5 21.1046 5 20V13C5 14.1046 8.35786 15 12.5 15Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M12.5 7C16.6421 7 20 6.10457 20 5V11.5C20 12.6046 16.6421 13.5 12.5 13.5C8.35786 13.5 5 12.6046 5 11.5V5C5 6.10457 8.35786 7 12.5 7Z' fill='currentColor' opacity='0.5'%3E%3C/path%3E%3Cpath d='M5.23628 12C5.08204 12.1598 5 12.8273 5 13C5 14.1046 8.35786 15 12.5 15C16.6421 15 20 14.1046 20 13C20 12.8273 19.918 12.1598 19.7637 12C18.9311 12.8626 15.9947 13.5 12.5 13.5C9.0053 13.5 6.06886 12.8626 5.23628 12Z' fill='currentColor'%3E%3C/path%3E%3C/svg%3E"); |
|
margin-right: .25rem; |
|
margin-left: -.125rem; |
|
margin-top: .25rem; |
|
} |
|
.bottomButton { |
|
justify-content: start; |
|
justify-content: left; |
|
text-align: left; |
|
background: transparent; |
|
box-shadow: none; |
|
font-size: var(--button-small-text-size); |
|
color: var(--body-text-color-subdued); |
|
padding-top: 0; |
|
align-items: baseline; |
|
} |
|
.bottomButton::before { |
|
content: 'tags:'; |
|
margin-right: .25rem; |
|
} |
|
.buttonsGroup { |
|
background: transparent; |
|
} |
|
.buttonsGroup:hover { |
|
background: var(--input-background-fill); |
|
} |
|
.buttonsGroup div { |
|
background: transparent; |
|
} |
|
.insivibleButtonGroup { |
|
display: none; |
|
} |
|
|
|
@keyframes placeHolderShimmer{ |
|
0%{ |
|
background-position: -468px 0 |
|
} |
|
100%{ |
|
background-position: 468px 0 |
|
} |
|
} |
|
.linear-background { |
|
animation-duration: 1s; |
|
animation-fill-mode: forwards; |
|
animation-iteration-count: infinite; |
|
animation-name: placeHolderShimmer; |
|
animation-timing-function: linear; |
|
background-image: linear-gradient(to right, var(--body-text-color-subdued) 8%, #dddddd11 18%, var(--body-text-color-subdued) 33%); |
|
background-size: 1000px 104px; |
|
color: transparent; |
|
background-clip: text; |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
generated_texts_state = gr.State((landing_page_datasets_generated_text,)) |
|
with gr.Row(): |
|
with gr.Column(scale=4, min_width=0): |
|
pass |
|
with gr.Column(scale=10): |
|
gr.Markdown( |
|
"# 🤗 Infinite Dataset Hub ♾️\n\n" |
|
"An endless catalog of datasets, created just for you.\n\n" |
|
) |
|
with gr.Column(scale=4, min_width=0): |
|
pass |
|
with gr.Column() as search_page: |
|
with gr.Row(): |
|
with gr.Column(scale=4, min_width=0): |
|
pass |
|
with gr.Column(scale=10): |
|
with gr.Row(): |
|
search_bar = gr.Textbox(max_lines=1, placeholder="Search datasets, get infinite results", show_label=False, container=False, scale=9) |
|
search_button = gr.Button("🔍", variant="primary", scale=1) |
|
with gr.Column(scale=4, min_width=0): |
|
pass |
|
with gr.Row(): |
|
with gr.Column(scale=4, min_width=0): |
|
pass |
|
with gr.Column(scale=10): |
|
button_groups: list[gr.Group] = [] |
|
buttons: list[gr.Button] = [] |
|
for i in range(MAX_TOTAL_NB_ITEMS): |
|
if i < len(default_output): |
|
line = default_output[i] |
|
dataset_name, tags = line.split(".", 1)[1].strip(" )").split(" (", 1) |
|
group_classes = "buttonsGroup" |
|
dataset_name_classes = "topButton" |
|
tags_classes = "bottomButton" |
|
else: |
|
dataset_name, tags = "⬜⬜⬜⬜⬜⬜", "░░░░, ░░░░, ░░░░" |
|
group_classes = "buttonsGroup insivibleButtonGroup" |
|
dataset_name_classes = "topButton linear-background" |
|
tags_classes = "bottomButton linear-background" |
|
with gr.Group(elem_classes=group_classes) as button_group: |
|
button_groups.append(button_group) |
|
buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes)) |
|
buttons.append(gr.Button(tags, elem_classes=tags_classes)) |
|
|
|
load_more_datasets = gr.Button("Load more datasets") |
|
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_") |
|
with gr.Column(scale=4, min_width=0): |
|
pass |
|
with gr.Column(visible=False) as dataset_page: |
|
dataset_title = gr.Markdown() |
|
gr.Markdown("_Note: This is an AI-generated dataset so its content may be inaccurate or false_") |
|
dataset_content = gr.Markdown() |
|
generate_full_dataset_button = gr.Button("Generate Full Dataset", variant="primary") |
|
dataset_dataframe = gr.DataFrame(visible=False, interactive=False, wrap=True) |
|
save_dataset_button = gr.Button("💾 Save Dataset", variant="primary", visible=False) |
|
dataset_share_button = gr.Button("Share Dataset URL") |
|
dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True) |
|
back_button = gr.Button("< Back", size="sm") |
|
|
|
app_state = gr.State({}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
def batched(it: Iterable[T], n: int) -> Iterator[list[T]]: |
|
it = iter(it) |
|
while batch := list(islice(it, n)): |
|
yield batch |
|
|
|
|
|
def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]: |
|
messages = [ |
|
{"role": "user", "content": msg} |
|
] + [ |
|
item |
|
for generated_text in generated_texts |
|
for item in [ |
|
{"role": "assistant", "content": generated_text}, |
|
{"role": "user", "content": "Can you generate more ?"}, |
|
] |
|
] |
|
for _ in range(3): |
|
try: |
|
for message in client.chat_completion( |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
top_p=0.8, |
|
seed=42, |
|
): |
|
yield message.choices[0].delta.content |
|
except requests.exceptions.ConnectionError as e: |
|
print(e + "\n\nRetrying in 1sec") |
|
time.sleep(1) |
|
continue |
|
break |
|
|
|
|
|
def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]: |
|
search_query = search_query or "" |
|
search_query = search_query[:1000] if search_query.strip() else landing_page_query |
|
generated_text = "" |
|
current_line = "" |
|
for token in stream_reponse( |
|
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query), |
|
generated_texts=generated_texts, |
|
): |
|
current_line += token |
|
if current_line.endswith("\n"): |
|
yield current_line |
|
generated_text += current_line |
|
current_line = "" |
|
yield current_line |
|
generated_text += current_line |
|
print("-----\n\n" + generated_text) |
|
|
|
|
|
def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]: |
|
search_query = search_query or "" |
|
search_query = search_query[:1000] if search_query.strip() else landing_page_query |
|
generated_text = "" |
|
for token in stream_reponse(GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format( |
|
search_query=search_query, |
|
dataset_name=dataset_name, |
|
tags=tags, |
|
), max_tokens=1500): |
|
generated_text += token |
|
yield generated_text |
|
print("-----\n\n" + generated_text) |
|
|
|
|
|
def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None: |
|
for i, result in enumerate(func(**kwargs)): |
|
queue.put(result) |
|
return None |
|
|
|
|
|
def iflatmap_unordered( |
|
func: Callable[..., Iterable[T]], |
|
*, |
|
kwargs_iterable: Iterable[dict], |
|
) -> Iterable[T]: |
|
queue = Queue() |
|
with ThreadPool() as pool: |
|
async_results = [ |
|
pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) for kwargs in kwargs_iterable |
|
] |
|
try: |
|
while True: |
|
try: |
|
yield queue.get(timeout=0.05) |
|
except Empty: |
|
if all(async_result.ready() for async_result in async_results) and queue.empty(): |
|
break |
|
finally: |
|
|
|
[async_result.get(timeout=0.05) for async_result in async_results] |
|
|
|
|
|
def generate_partial_dataset(title: str, content: str, search_query: str, variant: str, csv_header: str, output: list[dict[str, str]], indices_to_generate: list[int], max_tokens=1500) -> Iterator[int]: |
|
dataset_name, tags = title.strip("# ").split("\ntags:", 1) |
|
dataset_name, tags = dataset_name.strip(), tags.strip() |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format( |
|
dataset_name=dataset_name, |
|
tags=tags, |
|
search_query=search_query, |
|
) |
|
}, |
|
{"role": "assistant", "content": title + "\n\n" + content}, |
|
{"role": "user", "content": GENERATE_MORE_ROWS.format(csv_header=csv_header) + " " + variant}, |
|
] |
|
for _ in range(3): |
|
generated_text = "" |
|
generated_csv = "" |
|
current_line = "" |
|
nb_samples = 0 |
|
_in_csv = False |
|
try: |
|
for message in client.chat_completion( |
|
messages=messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
top_p=0.8, |
|
seed=42, |
|
): |
|
if nb_samples >= len(indices_to_generate): |
|
break |
|
current_line += message.choices[0].delta.content |
|
generated_text += message.choices[0].delta.content |
|
if current_line.endswith("\n"): |
|
_in_csv = _in_csv ^ current_line.lstrip().startswith("```") |
|
if current_line.strip() and _in_csv and not current_line.lstrip().startswith("```"): |
|
generated_csv += current_line |
|
try: |
|
generated_df = parse_csv_df(generated_csv.strip(), csv_header=csv_header) |
|
if len(generated_df) > nb_samples: |
|
output[indices_to_generate[nb_samples]] = generated_df.iloc[-1].to_dict() |
|
nb_samples += 1 |
|
yield 1 |
|
except Exception: |
|
pass |
|
current_line = "" |
|
except requests.exceptions.ConnectionError as e: |
|
print(e + "\n\nRetrying in 1sec") |
|
time.sleep(1) |
|
continue |
|
break |
|
|
|
|
|
|
|
|
|
|
|
def generate_variants(preview_df: pd.DataFrame): |
|
label_candidate_columns = [column for column in preview_df.columns if "label" in column.lower()] |
|
if label_candidate_columns: |
|
labels = preview_df[label_candidate_columns[0]].unique() |
|
if len(labels) > 1: |
|
return [ |
|
GENERATE_VARIANTS_WITH_RARITY_AND_LABEL.format(rarity=rarity, label=label) |
|
for rarity in RARITIES |
|
for label in labels |
|
] |
|
return [ |
|
GENERATE_VARIANTS_WITH_RARITY.format(rarity=rarity) |
|
for rarity in LONG_RARITIES |
|
] |
|
|
|
|
|
def parse_preview_df(content: str) -> tuple[str, pd.DataFrame]: |
|
_in_csv = False |
|
csv = "\n".join( |
|
line for line in content.split("\n") if line.strip() |
|
and (_in_csv := (_in_csv ^ line.lstrip().startswith("```"))) |
|
and not line.lstrip().startswith("```") |
|
) |
|
if not csv: |
|
raise gr.Error("Failed to parse CSV Preview") |
|
return csv.split("\n")[0], parse_csv_df(csv) |
|
|
|
|
|
def parse_csv_df(csv: str, csv_header: Optional[str] = None) -> pd.DataFrame: |
|
|
|
for match in re.finditer(r'''(?!")\[(["'][\w ]+["'][, ]*)+\](?!")''', csv): |
|
span = match.string[match.start() : match.end()] |
|
csv = csv.replace(span, '"' + span.replace('"', "'") + '"', 1) |
|
|
|
if csv_header and csv.strip().split("\n")[0] != csv_header: |
|
csv = csv_header + "\n" + csv |
|
|
|
df = pd.read_csv(io.StringIO(csv), skipinitialspace=True) |
|
return df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _search_datasets(search_query): |
|
yield {generated_texts_state: [], app_state: {"search_query": search_query}} |
|
yield { |
|
button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup") |
|
for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:] |
|
} |
|
yield { |
|
k: v |
|
for dataset_name_button, tags_button in batched(buttons, 2) |
|
for k, v in { |
|
dataset_name_button: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"), |
|
tags_button: gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background") |
|
}.items() |
|
} |
|
current_item_idx = 0 |
|
generated_text = "" |
|
for line in gen_datasets_line_by_line(search_query): |
|
if "I'm sorry" in line: |
|
raise gr.Error("Error: inappropriate content") |
|
if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL: |
|
return |
|
if line.strip() and line.strip().split(".", 1)[0].isnumeric(): |
|
try: |
|
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1) |
|
except ValueError: |
|
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) |
|
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ") |
|
generated_text += line |
|
yield { |
|
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"), |
|
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"), |
|
generated_texts_state: (generated_text,), |
|
} |
|
current_item_idx += 1 |
|
|
|
|
|
@search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state, app_state]) |
|
def search_dataset_from_search_button(search_query): |
|
yield from _search_datasets(search_query) |
|
|
|
|
|
@search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state, app_state]) |
|
def search_dataset_from_search_bar(search_query): |
|
yield from _search_datasets(search_query) |
|
|
|
|
|
@load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state]) |
|
def search_more_datasets(search_query, generated_texts): |
|
current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL |
|
yield { |
|
button_group: gr.Group(elem_classes="buttonsGroup") |
|
for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL] |
|
} |
|
generated_text = "" |
|
for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts): |
|
if "I'm sorry" in line or "against Microsoft's use case policy" in line: |
|
raise gr.Error("Error: inappropriate content") |
|
if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL: |
|
return |
|
if line.strip() and line.strip().split(".", 1)[0].isnumeric(): |
|
try: |
|
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1) |
|
except ValueError: |
|
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], "" |
|
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ") |
|
generated_text += line |
|
yield { |
|
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"), |
|
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"), |
|
generated_texts_state: (*generated_texts, generated_text), |
|
} |
|
current_item_idx += 1 |
|
|
|
def _show_dataset(search_query, dataset_name, tags): |
|
yield { |
|
search_page: gr.Column(visible=False), |
|
dataset_page: gr.Column(visible=True), |
|
dataset_title: f"# {dataset_name}\n\n tags: {tags}", |
|
dataset_share_textbox: gr.Textbox(visible=False), |
|
dataset_dataframe: gr.DataFrame(visible=False), |
|
generate_full_dataset_button: gr.Button(visible=True), |
|
save_dataset_button: gr.Button(visible=False), |
|
app_state: { |
|
"search_query": search_query, |
|
"dataset_name": dataset_name, |
|
"tags": tags |
|
} |
|
} |
|
for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags): |
|
yield {dataset_content: generated_text} |
|
|
|
|
|
show_dataset_inputs = [search_bar, *buttons] |
|
show_dataset_outputs = [app_state, search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, dataset_share_textbox] |
|
scroll_to_top_js = """ |
|
function (...args) { |
|
console.log(args); |
|
if ('parentIFrame' in window) { |
|
window.parentIFrame.scrollTo({top: 0, behavior:'smooth'}); |
|
} else { |
|
window.scrollTo({ top: 0 }); |
|
} |
|
return args; |
|
} |
|
""" |
|
|
|
def show_dataset_from_button(search_query, *buttons_values, i): |
|
dataset_name, tags = buttons_values[2 * i : 2 * i + 2] |
|
yield from _show_dataset(search_query, dataset_name, tags) |
|
|
|
for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)): |
|
dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js) |
|
tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js) |
|
|
|
|
|
@back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js) |
|
def show_search_page(): |
|
return gr.Column(visible=True), gr.Column(visible=False) |
|
|
|
|
|
@generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button]) |
|
def generate_full_dataset(title, content, search_query): |
|
dataset_name, tags = title.strip("# ").split("\ntags:", 1) |
|
dataset_name, tags = dataset_name.strip(), tags.strip() |
|
csv_header, preview_df = parse_preview_df(content) |
|
|
|
for column_name, values in preview_df.to_dict(orient="series").items(): |
|
try: |
|
if [int(v) for v in values] == list(range(len(preview_df))): |
|
preview_df = preview_df.drop(columns=column_name) |
|
if [int(v) for v in values] == list(range(1, len(preview_df) + 1)): |
|
preview_df = preview_df.drop(columns=column_name) |
|
except Exception: |
|
pass |
|
columns = list(preview_df) |
|
output: list[Optional[dict]] = [None] * NUM_ROWS |
|
output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))] |
|
yield { |
|
dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True), |
|
generate_full_dataset_button: gr.Button(visible=False), |
|
save_dataset_button: gr.Button(visible=True, interactive=False) |
|
} |
|
kwargs_iterable = [ |
|
{ |
|
"title": title, |
|
"content": content, |
|
"search_query": search_query, |
|
"variant": variant, |
|
"csv_header": csv_header, |
|
"output": output, |
|
"indices_to_generate": list(range(len(preview_df) + i, NUM_ROWS, NUM_VARIANTS)), |
|
} |
|
for i, variant in enumerate(islice(generate_variants(preview_df), NUM_VARIANTS)) |
|
] |
|
for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable): |
|
yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])} |
|
yield {save_dataset_button: gr.Button(visible=True, interactive=True)} |
|
print(f"Generated {dataset_name}!") |
|
|
|
|
|
@save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe]) |
|
def save_dataset(title, content, search_query, df): |
|
raise gr.Error("Not implemented yet sorry ! Request your dataset to be saved in the Discussion tab (provide the dataset URL)") |
|
|
|
|
|
@dataset_share_button.click(inputs=[app_state], outputs=[dataset_share_textbox]) |
|
def show_dataset_url(state): |
|
return gr.Textbox( |
|
f"{URL}?q={state['search_query'].replace(' ', '+')}&dataset={state['dataset_name'].replace(' ', '+')}&tags={state['tags'].replace(' ', '+')}", |
|
visible=True, |
|
) |
|
|
|
@demo.load(outputs=[app_state, search_page, search_bar, dataset_page, dataset_title, dataset_content, dataset_share_textbox] + button_groups + buttons + [generated_texts_state]) |
|
def load_app(request: gr.Request): |
|
query_params = dict(request.query_params) |
|
if "dataset" in query_params: |
|
yield from _show_dataset( |
|
search_query=query_params.get("q", query_params["dataset"]), |
|
dataset_name=query_params["dataset"], |
|
tags=query_params.get("tags", "") |
|
) |
|
elif "q" in query_params: |
|
yield {search_bar: query_params["q"]} |
|
yield from _search_datasets(query_params["q"]) |
|
else: |
|
yield {search_page: gr.Column(visible=True)} |
|
|
|
|
|
demo.launch() |
|
|