amos1088 commited on
Commit
d5f11d4
·
1 Parent(s): d09f5de
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -58,15 +58,17 @@ pipe.init_ipadapter(
58
  nb_token=64,
59
  )
60
 
 
61
  # ----------------------------
62
- # Step 5: Image Preprocessing Function
63
  # ----------------------------
64
  def preprocess_image(image_path):
65
- """Preprocess the input image for the pipeline."""
66
  image = Image.open(image_path).convert("RGB")
67
- # Ensure image is resized into a square based on the max dimension
68
- size = max(image.size)
69
- image = image.resize((size, size))
 
70
 
71
  preprocess = transforms.Compose([
72
  transforms.Resize((384, 384)),
@@ -75,14 +77,18 @@ def preprocess_image(image_path):
75
  ])
76
  return preprocess(image).unsqueeze(0).to("cuda")
77
 
 
78
  # ----------------------------
79
  # Step 6: Gradio Function
80
  # ----------------------------
81
  @spaces.GPU
82
  def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
83
  """Generate an image using Stable Diffusion 3.5 Large with IP-Adapter."""
84
- # Preprocess the reference image
85
- ref_img_tensor = preprocess_image(ref_img.name)
 
 
 
86
 
87
  # Run the pipeline
88
  with torch.no_grad():
@@ -100,6 +106,7 @@ def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
100
 
101
  return image
102
 
 
103
  # ----------------------------
104
  # Step 7: Gradio Interface
105
  # ----------------------------
 
58
  nb_token=64,
59
  )
60
 
61
+
62
  # ----------------------------
63
+ # Step 5: Preprocess Reference Image
64
  # ----------------------------
65
  def preprocess_image(image_path):
66
+ """Ensure the input image is a valid PIL Image and resize it."""
67
  image = Image.open(image_path).convert("RGB")
68
+
69
+ # Ensure the image is resized into a square
70
+ size = max(image.size) # Get the largest dimension
71
+ image = image.resize((size, size), Image.BILINEAR)
72
 
73
  preprocess = transforms.Compose([
74
  transforms.Resize((384, 384)),
 
77
  ])
78
  return preprocess(image).unsqueeze(0).to("cuda")
79
 
80
+
81
  # ----------------------------
82
  # Step 6: Gradio Function
83
  # ----------------------------
84
  @spaces.GPU
85
  def gui_generation(prompt, ref_img, guidance_scale, ipadapter_scale):
86
  """Generate an image using Stable Diffusion 3.5 Large with IP-Adapter."""
87
+ try:
88
+ # Load and preprocess the reference image
89
+ ref_img_tensor = preprocess_image(ref_img.name)
90
+ except Exception as e:
91
+ raise ValueError(f"Error loading reference image: {e}")
92
 
93
  # Run the pipeline
94
  with torch.no_grad():
 
106
 
107
  return image
108
 
109
+
110
  # ----------------------------
111
  # Step 7: Gradio Interface
112
  # ----------------------------