File size: 2,709 Bytes
4e7c6ea
0bb1fd1
 
 
4e7c6ea
 
 
 
 
6356404
 
0bb1fd1
4e7c6ea
0bb1fd1
6356404
0bb1fd1
 
6356404
0bb1fd1
 
 
 
 
 
 
4e7c6ea
0bb1fd1
 
 
 
 
 
4e7c6ea
 
 
 
0bb1fd1
4e7c6ea
 
 
 
0bb1fd1
4e7c6ea
0bb1fd1
4e7c6ea
 
0bb1fd1
4e7c6ea
 
 
 
 
0bb1fd1
4e7c6ea
 
 
 
0bb1fd1
4e7c6ea
 
0bb1fd1
4e7c6ea
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb1fd1
 
4e7c6ea
 
 
 
 
03c95d2
f7072bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
from PIL import Image  
import re
import os



auth_token = os.getenv("auth_token")
model_id = "CompVis/stable-diffusion-v1-4"
device = "cpu"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=auth_token)
pipe = pipe.to(device)

def infer(prompt, samples, steps, scale, seed):     
    generator = torch.Generator(device=device).manual_seed(seed)
    images_list = pipe(
        [prompt] * samples,
        num_inference_steps=steps,
        guidance_scale=scale,
        generator=generator,
    )
    images = []
    safe_image = Image.open(r"unsafe.png")
    for i, image in enumerate(images_list["sample"]):
        if(images_list["nsfw_content_detected"][i]):
            images.append(safe_image)
        else:
            images.append(image)
    return images
    


block = gr.Blocks()

with block:
    with gr.Group():
        with gr.Box():
            with gr.Row().style(mobile_collapse=False, equal_height=True):
                text = gr.Textbox(
                    label="Enter your prompt",
                    show_label=False,
                    max_lines=1,
                    placeholder="Enter your prompt",
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )
                btn = gr.Button("Generate image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )
        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2], height="auto")

        advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")

        with gr.Row(elem_id="advanced-options"):
            samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
            steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
            scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
            )
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=2147483647,
                step=1,
                randomize=True,
            )
        text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
        btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
        advanced_button.click(
            None,
            [],
            text,
        )
        
block.launch()