Spaces:
Paused
Paused
import os | |
import requests | |
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin" | |
file_path = "ip-adapter.bin" | |
# Check if the file already exists | |
if not os.path.exists(file_path): | |
print("File not found, downloading...") | |
response = requests.get(url, stream=True) | |
with open(file_path, "wb") as file: | |
for chunk in response.iter_content(chunk_size=1024): | |
if chunk: | |
file.write(chunk) | |
print("Download completed!") | |
else: | |
print("File already exists.") | |
from models.transformer_sd3 import SD3Transformer2DModel | |
import gradio as gr | |
import torch | |
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline | |
import os | |
import spaces | |
from huggingface_hub import login | |
token = os.getenv("HF_TOKEN") | |
login(token=token) | |
# Model and Pipeline Setup | |
model_path = 'stabilityai/stable-diffusion-3.5-large' | |
ip_adapter_path = './ip-adapter.bin' | |
image_encoder_path = "google/siglip-so400m-patch14-384" | |
# Load transformer and pipeline | |
transformer = SD3Transformer2DModel.from_pretrained( | |
model_path, subfolder="transformer", torch_dtype=torch.bfloat16 | |
) | |
pipe = StableDiffusion3Pipeline.from_pretrained( | |
model_path, transformer=transformer, torch_dtype=torch.bfloat16 | |
).to("cuda") | |
# Initialize IP Adapter | |
pipe.init_ipadapter( | |
ip_adapter_path=ip_adapter_path, | |
image_encoder_path=image_encoder_path, | |
nb_token=64, | |
) | |
def gui_generation(text, num_imgs, width, height): | |
""" | |
Generate images using Stable Diffusion 3.5 | |
""" | |
images = pipe( | |
prompt=text, | |
width=width, | |
height=height, | |
num_images_per_prompt=num_imgs, | |
negative_prompt="lowres, low quality, worst quality", | |
num_inference_steps=24, | |
guidance_scale=5.0, | |
generator=torch.Generator("cuda").manual_seed(42), | |
).images | |
return images | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Stable Diffusion 3.5 Image Generation") | |
with gr.Row(): | |
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt") | |
number_slider = gr.Slider(1, 30, value=2, step=1, label="Batch size") | |
with gr.Row(): | |
width_slider = gr.Slider(256, 1536, value=1024, step=64, label="Width") | |
height_slider = gr.Slider(256, 1536, value=1024, step=64, label="Height") | |
gallery = gr.Gallery(columns=[3], rows=[1], object_fit="contain", height="auto") | |
generate_btn = gr.Button("Generate") | |
generate_btn.click( | |
fn=gui_generation, | |
inputs=[prompt_box, number_slider, width_slider, height_slider], | |
outputs=gallery | |
) | |
demo.launch() | |