AlGe commited on
Commit
929679c
·
verified ·
1 Parent(s): 8793006

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -109,34 +109,39 @@ def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str], file_pat
109
  mask_image = np.array(Image.open(image_path))
110
  mask_height, mask_width = mask_image.shape[:2]
111
 
112
- entity_texts = []
113
- entity_scores = []
114
- entity_types = []
115
 
116
  for entity in entities:
117
- print(f"E: {entity}")
118
- # Join tokens and process the text
119
- segment_text = ' '.join(entity['tokens'])
120
- # Replace "▁" with normal spaces
121
- segment_text = segment_text.replace("▁", " ")
122
- # Remove all normal spaces
123
- segment_text = segment_text.replace(" ", "")
124
-
125
- entity_texts.append(segment_text)
126
- if 'average_score' in entity:
127
- entity_scores.append(entity['average_score'])
 
 
 
 
 
128
  else:
129
- entity_scores.append(0.5) # Example: Assigning a default score of 0.5 if 'average_score' is missing
130
- entity_types.append(entity['entity'])
131
- print(f"{segment_text} ({entity['entity']}): {entity.get('average_score', 0.5)}") # Print or log the score, using a default if missing
132
 
133
- word_freq = {text: score for text, score in zip(entity_texts, entity_scores)}
 
 
 
 
134
 
135
  def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
136
- entity_type = next((t for t, w in zip(entity_types, entity_texts) if w == word), None)
137
  return color_map.get(entity_type, "#FFFFFF")
138
 
139
- wordcloud = WordCloud(width=mask_width, height=mask_height, background_color='#121212', mask=mask_image, color_func=color_func).generate_from_frequencies(word_freq)
140
 
141
  plt.figure(figsize=(mask_width/100, mask_height/100))
142
  plt.imshow(wordcloud, interpolation='bilinear')
 
109
  mask_image = np.array(Image.open(image_path))
110
  mask_height, mask_width = mask_image.shape[:2]
111
 
112
+ word_details = []
 
 
113
 
114
  for entity in entities:
115
+ for token in entity['tokens']:
116
+ # Process each token
117
+ token_text = token.replace("▁", " ").strip()
118
+ if token_text: # Ensure token is not empty
119
+ word_details.append({
120
+ 'text': token_text,
121
+ 'score': entity.get('average_score', 0.5),
122
+ 'entity': entity['entity']
123
+ })
124
+
125
+ # Calculate word frequency weighted by score
126
+ word_freq = {}
127
+ for detail in word_details:
128
+ if detail['text'] in word_freq:
129
+ word_freq[detail['text']]['score'] += detail['score']
130
+ word_freq[detail['text']]['count'] += 1
131
  else:
132
+ word_freq[detail['text']] = {'score': detail['score'], 'count': 1, 'entity': detail['entity']}
 
 
133
 
134
+ # Average the scores and prepare final frequency dictionary
135
+ final_word_freq = {word: details['score'] / details['count'] for word, details in word_freq.items()}
136
+
137
+ # Prepare entity type mapping for color function
138
+ word_to_entity = {word: details['entity'] for word, details in word_freq.items()}
139
 
140
  def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
141
+ entity_type = word_to_entity.get(word, None)
142
  return color_map.get(entity_type, "#FFFFFF")
143
 
144
+ wordcloud = WordCloud(width=mask_width, height=mask_height, background_color='#121212', mask=mask_image, color_func=color_func).generate_from_frequencies(final_word_freq)
145
 
146
  plt.figure(figsize=(mask_width/100, mask_height/100))
147
  plt.imshow(wordcloud, interpolation='bilinear')