amos1088 commited on
Commit
6dcb6b3
·
1 Parent(s): 545ba28
Files changed (1) hide show
  1. app.py +40 -18
app.py CHANGED
@@ -5,13 +5,15 @@ import gradio as gr
5
  import spaces
6
  from PIL import Image
7
  from huggingface_hub import login
8
- from diffusers.utils import load_image
9
  from torchvision import transforms
 
10
 
11
  from models.transformer_sd3 import SD3Transformer2DModel
12
  from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
13
 
14
- # Download IP Adapter if not exists
 
 
15
  url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
16
  file_path = "ip-adapter.bin"
17
 
@@ -24,22 +26,30 @@ if not os.path.exists(file_path):
24
  file.write(chunk)
25
  print("Download completed!")
26
 
27
- # Hugging Face login
 
 
28
  token = os.getenv("HF_TOKEN")
 
 
29
  login(token=token)
30
 
31
- # Model paths
 
 
32
  model_path = 'stabilityai/stable-diffusion-3.5-large'
33
  ip_adapter_path = './ip-adapter.bin'
34
  image_encoder_path = "google/siglip-so400m-patch14-384"
35
 
36
- # Load transformer and pipeline
 
 
37
  transformer = SD3Transformer2DModel.from_pretrained(
38
- model_path, subfolder="transformer", torch_dtype=torch.bfloat16
39
  )
40
 
41
  pipe = StableDiffusion3Pipeline.from_pretrained(
42
- model_path, transformer=transformer, torch_dtype=torch.bfloat16
43
  ).to("cuda")
44
 
45
  pipe.init_ipadapter(
@@ -48,21 +58,29 @@ pipe.init_ipadapter(
48
  nb_token=64,
49
  )
50
 
51
-
52
- @spaces.GPU
53
- def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
54
-
55
- # Load and preprocess the reference image
56
  preprocess = transforms.Compose([
57
  transforms.Resize((384, 384)),
58
  transforms.ToTensor(),
59
  transforms.ConvertImageDtype(torch.float16)
60
  ])
 
 
61
 
62
- ref_img = Image.open(ref_img.name).convert('RGB')
63
- ref_img_tensor = preprocess(ref_img).unsqueeze(0).to("cuda")
 
 
 
 
 
 
64
 
65
- # Generate the image
66
  with torch.no_grad():
67
  image = pipe(
68
  width=1024,
@@ -78,8 +96,9 @@ def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
78
 
79
  return image
80
 
81
-
82
- # Set up Gradio interface
 
83
  prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
84
  ref_img = gr.File(label="Upload Reference Image")
85
  guidance_slider = gr.Slider(
@@ -108,4 +127,7 @@ interface = gr.Interface(
108
  description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
109
  )
110
 
111
- interface.launch(share=True)
 
 
 
 
5
  import spaces
6
  from PIL import Image
7
  from huggingface_hub import login
 
8
  from torchvision import transforms
9
+ from diffusers.utils import load_image
10
 
11
  from models.transformer_sd3 import SD3Transformer2DModel
12
  from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
13
 
14
+ # ----------------------------
15
+ # Step 1: Download IP Adapter if not exists
16
+ # ----------------------------
17
  url = "https://huggingface.co/InstantX/SD3.5-Large-IP-Adapter/resolve/main/ip-adapter.bin"
18
  file_path = "ip-adapter.bin"
19
 
 
26
  file.write(chunk)
27
  print("Download completed!")
28
 
29
+ # ----------------------------
30
+ # Step 2: Hugging Face Login
31
+ # ----------------------------
32
  token = os.getenv("HF_TOKEN")
33
+ if not token:
34
+ raise ValueError("Hugging Face token not found. Set the 'HF_TOKEN' environment variable.")
35
  login(token=token)
36
 
37
+ # ----------------------------
38
+ # Step 3: Model Paths
39
+ # ----------------------------
40
  model_path = 'stabilityai/stable-diffusion-3.5-large'
41
  ip_adapter_path = './ip-adapter.bin'
42
  image_encoder_path = "google/siglip-so400m-patch14-384"
43
 
44
+ # ----------------------------
45
+ # Step 4: Load Transformer and Pipeline
46
+ # ----------------------------
47
  transformer = SD3Transformer2DModel.from_pretrained(
48
+ model_path, subfolder="transformer", torch_dtype=torch.float16
49
  )
50
 
51
  pipe = StableDiffusion3Pipeline.from_pretrained(
52
+ model_path, transformer=transformer, torch_dtype=torch.float16
53
  ).to("cuda")
54
 
55
  pipe.init_ipadapter(
 
58
  nb_token=64,
59
  )
60
 
61
+ # ----------------------------
62
+ # Step 5: Image Preprocessing Function
63
+ # ----------------------------
64
+ def preprocess_image(image_path):
65
+ """Preprocess the input image for the pipeline."""
66
  preprocess = transforms.Compose([
67
  transforms.Resize((384, 384)),
68
  transforms.ToTensor(),
69
  transforms.ConvertImageDtype(torch.float16)
70
  ])
71
+ image = Image.open(image_path).convert('RGB')
72
+ return preprocess(image).unsqueeze(0).to("cuda")
73
 
74
+ # ----------------------------
75
+ # Step 6: Gradio Function
76
+ # ----------------------------
77
+ @spaces.GPU
78
+ def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
79
+ """Generate an image using Stable Diffusion 3.5 Large with IP-Adapter."""
80
+ # Preprocess the reference image
81
+ ref_img_tensor = preprocess_image(ref_img.name)
82
 
83
+ # Run the pipeline
84
  with torch.no_grad():
85
  image = pipe(
86
  width=1024,
 
96
 
97
  return image
98
 
99
+ # ----------------------------
100
+ # Step 7: Gradio Interface
101
+ # ----------------------------
102
  prompt_box = gr.Textbox(label="Prompt", placeholder="Enter your image generation prompt")
103
  ref_img = gr.File(label="Upload Reference Image")
104
  guidance_slider = gr.Slider(
 
127
  description="Generates an image based on a text prompt and a reference image using Stable Diffusion 3.5 Large with IP-Adapter."
128
  )
129
 
130
+ # ----------------------------
131
+ # Step 8: Launch Gradio App
132
+ # ----------------------------
133
+ interface.launch(share=True)