Kyan14 commited on
Commit
c9985bd
·
1 Parent(s): 3cc4e4b

Had an issue of black images being generated

Browse files

Idea is to keep looping it in clip to classify it as black image or not and if it is keep generating till it isnt

Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -44,6 +44,10 @@ def get_mood_from_image(image: Image.Image):
44
 
45
  return selected_mood
46
 
 
 
 
 
47
  def generate_art(mood):
48
  # Implement art generation logic using the Stable Diffusion API
49
  prompt = f"{mood} generative art with vibrant colors and intricate patterns ({str(np.random.randint(1, 10000))})"
@@ -58,22 +62,25 @@ def generate_art(mood):
58
  }
59
 
60
  while True:
61
- response = requests.post('https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5', headers=headers, json=json_data)
62
-
63
- if response.status_code == 503:
64
- print("Model is loading, waiting for 30 seconds before retrying...")
65
- time.sleep(30)
66
- continue
67
-
68
- if response.status_code != 200:
69
- print(f"Error: API response status code {response.status_code}")
70
- print("Response content:")
71
- print(response.content)
72
- return None
73
-
74
- break
75
-
76
- image = Image.open(BytesIO(response.content))
 
 
 
77
 
78
  return image
79
 
 
44
 
45
  return selected_mood
46
 
47
+ def is_black_image(image: Image.Image) -> bool:
48
+ img_array = np.array(image)
49
+ return np.all(img_array == 0)
50
+
51
  def generate_art(mood):
52
  # Implement art generation logic using the Stable Diffusion API
53
  prompt = f"{mood} generative art with vibrant colors and intricate patterns ({str(np.random.randint(1, 10000))})"
 
62
  }
63
 
64
  while True:
65
+ while True:
66
+ response = requests.post('https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5', headers=headers, json=json_data)
67
+
68
+ if response.status_code == 503:
69
+ print("Model is loading, waiting for 30 seconds before retrying...")
70
+ time.sleep(30)
71
+ continue
72
+
73
+ if response.status_code != 200:
74
+ print(f"Error: API response status code {response.status_code}")
75
+ print("Response content:")
76
+ print(response.content)
77
+ return None
78
+
79
+ image = Image.open(BytesIO(response.content))
80
+
81
+ # Check if the image is black
82
+ if not is_black_image(image):
83
+ break
84
 
85
  return image
86