File size: 1,946 Bytes
f2142ae
 
 
 
 
 
 
6890e40
38381e8
6890e40
38381e8
6890e40
 
 
 
 
38381e8
 
 
 
 
 
 
 
 
 
 
 
6890e40
f2142ae
 
bc734c2
f2142ae
 
 
 
 
38381e8
f2142ae
38381e8
f2142ae
 
24a2032
6890e40
 
38381e8
f2142ae
 
 
 
 
 
 
 
38381e8
 
f2142ae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from pathlib import Path
from huggingface_hub import list_repo_files, hf_hub_url
from collections import defaultdict
import requests

repo_id = 'nateraw/stable-diffusion-gallery'

def get_data(revision='main'):
    data = defaultdict(list)
    for file in list_repo_files(repo_id, repo_type='dataset', revision=revision):
        path = Path(file)
        if path.name == '.gitattributes':
            continue
        if path.suffix in ['.png', '.jpg', '.jpeg']:
            data[path.parent.name].append(file)
    print(data.keys())
    return data

def on_refresh(data_state):
    data = get_data()
    data_state.update(dict(data))
    return [
        data_state,
        gr.update(choices=list(data_state.keys())),
    ]

def on_submit(run, data):
    images = [(hf_hub_url(repo_id=repo_id, filename=img, repo_type='dataset'), f"Seed:\n{Path(img).stem}") for img in data[run]]
    prompt_config_url = hf_hub_url(
        repo_id=repo_id,
        filename=f'{run}/prompt_config.json',
        repo_type='dataset'
    )
    prompt_config_json = requests.get(prompt_config_url).json()
    return prompt_config_json, images

runs = list(get_data().keys())
with gr.Blocks() as demo:
    data_state = gr.State({})
    with gr.Column(variant="panel"):
        with gr.Row(variant="compact"):
            refresh_btn = gr.Button("Refresh").style(full_width=True)

        with gr.Row(variant="compact"):
            run = gr.Dropdown(runs) #interactive=True) # .style(container=False)
            btn = gr.Button("View images").style(full_width=False)

        with gr.Row(variant="compact"):
            data_json = gr.Json()
        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[5], height="auto")

    refresh_btn.click(on_refresh, data_state, [data_state, run])
    btn.click(on_submit, [run, data_state], [data_json, gallery])

demo.launch(debug=True)