pwilczewski commited on
Commit
97e5c73
·
1 Parent(s): b9d1dbd

initial gradio app

Browse files
Files changed (1) hide show
  1. app.py +82 -9
app.py CHANGED
@@ -1,10 +1,83 @@
1
-
2
  import gradio as gr
3
- description = "BigGAN text-to-image demo."
4
- title = "BigGAN ImageNet"
5
- interface = gr.Interface.load("huggingface/osanseviero/BigGAN-deep-128",
6
- description=description,
7
- title = title,
8
- examples=[["american robin"]]
9
- )
10
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from torch import autocast
4
+ from diffusers import StableDiffusionPipeline
5
+ from datasets import load_dataset
6
+ from PIL import Image
7
+ import re
8
+ import os
9
+
10
+ auth_token = os.getenv("auth_token")
11
+ model_id = "CompVis/stable-diffusion-v1-4"
12
+ device = "cpu"
13
+ #pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
14
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token)
15
+ pipe = pipe.to(device)
16
+
17
+ def infer(prompt, samples, steps, scale, seed):
18
+ generator = torch.Generator(device=device).manual_seed(seed)
19
+ images_list = pipe(
20
+ [prompt] * samples,
21
+ num_inference_steps=steps,
22
+ guidance_scale=scale,
23
+ generator=generator,
24
+ )
25
+ images = []
26
+ safe_image = Image.open(r"unsafe.png")
27
+ for i, image in enumerate(images_list["sample"]):
28
+ if(images_list["nsfw_content_detected"][i]):
29
+ images.append(safe_image)
30
+ else:
31
+ images.append(image)
32
+ return images
33
+
34
+
35
+
36
+ block = gr.Blocks()
37
+
38
+ with block:
39
+ with gr.Group():
40
+ with gr.Box():
41
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
42
+ text = gr.Textbox(
43
+ label="Enter your prompt",
44
+ show_label=False,
45
+ max_lines=1,
46
+ placeholder="Enter your prompt",
47
+ ).style(
48
+ border=(True, False, True, True),
49
+ rounded=(True, False, False, True),
50
+ container=False,
51
+ )
52
+ btn = gr.Button("Generate image").style(
53
+ margin=False,
54
+ rounded=(False, True, True, False),
55
+ )
56
+ gallery = gr.Gallery(
57
+ label="Generated images", show_label=False, elem_id="gallery"
58
+ ).style(grid=[2], height="auto")
59
+
60
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
61
+
62
+ with gr.Row(elem_id="advanced-options"):
63
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
64
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
65
+ scale = gr.Slider(
66
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
67
+ )
68
+ seed = gr.Slider(
69
+ label="Seed",
70
+ minimum=0,
71
+ maximum=2147483647,
72
+ step=1,
73
+ randomize=True,
74
+ )
75
+ text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
76
+ btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
77
+ advanced_button.click(
78
+ None,
79
+ [],
80
+ text,
81
+ )
82
+
83
+ block.launch()