m / app.py
LobsterQQQ's picture
Create app.py
f1c7b01
raw
history blame
798 Bytes
from contextlib import nullcontext
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
context = autocast if device == "cuda" else nullcontext
dtype = torch.float16 if device == "cuda" else torch.float32
pipe = StableDiffusionPipeline.from_pretrained("ringhyacinth/nail-set-diffuser", torch_dtype=dtype)
pipe = pipe.to(device)
# Disable nsfw checker
disable_safety = True
if disable_safety:
def null_safety(images, **kwargs):
return images, False
pipe.safety_checker = null_safety
def infer(prompt, n_samples, steps, scale):
with context("cuda"):
images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images
return images