uruguayai commited on
Commit
591cbd9
·
verified ·
1 Parent(s): 5849d75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -41
app.py CHANGED
@@ -5,7 +5,6 @@ from PIL import Image
5
  import numpy as np
6
  from torchvision import transforms
7
 
8
- # Set up device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # Initialize the inpainting model
@@ -18,65 +17,38 @@ except Exception as e:
18
  # Load a test image
19
  def load_test_image():
20
  try:
21
- return Image.open("path/to/your/test_image.png") # Replace with a valid path
 
22
  except Exception as e:
23
  print(f"Error loading test image: {e}")
24
  return None
25
 
26
  def process_image(prompt, image, style, upscale_factor, inpaint):
27
- try:
28
- # Use a test image if no image is received
29
  if image is None:
30
- image = load_test_image()
31
- if image is None:
32
- return None, "No image received and failed to load test image."
33
 
 
34
  print(f"Received image type: {type(image)}")
35
-
36
  if isinstance(image, np.ndarray):
37
- print(f"Image shape: {image.shape}")
38
  image = Image.fromarray(image)
39
  elif isinstance(image, torch.Tensor):
40
- print(f"Image tensor shape: {image.shape}")
41
  image = transforms.ToPILImage()(image)
42
- elif isinstance(image, Image.Image):
43
- print("Image is already in PIL format.")
44
- else:
45
  return None, f"Unsupported image format: {type(image)}."
46
-
47
- # Check if the image is valid
48
- if not isinstance(image, Image.Image):
49
- return None, "Error: Image format conversion failed."
50
-
51
- # Log the input parameters
52
- print(f"Prompt: {prompt}")
53
- print(f"Style: {style}")
54
- print(f"Upscale Factor: {upscale_factor}")
55
- print(f"Inpaint: {inpaint}")
56
-
57
- # Example placeholder logic for using the pipeline
58
- if inpaint and inpaint_model:
59
- result = inpaint_model(prompt=prompt, image=image, guidance_scale=7.5)
60
- else:
61
- result = inpaint_model(prompt=prompt, guidance_scale=7.5) if inpaint_model else None
62
-
63
- # Check if the result is valid
64
- if result and hasattr(result, 'images') and len(result.images) > 0:
65
- return result.images[0], None # Return image and no error
66
- else:
67
- return None, "Error: No image returned from model." # Return no image and an error message
68
-
69
  except Exception as e:
70
  error_message = f"Error in process_image function: {e}"
71
  print(error_message)
72
- return None, error_message # Return no image and the error message
73
 
74
- # Define the Gradio interface
75
  with gr.Blocks() as demo:
76
  with gr.Row():
77
  with gr.Column():
78
  prompt_input = gr.Textbox(label="Enter your prompt")
79
- image_input = gr.Image(label="Image (for inpainting)", type="pil") # Ensure type is PIL
80
  style_input = gr.Dropdown(choices=["Fooocus Style", "SAI Anime"], label="Select Style")
81
  upscale_input = gr.Slider(minimum=1, maximum=4, step=1, label="Upscale Factor")
82
  inpaint_input = gr.Checkbox(label="Enable Inpainting")
@@ -88,8 +60,7 @@ with gr.Blocks() as demo:
88
  generate_button.click(
89
  process_image,
90
  inputs=[prompt_input, image_input, style_input, upscale_input, inpaint_input],
91
- outputs=[output_image, error_output] # Handle both image and error output
92
  )
93
 
94
- # Launch the interface
95
  demo.launch()
 
5
  import numpy as np
6
  from torchvision import transforms
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # Initialize the inpainting model
 
17
  # Load a test image
18
  def load_test_image():
19
  try:
20
+ # Provide the absolute path to a test image
21
+ return Image.open("/absolute/path/to/your/test_image.png")
22
  except Exception as e:
23
  print(f"Error loading test image: {e}")
24
  return None
25
 
26
  def process_image(prompt, image, style, upscale_factor, inpaint):
27
+ if image is None:
28
+ image = load_test_image()
29
  if image is None:
30
+ return None, "No image received and failed to load test image."
 
 
31
 
32
+ try:
33
  print(f"Received image type: {type(image)}")
 
34
  if isinstance(image, np.ndarray):
 
35
  image = Image.fromarray(image)
36
  elif isinstance(image, torch.Tensor):
 
37
  image = transforms.ToPILImage()(image)
38
+ elif not isinstance(image, Image.Image):
 
 
39
  return None, f"Unsupported image format: {type(image)}."
40
+
41
+ return image, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
  error_message = f"Error in process_image function: {e}"
44
  print(error_message)
45
+ return None, error_message
46
 
 
47
  with gr.Blocks() as demo:
48
  with gr.Row():
49
  with gr.Column():
50
  prompt_input = gr.Textbox(label="Enter your prompt")
51
+ image_input = gr.Image(label="Image (for inpainting)", type="pil")
52
  style_input = gr.Dropdown(choices=["Fooocus Style", "SAI Anime"], label="Select Style")
53
  upscale_input = gr.Slider(minimum=1, maximum=4, step=1, label="Upscale Factor")
54
  inpaint_input = gr.Checkbox(label="Enable Inpainting")
 
60
  generate_button.click(
61
  process_image,
62
  inputs=[prompt_input, image_input, style_input, upscale_input, inpaint_input],
63
+ outputs=[output_image, error_output]
64
  )
65
 
 
66
  demo.launch()