AlGe commited on
Commit
d8a3707
·
verified ·
1 Parent(s): efec366

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -76
app.py CHANGED
@@ -45,16 +45,6 @@ model_bin = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-lar
45
  tokenizer_bin.model_max_length = 512
46
  pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin)
47
 
48
- tokenizer_ext = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token)
49
- model_ext = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token)
50
- tokenizer_ext.model_max_length = 512
51
- pipe_ext = pipeline("ner", model=model_ext, tokenizer=tokenizer_ext)
52
-
53
- model1 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_Int_segment", num_labels=1, token=auth_token)
54
- tokenizer1 = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_Int_segment", token=auth_token)
55
-
56
- model2 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_seq_ext", num_labels=1, token=auth_token)
57
-
58
  def process_ner(text: str, pipeline) -> dict:
59
  output = pipeline(text)
60
  entities = []
@@ -84,60 +74,22 @@ def process_ner(text: str, pipeline) -> dict:
84
 
85
  return {"text": text, "entities": entities}
86
 
87
- def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str, str, str]:
88
- inputs1 = tokenizer1(text, max_length=512, return_tensors='pt', truncation=True, padding=True)
89
-
90
- with torch.no_grad():
91
- outputs1 = model1(**inputs1)
92
- outputs2 = model2(**inputs1)
93
-
94
- prediction1 = outputs1[0].item()
95
- prediction2 = outputs2[0].item()
96
- score = prediction1 / (prediction2 + prediction1)
97
-
98
- return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
99
-
100
-
101
- def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure, np.ndarray]:
102
  entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
103
- entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
104
-
105
  # Counting entities for binary classification
106
  entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
107
  bin_labels = list(entity_counts_bin.keys())
108
  bin_sizes = list(entity_counts_bin.values())
109
-
110
- # Counting entities for extended classification
111
- entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
112
- ext_labels = list(entity_counts_ext.keys())
113
- ext_sizes = list(entity_counts_ext.values())
114
 
115
  bin_color_map = {
116
  "External": "#6ad5bc",
117
  "Internal": "#ee8bac"
118
  }
119
 
120
- ext_color_map = {
121
- "INTemothou": "#FF7F50", # Coral
122
- "INTpercept": "#FF4500", # OrangeRed
123
- "INTtime": "#FF6347", # Tomato
124
- "INTplace": "#FFD700", # Gold
125
- "INTevent": "#FFA500", # Orange
126
- "EXTsemantic": "#4682B4", # SteelBlue
127
- "EXTrepetition": "#5F9EA0", # CadetBlue
128
- "EXTother": "#00CED1", # DarkTurquoise
129
- }
130
-
131
  bin_colors = [bin_color_map.get(label, "#FFFFFF") for label in bin_labels]
132
- ext_colors = [ext_color_map.get(label, "#FFFFFF") for label in ext_labels]
133
 
134
- # Create pie chart for extended classification
135
- fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))])
136
- fig1.update_layout(
137
- template='plotly_dark',
138
- plot_bgcolor='rgba(0,0,0,0)',
139
- paper_bgcolor='rgba(0,0,0,0)'
140
- )
141
 
142
  # Create bar chart for binary classification
143
  fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))])
@@ -150,9 +102,9 @@ def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figu
150
  )
151
 
152
  # Generate word cloud
153
- wordcloud_image = generate_wordcloud(ner_output_ext['entities'], ext_color_map)
154
 
155
- return fig1, fig2, wordcloud_image
156
 
157
 
158
  def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray:
@@ -193,14 +145,10 @@ def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.nd
193
  @spaces.GPU
194
  def all(text: str):
195
  ner_output_bin = process_ner(text, pipe_bin)
196
- ner_output_ext = process_ner(text, pipe_ext)
197
- classification_output = process_classification(text, model1, model2, tokenizer1)
198
-
199
- pie_chart, bar_chart, wordcloud_image = generate_charts(ner_output_bin, ner_output_ext)
200
 
201
- return (ner_output_bin, ner_output_ext,
202
- classification_output[0], classification_output[1], classification_output[2],
203
- pie_chart, bar_chart, wordcloud_image)
204
 
205
  examples = [
206
  ['Bevor ich meinen Hund kaufte bin ich immer alleine durch den Park gelaufen. Gestern war ich aber mit dem Hund losgelaufen. Das Wetter war sehr schön, nicht wie sonst im Winter. Ich weiß nicht genau. Mir fällt sonst nichts dazu ein. Wir trafen auf mehrere Spaziergänger. Ein Mann mit seinem Kind. Das Kind hat ein Eis gegessen.'],
@@ -215,22 +163,6 @@ iface = gr.Interface(
215
  "External": "#6ad5bcff",
216
  "Internal": "#ee8bacff"}
217
  ),
218
- gr.HighlightedText(label="Extended Sequence Classification",
219
- color_map={
220
- "INTemothou": "#FF7F50", # Coral
221
- "INTpercept": "#FF4500", # OrangeRed
222
- "INTtime": "#FF6347", # Tomato
223
- "INTplace": "#FFD700", # Gold
224
- "INTevent": "#FFA500", # Orange
225
- "EXTsemantic": "#4682B4", # SteelBlue
226
- "EXTrepetition": "#5F9EA0", # CadetBlue
227
- "EXTother": "#00CED1", # DarkTurquoise
228
- }
229
- ),
230
- gr.Label(label="Internal Detail Count"),
231
- gr.Label(label="External Detail Count"),
232
- gr.Label(label="Approximated Internal Detail Ratio"),
233
- gr.Plot(label="Extended SeqClass Entity Distribution Pie Chart"),
234
  gr.Plot(label="Binary SeqClass Entity Count Bar Chart"),
235
  gr.Image(label="Entity Word Cloud")
236
  ],
 
45
  tokenizer_bin.model_max_length = 512
46
  pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin)
47
 
 
 
 
 
 
 
 
 
 
 
48
  def process_ner(text: str, pipeline) -> dict:
49
  output = pipeline(text)
50
  entities = []
 
74
 
75
  return {"text": text, "entities": entities}
76
 
77
+ def generate_charts(ner_output_bin: dict) -> Tuple[go.Figure, np.ndarray]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
79
+
 
80
  # Counting entities for binary classification
81
  entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
82
  bin_labels = list(entity_counts_bin.keys())
83
  bin_sizes = list(entity_counts_bin.values())
84
+
 
 
 
 
85
 
86
  bin_color_map = {
87
  "External": "#6ad5bc",
88
  "Internal": "#ee8bac"
89
  }
90
 
 
 
 
 
 
 
 
 
 
 
 
91
  bin_colors = [bin_color_map.get(label, "#FFFFFF") for label in bin_labels]
 
92
 
 
 
 
 
 
 
 
93
 
94
  # Create bar chart for binary classification
95
  fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))])
 
102
  )
103
 
104
  # Generate word cloud
105
+ wordcloud_image = generate_wordcloud(ner_output_bin['entities'], bin_color_map)
106
 
107
+ return fig2, wordcloud_image
108
 
109
 
110
  def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray:
 
145
  @spaces.GPU
146
  def all(text: str):
147
  ner_output_bin = process_ner(text, pipe_bin)
148
+
149
+ bar_chart, wordcloud_image = generate_charts(ner_output_bin)
 
 
150
 
151
+ return (ner_output_bin, bar_chart, wordcloud_image)
 
 
152
 
153
  examples = [
154
  ['Bevor ich meinen Hund kaufte bin ich immer alleine durch den Park gelaufen. Gestern war ich aber mit dem Hund losgelaufen. Das Wetter war sehr schön, nicht wie sonst im Winter. Ich weiß nicht genau. Mir fällt sonst nichts dazu ein. Wir trafen auf mehrere Spaziergänger. Ein Mann mit seinem Kind. Das Kind hat ein Eis gegessen.'],
 
163
  "External": "#6ad5bcff",
164
  "Internal": "#ee8bacff"}
165
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  gr.Plot(label="Binary SeqClass Entity Count Bar Chart"),
167
  gr.Image(label="Entity Word Cloud")
168
  ],