File size: 3,477 Bytes
4f91ffe
 
51f8f41
 
 
 
6dcb6b3
f40f178
 
 
51f8f41
6dcb6b3
 
 
4f91ffe
 
 
 
 
 
 
 
 
 
 
1ae6c5e
6dcb6b3
 
 
4fbc46c
6dcb6b3
 
c1497a6
3aadc38
f40f178
3ace7e2
 
5a35e98
f40f178
 
 
5a35e98
f40f178
 
 
3ace7e2
f40f178
 
3ace7e2
 
 
d5f11d4
5a35e98
6dcb6b3
 
 
 
7957b28
a2919a7
b2d0aef
f40f178
9f78050
f40f178
b2d0aef
 
 
 
7957b28
6f3c8de
b2d0aef
f40f178
b2d0aef
f40f178
b2d0aef
51f8f41
 
 
d5f11d4
6dcb6b3
 
 
e129330
7957b28
 
e129330
 
 
 
 
 
 
51f8f41
 
e129330
 
 
 
 
 
 
 
 
 
 
 
7957b28
bfa42ad
51f8f41
 
e129330
 
6dcb6b3
 
 
d09f5de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import requests
import torch
import gradio as gr
import spaces
from huggingface_hub import login
from diffusers.utils import load_image

from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline

# ----------------------------
# Step 1: Download IP Adapter if not exists
# ----------------------------
url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
file_path = "ip-adapter.bin"

if not os.path.exists(file_path):
    print("File not found, downloading...")
    response = requests.get(url, stream=True)
    with open(file_path, "wb") as file:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                file.write(chunk)
    print("Download completed!")

# ----------------------------
# Step 2: Hugging Face Login
# ----------------------------
token = os.getenv("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
login(token=token)

model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"

transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ip_adapter_path,
    image_encoder_path=image_encoder_path,
    nb_token=64,
)


# ----------------------------
# Step 6: Gradio Function
# ----------------------------
@spaces.GPU
def gui_generation(prompt,negative_prompt, ref_img, guidance_scale, ipadapter_scale):


    ref_img = load_image(ref_img.name).convert('RGB')

    # please note that SD3.5 Large is sensitive to highres generation like 1536x1536
    image = pipe(
        width=1024,
        height=1024,
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=35,
        guidance_scale=guidance_scale,
        generator=torch.Generator("cuda").manual_seed(42),
        clip_image=ref_img,
        ipadapter_scale=ipadapter_scale,
    ).images[0]

    return image


# ----------------------------
# Step 7: Gradio Interface
# ----------------------------
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
negative_prompt_box = gr.Textbox(label="Negative Prompt", placeholder="Enter your image generation prompt",value="lowres, low quality, worst quality")

ref_img = gr.File(label="Upload Reference Image")
guidance_slider = gr.Slider(
    label="Guidance Scale",
    minimum=2,
    maximum=16,
    value=7,
    step=0.5,
    info="Controls adherence to the text prompt"
)

ipadapter_slider = gr.Slider(
    label="IP-Adapter Scale",
    minimum=0,
    maximum=1,
    value=0.5,
    step=0.1,
    info="Controls influence of the image prompt"
)

interface = gr.Interface(
    fn=gui_generation,
    inputs=[prompt_box,negative_prompt_box, ref_img, guidance_slider, ipadapter_slider],
    outputs="image",
    title="Image Generation with Stable Diffusion 3.5 Large and IP-Adapter",
    description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
)

# ----------------------------
# Step 8: Launch Gradio App
# ----------------------------
interface.launch()