import os import gradio as gr from datasets import load_dataset auth_token = os.environ.get("auth_token") visit_bench_all = load_dataset("mlfoundations/visit-bench", use_auth_token=auth_token) print('dataset keys:') print(visit_bench_all.keys()) visit_bench = visit_bench_all['test'] df = visit_bench.to_pandas() print(f"Got {len(df)} items in dataframe") df = df.sample(frac=1) LINES_NUMBER = 20 def display_df(): df_images = df.head(LINES_NUMBER) return df_images def display_next(dataframe, end): start = int(end or len(dataframe)) end = int(start) + int(LINES_NUMBER) global df if end >= len(df) - 1: start = 0 end = LINES_NUMBER df = df.sample(frac=1) print(f"Shuffle") df_images = df.iloc[start:end] assert len(df_images) == LINES_NUMBER return df_images, end initial_dataframe = display_df() # Gradio Blocks with gr.Blocks() as demo: gr.Markdown("

VisIT-Bench Dataset Viewer

") with gr.Row(): num_end = gr.Number(visible=False) b1 = gr.Button("Get Initial dataframe") b2 = gr.Button("Next Rows") with gr.Row(): out_dataframe = gr.Dataframe(initial_dataframe, wrap=True, max_rows=LINES_NUMBER, overflow_row_behaviour="paginate", interactive=False) b1.click(fn=display_df, outputs=out_dataframe, api_name="initial_dataframe") b2.click(fn=display_next, inputs=[out_dataframe, num_end], outputs=[out_dataframe, num_end], api_name="next_rows") demo.launch(debug=True, show_error=True)