test_gradio / app.py
amos1088's picture
uuu
4f91ffe
raw
history blame
2.66 kB
import os
import requests
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"
# Check if the file already exists
if not os.path.exists(file_path):
print("File not found, downloading...")
response = requests.get(url, stream=True)
with open(file_path, "wb") as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
file.write(chunk)
print("Download completed!")
else:
print("File already exists.")
from models.transformer_sd3 import SD3Transformer2DModel
import gradio as gr
import torch
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
import os
import spaces
from huggingface_hub import login
token = os.getenv("HF_TOKEN")
login(token=token)
# Model and Pipeline Setup
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"
# Load transformer and pipeline
transformer = SD3Transformer2DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_path, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
# Initialize IP Adapter
pipe.init_ipadapter(
ip_adapter_path=ip_adapter_path,
image_encoder_path=image_encoder_path,
nb_token=64,
)
@spaces.GPU
def gui_generation(text, num_imgs, width, height):
"""
Generate images using Stable Diffusion 3.5
"""
images = pipe(
prompt=text,
width=width,
height=height,
num_images_per_prompt=num_imgs,
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).images
return images
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Stable Diffusion 3.5 Image Generation")
with gr.Row():
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
number_slider = gr.Slider(1, 30, value=2, step=1, label="Batch size")
with gr.Row():
width_slider = gr.Slider(256, 1536, value=1024, step=64, label="Width")
height_slider = gr.Slider(256, 1536, value=1024, step=64, label="Height")
gallery = gr.Gallery(columns=[3], rows=[1], object_fit="contain", height="auto")
generate_btn = gr.Button("Generate")
generate_btn.click(
fn=gui_generation,
inputs=[prompt_box, number_slider, width_slider, height_slider],
outputs=gallery
)
demo.launch()