test_gradio / app.py
amos1088's picture
test gradio
91a655a
raw
history blame
1.84 kB
import gradio as gr
import torch
from PIL import Image
from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
import os
from huggingface_hub import login
token = os.getenv("HF_TOKEN")
login(token=token)
# Model and paths
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"
# Load SD3.5 pipeline and components
transformer = SD3Transformer2DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_path, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipe.init_ipadapter(
ip_adapter_path=ip_adapter_path,
image_encoder_path=image_encoder_path,
nb_token=64,
)
@gr.Interface()
def gui_generation(image: Image, style_image: Image):
"""
Generate an image based on input and style images.
"""
generator = torch.Generator("cuda").manual_seed(42) # Reproducibility
output = pipe(
width=1024,
height=1024,
prompt="",
negative_prompt="",
num_inference_steps=24,
guidance_scale=5.0,
generator=generator,
clip_image=style_image,
ipadapter_scale=0.5,
).images[0]
return output
# Gradio UI elements
image_input = gr.Image(type="pil", label="Input Image")
style_image_input = gr.Image(type="pil", label="Style Image")
output_image = gr.Image(label="Generated Image")
interface = gr.Interface(
gui_generation,
inputs=[image_input, style_image_input],
outputs=output_image,
title="Image Generation with Style Image",
description="Upload an input image and a style image to generate a new image based on the style."
)
interface.launch()