Kyan14 commited on
Commit
abd7ad6
·
1 Parent(s): 85fd9d1

Trying to prevent black image

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -43,8 +43,11 @@ def get_mood_from_image(image: Image.Image):
43
  selected_mood = max(mood_scores, key=mood_scores.get)
44
 
45
  return selected_mood
 
 
 
46
 
47
- def generate_art(mood):
48
  # Implement art generation logic using the Stable Diffusion API
49
  prompt = f"{mood} generative art with vibrant colors and intricate patterns ({str(np.random.randint(1, 10000))})"
50
 
@@ -57,7 +60,8 @@ def generate_art(mood):
57
  "inputs": prompt
58
  }
59
 
60
- while True:
 
61
  response = requests.post('https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5', headers=headers, json=json_data)
62
 
63
  if response.status_code == 503:
@@ -71,13 +75,21 @@ def generate_art(mood):
71
  print(response.content)
72
  return None
73
 
74
- break
 
 
 
 
75
 
76
- image = Image.open(BytesIO(response.content))
 
 
 
77
 
78
  return image
79
 
80
 
 
81
  def mood_art_generator(image):
82
  mood = get_mood_from_image(image)
83
  print("Mood:", mood)
 
43
  selected_mood = max(mood_scores, key=mood_scores.get)
44
 
45
  return selected_mood
46
+ def is_black_image(image: Image.Image) -> bool:
47
+ img_array = np.array(image)
48
+ return np.all(img_array == 0)
49
 
50
+ def generate_art(mood, max_retries=10):
51
  # Implement art generation logic using the Stable Diffusion API
52
  prompt = f"{mood} generative art with vibrant colors and intricate patterns ({str(np.random.randint(1, 10000))})"
53
 
 
60
  "inputs": prompt
61
  }
62
 
63
+ retries = 0
64
+ while retries < max_retries:
65
  response = requests.post('https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5', headers=headers, json=json_data)
66
 
67
  if response.status_code == 503:
 
75
  print(response.content)
76
  return None
77
 
78
+ image = Image.open(BytesIO(response.content))
79
+
80
+ # Check if the image is black
81
+ if not is_black_image(image):
82
+ break
83
 
84
+ retries += 1
85
+
86
+ if retries == max_retries:
87
+ return None
88
 
89
  return image
90
 
91
 
92
+
93
  def mood_art_generator(image):
94
  mood = get_mood_from_image(image)
95
  print("Mood:", mood)