Update app.py
Browse files
app.py
CHANGED
@@ -45,16 +45,6 @@ model_bin = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-lar
|
|
45 |
tokenizer_bin.model_max_length = 512
|
46 |
pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin)
|
47 |
|
48 |
-
tokenizer_ext = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token)
|
49 |
-
model_ext = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token)
|
50 |
-
tokenizer_ext.model_max_length = 512
|
51 |
-
pipe_ext = pipeline("ner", model=model_ext, tokenizer=tokenizer_ext)
|
52 |
-
|
53 |
-
model1 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_Int_segment", num_labels=1, token=auth_token)
|
54 |
-
tokenizer1 = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_Int_segment", token=auth_token)
|
55 |
-
|
56 |
-
model2 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_seq_ext", num_labels=1, token=auth_token)
|
57 |
-
|
58 |
def process_ner(text: str, pipeline) -> dict:
|
59 |
output = pipeline(text)
|
60 |
entities = []
|
@@ -84,60 +74,22 @@ def process_ner(text: str, pipeline) -> dict:
|
|
84 |
|
85 |
return {"text": text, "entities": entities}
|
86 |
|
87 |
-
def
|
88 |
-
inputs1 = tokenizer1(text, max_length=512, return_tensors='pt', truncation=True, padding=True)
|
89 |
-
|
90 |
-
with torch.no_grad():
|
91 |
-
outputs1 = model1(**inputs1)
|
92 |
-
outputs2 = model2(**inputs1)
|
93 |
-
|
94 |
-
prediction1 = outputs1[0].item()
|
95 |
-
prediction2 = outputs2[0].item()
|
96 |
-
score = prediction1 / (prediction2 + prediction1)
|
97 |
-
|
98 |
-
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
|
99 |
-
|
100 |
-
|
101 |
-
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure, np.ndarray]:
|
102 |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
|
103 |
-
|
104 |
-
|
105 |
# Counting entities for binary classification
|
106 |
entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
|
107 |
bin_labels = list(entity_counts_bin.keys())
|
108 |
bin_sizes = list(entity_counts_bin.values())
|
109 |
-
|
110 |
-
# Counting entities for extended classification
|
111 |
-
entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
|
112 |
-
ext_labels = list(entity_counts_ext.keys())
|
113 |
-
ext_sizes = list(entity_counts_ext.values())
|
114 |
|
115 |
bin_color_map = {
|
116 |
"External": "#6ad5bc",
|
117 |
"Internal": "#ee8bac"
|
118 |
}
|
119 |
|
120 |
-
ext_color_map = {
|
121 |
-
"INTemothou": "#FF7F50", # Coral
|
122 |
-
"INTpercept": "#FF4500", # OrangeRed
|
123 |
-
"INTtime": "#FF6347", # Tomato
|
124 |
-
"INTplace": "#FFD700", # Gold
|
125 |
-
"INTevent": "#FFA500", # Orange
|
126 |
-
"EXTsemantic": "#4682B4", # SteelBlue
|
127 |
-
"EXTrepetition": "#5F9EA0", # CadetBlue
|
128 |
-
"EXTother": "#00CED1", # DarkTurquoise
|
129 |
-
}
|
130 |
-
|
131 |
bin_colors = [bin_color_map.get(label, "#FFFFFF") for label in bin_labels]
|
132 |
-
ext_colors = [ext_color_map.get(label, "#FFFFFF") for label in ext_labels]
|
133 |
|
134 |
-
# Create pie chart for extended classification
|
135 |
-
fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))])
|
136 |
-
fig1.update_layout(
|
137 |
-
template='plotly_dark',
|
138 |
-
plot_bgcolor='rgba(0,0,0,0)',
|
139 |
-
paper_bgcolor='rgba(0,0,0,0)'
|
140 |
-
)
|
141 |
|
142 |
# Create bar chart for binary classification
|
143 |
fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))])
|
@@ -150,9 +102,9 @@ def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figu
|
|
150 |
)
|
151 |
|
152 |
# Generate word cloud
|
153 |
-
wordcloud_image = generate_wordcloud(
|
154 |
|
155 |
-
return
|
156 |
|
157 |
|
158 |
def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray:
|
@@ -193,14 +145,10 @@ def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.nd
|
|
193 |
@spaces.GPU
|
194 |
def all(text: str):
|
195 |
ner_output_bin = process_ner(text, pipe_bin)
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
pie_chart, bar_chart, wordcloud_image = generate_charts(ner_output_bin, ner_output_ext)
|
200 |
|
201 |
-
return (ner_output_bin,
|
202 |
-
classification_output[0], classification_output[1], classification_output[2],
|
203 |
-
pie_chart, bar_chart, wordcloud_image)
|
204 |
|
205 |
examples = [
|
206 |
['Bevor ich meinen Hund kaufte bin ich immer alleine durch den Park gelaufen. Gestern war ich aber mit dem Hund losgelaufen. Das Wetter war sehr schön, nicht wie sonst im Winter. Ich weiß nicht genau. Mir fällt sonst nichts dazu ein. Wir trafen auf mehrere Spaziergänger. Ein Mann mit seinem Kind. Das Kind hat ein Eis gegessen.'],
|
@@ -215,22 +163,6 @@ iface = gr.Interface(
|
|
215 |
"External": "#6ad5bcff",
|
216 |
"Internal": "#ee8bacff"}
|
217 |
),
|
218 |
-
gr.HighlightedText(label="Extended Sequence Classification",
|
219 |
-
color_map={
|
220 |
-
"INTemothou": "#FF7F50", # Coral
|
221 |
-
"INTpercept": "#FF4500", # OrangeRed
|
222 |
-
"INTtime": "#FF6347", # Tomato
|
223 |
-
"INTplace": "#FFD700", # Gold
|
224 |
-
"INTevent": "#FFA500", # Orange
|
225 |
-
"EXTsemantic": "#4682B4", # SteelBlue
|
226 |
-
"EXTrepetition": "#5F9EA0", # CadetBlue
|
227 |
-
"EXTother": "#00CED1", # DarkTurquoise
|
228 |
-
}
|
229 |
-
),
|
230 |
-
gr.Label(label="Internal Detail Count"),
|
231 |
-
gr.Label(label="External Detail Count"),
|
232 |
-
gr.Label(label="Approximated Internal Detail Ratio"),
|
233 |
-
gr.Plot(label="Extended SeqClass Entity Distribution Pie Chart"),
|
234 |
gr.Plot(label="Binary SeqClass Entity Count Bar Chart"),
|
235 |
gr.Image(label="Entity Word Cloud")
|
236 |
],
|
|
|
45 |
tokenizer_bin.model_max_length = 512
|
46 |
pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def process_ner(text: str, pipeline) -> dict:
|
49 |
output = pipeline(text)
|
50 |
entities = []
|
|
|
74 |
|
75 |
return {"text": text, "entities": entities}
|
76 |
|
77 |
+
def generate_charts(ner_output_bin: dict) -> Tuple[go.Figure, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
|
79 |
+
|
|
|
80 |
# Counting entities for binary classification
|
81 |
entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
|
82 |
bin_labels = list(entity_counts_bin.keys())
|
83 |
bin_sizes = list(entity_counts_bin.values())
|
84 |
+
|
|
|
|
|
|
|
|
|
85 |
|
86 |
bin_color_map = {
|
87 |
"External": "#6ad5bc",
|
88 |
"Internal": "#ee8bac"
|
89 |
}
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
bin_colors = [bin_color_map.get(label, "#FFFFFF") for label in bin_labels]
|
|
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
# Create bar chart for binary classification
|
95 |
fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))])
|
|
|
102 |
)
|
103 |
|
104 |
# Generate word cloud
|
105 |
+
wordcloud_image = generate_wordcloud(ner_output_bin['entities'], bin_color_map)
|
106 |
|
107 |
+
return fig2, wordcloud_image
|
108 |
|
109 |
|
110 |
def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray:
|
|
|
145 |
@spaces.GPU
|
146 |
def all(text: str):
|
147 |
ner_output_bin = process_ner(text, pipe_bin)
|
148 |
+
|
149 |
+
bar_chart, wordcloud_image = generate_charts(ner_output_bin)
|
|
|
|
|
150 |
|
151 |
+
return (ner_output_bin, bar_chart, wordcloud_image)
|
|
|
|
|
152 |
|
153 |
examples = [
|
154 |
['Bevor ich meinen Hund kaufte bin ich immer alleine durch den Park gelaufen. Gestern war ich aber mit dem Hund losgelaufen. Das Wetter war sehr schön, nicht wie sonst im Winter. Ich weiß nicht genau. Mir fällt sonst nichts dazu ein. Wir trafen auf mehrere Spaziergänger. Ein Mann mit seinem Kind. Das Kind hat ein Eis gegessen.'],
|
|
|
163 |
"External": "#6ad5bcff",
|
164 |
"Internal": "#ee8bacff"}
|
165 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
gr.Plot(label="Binary SeqClass Entity Count Bar Chart"),
|
167 |
gr.Image(label="Entity Word Cloud")
|
168 |
],
|