AlGe commited on
Commit
8793006
·
verified ·
1 Parent(s): 9d8d3cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -105,15 +105,23 @@ def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str], file_pat
105
  image_path = os.path.join(base_path, file_path)
106
  if not os.path.exists(image_path):
107
  raise FileNotFoundError(f"Mask image file not found: {image_path}")
 
108
  mask_image = np.array(Image.open(image_path))
109
  mask_height, mask_width = mask_image.shape[:2]
 
110
  entity_texts = []
111
  entity_scores = []
112
  entity_types = []
 
113
  for entity in entities:
114
  print(f"E: {entity}")
 
115
  segment_text = ' '.join(entity['tokens'])
116
- #segment_text = re.sub(r'^\W+', '', segment_text)
 
 
 
 
117
  entity_texts.append(segment_text)
118
  if 'average_score' in entity:
119
  entity_scores.append(entity['average_score'])
@@ -121,20 +129,26 @@ def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str], file_pat
121
  entity_scores.append(0.5) # Example: Assigning a default score of 0.5 if 'average_score' is missing
122
  entity_types.append(entity['entity'])
123
  print(f"{segment_text} ({entity['entity']}): {entity.get('average_score', 0.5)}") # Print or log the score, using a default if missing
 
124
  word_freq = {text: score for text, score in zip(entity_texts, entity_scores)}
 
125
  def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
126
  entity_type = next((t for t, w in zip(entity_types, entity_texts) if w == word), None)
127
  return color_map.get(entity_type, "#FFFFFF")
 
128
  wordcloud = WordCloud(width=mask_width, height=mask_height, background_color='#121212', mask=mask_image, color_func=color_func).generate_from_frequencies(word_freq)
 
129
  plt.figure(figsize=(mask_width/100, mask_height/100))
130
  plt.imshow(wordcloud, interpolation='bilinear')
131
  plt.axis('off')
132
  plt.tight_layout(pad=0)
 
133
  plt_image = plt.gcf()
134
  plt_image.canvas.draw()
135
  image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8)
136
  image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,))
137
  plt.close()
 
138
  return image_array
139
 
140
  @spaces.GPU
 
105
  image_path = os.path.join(base_path, file_path)
106
  if not os.path.exists(image_path):
107
  raise FileNotFoundError(f"Mask image file not found: {image_path}")
108
+
109
  mask_image = np.array(Image.open(image_path))
110
  mask_height, mask_width = mask_image.shape[:2]
111
+
112
  entity_texts = []
113
  entity_scores = []
114
  entity_types = []
115
+
116
  for entity in entities:
117
  print(f"E: {entity}")
118
+ # Join tokens and process the text
119
  segment_text = ' '.join(entity['tokens'])
120
+ # Replace "▁" with normal spaces
121
+ segment_text = segment_text.replace("▁", " ")
122
+ # Remove all normal spaces
123
+ segment_text = segment_text.replace(" ", "")
124
+
125
  entity_texts.append(segment_text)
126
  if 'average_score' in entity:
127
  entity_scores.append(entity['average_score'])
 
129
  entity_scores.append(0.5) # Example: Assigning a default score of 0.5 if 'average_score' is missing
130
  entity_types.append(entity['entity'])
131
  print(f"{segment_text} ({entity['entity']}): {entity.get('average_score', 0.5)}") # Print or log the score, using a default if missing
132
+
133
  word_freq = {text: score for text, score in zip(entity_texts, entity_scores)}
134
+
135
  def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
136
  entity_type = next((t for t, w in zip(entity_types, entity_texts) if w == word), None)
137
  return color_map.get(entity_type, "#FFFFFF")
138
+
139
  wordcloud = WordCloud(width=mask_width, height=mask_height, background_color='#121212', mask=mask_image, color_func=color_func).generate_from_frequencies(word_freq)
140
+
141
  plt.figure(figsize=(mask_width/100, mask_height/100))
142
  plt.imshow(wordcloud, interpolation='bilinear')
143
  plt.axis('off')
144
  plt.tight_layout(pad=0)
145
+
146
  plt_image = plt.gcf()
147
  plt_image.canvas.draw()
148
  image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8)
149
  image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,))
150
  plt.close()
151
+
152
  return image_array
153
 
154
  @spaces.GPU