AlGe commited on
Commit
9d8d3cc
·
verified ·
1 Parent(s): ccbf9e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -47
app.py CHANGED
@@ -51,11 +51,9 @@ def process_ner(text: str, pipeline) -> dict:
51
  output = pipeline(text)
52
  entities = []
53
  current_entity = None
54
-
55
  for token in output:
56
  entity_type = token['entity'][2:]
57
  entity_prefix = token['entity'][:1]
58
-
59
  if current_entity is None or entity_type != current_entity['entity'] or (entity_prefix == 'B' and entity_type == current_entity['entity']):
60
  if current_entity is not None:
61
  entities.append(current_entity)
@@ -63,98 +61,80 @@ def process_ner(text: str, pipeline) -> dict:
63
  "entity": entity_type,
64
  "start": token['start'],
65
  "end": token['end'],
66
- "score": token['score'],
67
  "tokens": [token['word']]
68
  }
69
  else:
70
  current_entity['end'] = token['end']
71
- current_entity['score'] = max(current_entity['score'], token['score'])
72
  current_entity['tokens'].append(token['word'])
73
-
74
  if current_entity is not None:
75
  entities.append(current_entity)
76
-
 
77
  return {"text": text, "entities": entities}
78
 
79
  def generate_charts(ner_output_ext: dict) -> Tuple[go.Figure, np.ndarray]:
80
  entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
81
-
82
- # Counting entities for extended classification
83
  entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
84
  ext_labels = list(entity_counts_ext.keys())
85
  ext_sizes = list(entity_counts_ext.values())
86
-
87
  ext_color_map = {
88
- "INTemothou": "#FF7F50", # Coral
89
- "INTpercept": "#FF4500", # OrangeRed
90
- "INTtime": "#FF6347", # Tomato
91
- "INTplace": "#FFD700", # Gold
92
- "INTevent": "#FFA500", # Orange
93
- "EXTsemantic": "#4682B4", # SteelBlue
94
- "EXTrepetition": "#5F9EA0", # CadetBlue
95
- "EXTother": "#00CED1", # DarkTurquoise
96
  }
97
-
98
  ext_colors = [ext_color_map.get(label, "#FFFFFF") for label in ext_labels]
99
-
100
- # Create pie chart for extended classification
101
  fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))])
102
  fig1.update_layout(
103
  template='plotly_dark',
104
  plot_bgcolor='rgba(0,0,0,0)',
105
  paper_bgcolor='rgba(0,0,0,0)'
106
  )
107
- # Generate word cloud
108
  wordcloud_image = generate_wordcloud(ner_output_ext['entities'], ext_color_map, "dh3.png")
109
-
110
  return fig1, wordcloud_image
111
 
112
  def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str], file_path: str) -> np.ndarray:
113
  # Construct the absolute path
114
  base_path = os.path.dirname(os.path.abspath(__file__))
115
  image_path = os.path.join(base_path, file_path)
116
-
117
- # Debugging statement to print the image path
118
- print(f"Image path: {image_path}")
119
-
120
- # Check if the file exists
121
  if not os.path.exists(image_path):
122
  raise FileNotFoundError(f"Mask image file not found: {image_path}")
123
-
124
  mask_image = np.array(Image.open(image_path))
125
  mask_height, mask_width = mask_image.shape[:2]
126
-
127
- token_texts = []
128
- token_scores = []
129
- token_types = []
130
-
131
  for entity in entities:
132
- for token in entity['tokens']:
133
- cleaned_token = re.sub(r'^\W+', '', token)
134
- token_texts.append(cleaned_token)
135
- token_scores.append(entity['score'])
136
- token_types.append(entity['entity'])
137
- print(f"{cleaned_token} ({entity['entity']}): {entity['score']}")
138
-
139
- word_freq = {text: score for text, score in zip(token_texts, token_scores)}
140
-
 
 
141
  def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
142
- entity_type = next((t for t, w in zip(token_types, token_texts) if w == word), None)
143
  return color_map.get(entity_type, "#FFFFFF")
144
-
145
  wordcloud = WordCloud(width=mask_width, height=mask_height, background_color='#121212', mask=mask_image, color_func=color_func).generate_from_frequencies(word_freq)
146
-
147
  plt.figure(figsize=(mask_width/100, mask_height/100))
148
  plt.imshow(wordcloud, interpolation='bilinear')
149
  plt.axis('off')
150
  plt.tight_layout(pad=0)
151
-
152
  plt_image = plt.gcf()
153
  plt_image.canvas.draw()
154
  image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8)
155
  image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,))
156
  plt.close()
157
-
158
  return image_array
159
 
160
  @spaces.GPU
 
51
  output = pipeline(text)
52
  entities = []
53
  current_entity = None
 
54
  for token in output:
55
  entity_type = token['entity'][2:]
56
  entity_prefix = token['entity'][:1]
 
57
  if current_entity is None or entity_type != current_entity['entity'] or (entity_prefix == 'B' and entity_type == current_entity['entity']):
58
  if current_entity is not None:
59
  entities.append(current_entity)
 
61
  "entity": entity_type,
62
  "start": token['start'],
63
  "end": token['end'],
64
+ "scores": [token['score']],
65
  "tokens": [token['word']]
66
  }
67
  else:
68
  current_entity['end'] = token['end']
69
+ current_entity['scores'].append(token['score'])
70
  current_entity['tokens'].append(token['word'])
 
71
  if current_entity is not None:
72
  entities.append(current_entity)
73
+ for entity in entities:
74
+ entity['average_score'] = sum(entity['scores']) / len(entity['scores'])
75
  return {"text": text, "entities": entities}
76
 
77
  def generate_charts(ner_output_ext: dict) -> Tuple[go.Figure, np.ndarray]:
78
  entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
 
 
79
  entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
80
  ext_labels = list(entity_counts_ext.keys())
81
  ext_sizes = list(entity_counts_ext.values())
 
82
  ext_color_map = {
83
+ "INTemothou": "#FF7F50",
84
+ "INTpercept": "#FF4500",
85
+ "INTtime": "#FF6347",
86
+ "INTplace": "#FFD700",
87
+ "INTevent": "#FFA500",
88
+ "EXTsemantic": "#4682B4",
89
+ "EXTrepetition": "#5F9EA0",
90
+ "EXTother": "#00CED1",
91
  }
 
92
  ext_colors = [ext_color_map.get(label, "#FFFFFF") for label in ext_labels]
 
 
93
  fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))])
94
  fig1.update_layout(
95
  template='plotly_dark',
96
  plot_bgcolor='rgba(0,0,0,0)',
97
  paper_bgcolor='rgba(0,0,0,0)'
98
  )
 
99
  wordcloud_image = generate_wordcloud(ner_output_ext['entities'], ext_color_map, "dh3.png")
 
100
  return fig1, wordcloud_image
101
 
102
  def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str], file_path: str) -> np.ndarray:
103
  # Construct the absolute path
104
  base_path = os.path.dirname(os.path.abspath(__file__))
105
  image_path = os.path.join(base_path, file_path)
 
 
 
 
 
106
  if not os.path.exists(image_path):
107
  raise FileNotFoundError(f"Mask image file not found: {image_path}")
 
108
  mask_image = np.array(Image.open(image_path))
109
  mask_height, mask_width = mask_image.shape[:2]
110
+ entity_texts = []
111
+ entity_scores = []
112
+ entity_types = []
 
 
113
  for entity in entities:
114
+ print(f"E: {entity}")
115
+ segment_text = ' '.join(entity['tokens'])
116
+ #segment_text = re.sub(r'^\W+', '', segment_text)
117
+ entity_texts.append(segment_text)
118
+ if 'average_score' in entity:
119
+ entity_scores.append(entity['average_score'])
120
+ else:
121
+ entity_scores.append(0.5) # Example: Assigning a default score of 0.5 if 'average_score' is missing
122
+ entity_types.append(entity['entity'])
123
+ print(f"{segment_text} ({entity['entity']}): {entity.get('average_score', 0.5)}") # Print or log the score, using a default if missing
124
+ word_freq = {text: score for text, score in zip(entity_texts, entity_scores)}
125
  def color_func(word, font_size, position, orientation, random_state=None, **kwargs):
126
+ entity_type = next((t for t, w in zip(entity_types, entity_texts) if w == word), None)
127
  return color_map.get(entity_type, "#FFFFFF")
 
128
  wordcloud = WordCloud(width=mask_width, height=mask_height, background_color='#121212', mask=mask_image, color_func=color_func).generate_from_frequencies(word_freq)
 
129
  plt.figure(figsize=(mask_width/100, mask_height/100))
130
  plt.imshow(wordcloud, interpolation='bilinear')
131
  plt.axis('off')
132
  plt.tight_layout(pad=0)
 
133
  plt_image = plt.gcf()
134
  plt_image.canvas.draw()
135
  image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8)
136
  image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,))
137
  plt.close()
 
138
  return image_array
139
 
140
  @spaces.GPU