File size: 3,310 Bytes
4f91ffe
 
51f8f41
 
 
 
6dcb6b3
51f8f41
 
 
4f91ffe
6dcb6b3
 
 
4f91ffe
 
 
 
 
 
 
 
 
 
 
1ae6c5e
6dcb6b3
 
 
4fbc46c
6dcb6b3
 
c1497a6
3aadc38
3ace7e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5f11d4
 
6dcb6b3
 
 
 
 
b2d0aef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51f8f41
 
 
d5f11d4
6dcb6b3
 
 
e129330
 
 
 
 
 
 
 
51f8f41
 
e129330
 
 
 
 
 
 
 
 
 
 
 
 
 
51f8f41
 
e129330
 
6dcb6b3
 
 
d09f5de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import requests
import torch
import gradio as gr
import spaces
from huggingface_hub import login
from diffusers.utils import load_image

from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline

# ----------------------------
# Step 1: Download IP Adapter if not exists
# ----------------------------
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"

if not os.path.exists(file_path):
    print("File not found, downloading...")
    response = requests.get(url, stream=True)
    with open(file_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
    print("Download completed!")

# ----------------------------
# Step 2: Hugging Face Login
# ----------------------------
token = os.getenv("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
login(token=token)

model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"

transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ip_adapter_path,
    image_encoder_path=image_encoder_path,
    nb_token=64,
)


# ----------------------------
# Step 6: Gradio Function
# ----------------------------
@spaces.GPU
def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):


    ref_img = load_image(ref_img.name).convert('RGB')

    # please note that SD3.5 Large is sensitive to highres generation like 1536x1536
    image = pipe(
        width=1024,
        height=1024,
        prompt=prompt,
        negative_prompt="lowres, low quality, worst quality",
        num_inference_steps=24,
        guidance_scale=guidance_scale,
        generator=torch.Generator("cuda").manual_seed(42),
        clip_image=ref_img,
        ipadapter_scale=ipadapter_scale,
    ).images[0]

    return image


# ----------------------------
# Step 7: Gradio Interface
# ----------------------------
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
ref_img = gr.File(label="Upload Reference Image")
guidance_slider = gr.Slider(
    label="Guidance Scale",
    minimum=2,
    maximum=16,
    value=7,
    step=0.5,
    info="Controls adherence to the text prompt"
)

ipadapter_slider = gr.Slider(
    label="IP-Adapter Scale",
    minimum=0,
    maximum=1,
    value=0.5,
    step=0.1,
    info="Controls influence of the image prompt"
)

interface = gr.Interface(
    fn=gui_generation,
    inputs=[prompt_box, ref_img, guidance_slider, ipadapter_slider],
    outputs="image",
    title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter",
    description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
)

# ----------------------------
# Step 8: Launch Gradio App
# ----------------------------
interface.launch()