text2image / app.py
amitkayal's picture
Duplicate from imseldrith/text2image
958a329
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline, PNDMScheduler
import requests
import PIL
from PIL import Image
import numpy as np
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
scheduler_ = PNDMScheduler(
beta_start=0.005,
beta_end=0.12,
beta_schedule="scaled_linear"
)
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=scheduler_,
revision="fp16",
use_auth_token=True
).to("cuda")
pipeimg = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
).to("cuda")
block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
num_samples = 2
def infer(prompt, init_image, strength):
if init_image != None:
init_image = init_image.resize((512, 512))
init_image = preprocess(init_image)
with autocast("cuda"):
images = pipeimg([prompt] * num_samples, init_image=init_image, strength=strength, guidance_scale=7.5)["sample"]
else:
with autocast("cuda"):
images = pipe([prompt] * num_samples, guidance_scale=7.5)["sample"]
return images
with block as demo:
gr.Markdown("<h1><center>Stable Diffusion</center></h1>")
gr.Markdown(
"Stable Diffusion is an AI model that generates images from any prompt you give!"
)
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Enter your prompt", show_label=False, max_lines=1
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Run").style(
margin=False,
rounded=(False, True, True, False),
)
strength_slider = gr.Slider(
label="Strength",
maximum = 1,
value = 0.75
)
image = gr.Image(
label="Intial Image",
type="pil"
)
gallery = gr.Gallery(label="Generated images", show_label=False).style(
grid=[2], height="auto"
)
text.submit(infer, inputs=[text,image,strength_slider], outputs=gallery)
btn.click(infer, inputs=[text,image,strength_slider], outputs=gallery)
gr.Markdown(
"""___
<p style='text-align: center'>
Created by CompVis and Stability AI
<br/>
</p>"""
)
demo.launch(share=True)