File size: 1,841 Bytes
683afc3
88d1237
3aadc38
 
 
0737dc8
3aadc38
38e6a4b
4fbc46c
c1497a6
3aadc38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a655a
3aadc38
91a655a
3aadc38
 
 
91a655a
 
 
 
 
 
 
 
 
 
 
 
3aadc38
 
 
91a655a
 
 
3aadc38
 
 
91a655a
 
 
 
3aadc38
a4cc7b2
91a655a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import torch
from PIL import Image
from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
import os
from huggingface_hub import login

token = os.getenv("HF_TOKEN")
login(token=token)

# Model and paths
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"

# Load SD3.5 pipeline and components
transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path, transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")

pipe.init_ipadapter(
    ip_adapter_path=ip_adapter_path,
    image_encoder_path=image_encoder_path,
    nb_token=64,
)


@gr.Interface()
def gui_generation(image: Image, style_image: Image):
    """
    Generate an image based on input and style images.
    """
    generator = torch.Generator("cuda").manual_seed(42)  # Reproducibility

    output = pipe(
        width=1024,
        height=1024,
        prompt="",
        negative_prompt="",
        num_inference_steps=24,
        guidance_scale=5.0,
        generator=generator,
        clip_image=style_image,
        ipadapter_scale=0.5,
    ).images[0]
    return output


# Gradio UI elements
image_input = gr.Image(type="pil", label="Input Image")
style_image_input = gr.Image(type="pil", label="Style Image")
output_image = gr.Image(label="Generated Image")

interface = gr.Interface(
    gui_generation,
    inputs=[image_input, style_image_input],
    outputs=output_image,
    title="Image Generation with Style Image",
    description="Upload an input image and a style image to generate a new image based on the style."
)

interface.launch()