test_gradio / app.py
amos1088's picture
test gradio
97c3973
raw
history blame
2.11 kB
import gradio as gr
from huggingface_hub import login
import os
import spaces
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
import torch
from diffusers import AutoPipelineForText2Image, DDIMScheduler
from transformers import CLIPVisionModelWithProjection
from diffusers.utils import load_image
token = os.getenv("HF_TOKEN")
login(token=token)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
)
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
image_encoder=image_encoder,
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name=["ip-adapter-plus_sdxl_vit-h.safetensors", "ip-adapter-plus-face_sdxl_vit-h.safetensors"]
)
pipeline.set_ip_adapter_scale([0.7, 0.3])
pipeline.enable_model_cpu_offload()
@spaces.GPU
def generate_image(prompt, reference_image, controlnet_conditioning_scale):
reference_image = Image.open(reference_image)
# reference_image.resize((512, 512))
pipeline.set_ip_adapter_scale([controlnet_conditioning_scale])
image = pipeline(
prompt=prompt,
ip_adapter_image=[reference_image],
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50, num_images_per_prompt=1,
).images[0]
return image
# Set up Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Image( type= "filepath",label="Reference Image (Style)"),
gr.Slider(label="Control Net Conditioning Scale", minimum=0, maximum=1.0, step=0.1, value=0.6),
],
outputs="image",
title="Image Generation with Stable Diffusion 3 medium and ControlNet",
description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3 medium with ControlNet."
)
interface.launch()