|
from __future__ import annotations |
|
from typing import Iterable, List, Dict, Tuple |
|
|
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.soft import Soft |
|
from gradio.themes.monochrome import Monochrome |
|
from gradio.themes.default import Default |
|
from gradio.themes.utils import colors, fonts, sizes |
|
|
|
import spaces |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, pipeline |
|
import os |
|
import colorsys |
|
import matplotlib.pyplot as plt |
|
import plotly.graph_objects as go |
|
from typing import Tuple |
|
import plotly.io as pio |
|
|
|
|
|
from wordcloud import WordCloud |
|
import io |
|
|
|
def hex_to_rgb(hex_color: str) -> tuple[int, int, int]: |
|
hex_color = hex_color.lstrip('#') |
|
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) |
|
|
|
def rgb_to_hex(rgb_color: tuple[int, int, int]) -> str: |
|
return "#{:02x}{:02x}{:02x}".format(*rgb_color) |
|
|
|
def adjust_brightness(rgb_color: tuple[int, int, int], factor: float) -> tuple[int, int, int]: |
|
hsv_color = colorsys.rgb_to_hsv(*[v / 255.0 for v in rgb_color]) |
|
new_v = max(0, min(hsv_color[2] * factor, 1)) |
|
new_rgb = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], new_v) |
|
return tuple(int(v * 255) for v in new_rgb) |
|
|
|
monochrome = Monochrome() |
|
|
|
auth_token = os.environ['HF_TOKEN'] |
|
|
|
tokenizer_bin = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_token", token=auth_token) |
|
model_bin = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_token", token=auth_token) |
|
tokenizer_bin.model_max_length = 512 |
|
pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin) |
|
|
|
tokenizer_ext = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token) |
|
model_ext = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token) |
|
tokenizer_ext.model_max_length = 512 |
|
pipe_ext = pipeline("ner", model=model_ext, tokenizer=tokenizer_ext) |
|
|
|
model1 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_Int_segment", num_labels=1, token=auth_token) |
|
tokenizer1 = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_Int_segment", token=auth_token) |
|
|
|
model2 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_seq_ext", num_labels=1, token=auth_token) |
|
|
|
def process_ner(text: str, pipeline) -> dict: |
|
output = pipeline(text) |
|
entities = [] |
|
current_entity = None |
|
|
|
for token in output: |
|
entity_type = token['entity'][2:] |
|
entity_prefix = token['entity'][:1] |
|
|
|
if current_entity is None or entity_type != current_entity['entity'] or (entity_prefix == 'B' and entity_type == current_entity['entity']): |
|
if current_entity is not None: |
|
entities.append(current_entity) |
|
current_entity = { |
|
"entity": entity_type, |
|
"start": token['start'], |
|
"end": token['end'], |
|
"score": token['score'] |
|
} |
|
else: |
|
current_entity['end'] = token['end'] |
|
current_entity['score'] = max(current_entity['score'], token['score']) |
|
|
|
if current_entity is not None: |
|
entities.append(current_entity) |
|
|
|
return {"entities": entities} |
|
|
|
def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str, str, str]: |
|
inputs1 = tokenizer1(text, max_length=512, return_tensors='pt', truncation=True, padding=True) |
|
|
|
with torch.no_grad(): |
|
outputs1 = model1(**inputs1) |
|
outputs2 = model2(**inputs1) |
|
|
|
prediction1 = outputs1[0].item() |
|
prediction2 = outputs2[0].item() |
|
score = prediction1 / (prediction2 + prediction1) |
|
|
|
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}" |
|
|
|
import plotly.graph_objects as go |
|
from typing import Tuple |
|
|
|
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure, np.ndarray]: |
|
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']] |
|
entities_ext = [entity['entity'] for entity in ner_output_ext['entities']] |
|
|
|
|
|
entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)} |
|
bin_labels = list(entity_counts_bin.keys()) |
|
bin_sizes = list(entity_counts_bin.values()) |
|
|
|
|
|
entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)} |
|
ext_labels = list(entity_counts_ext.keys()) |
|
ext_sizes = list(entity_counts_ext.values()) |
|
|
|
|
|
bin_color_map = { |
|
"External": "#6ad5bc", |
|
"Internal": "#ee8bac" |
|
} |
|
|
|
ext_color_map = { |
|
"INTemothou": "#FF7F50", |
|
"INTpercept": "#FF4500", |
|
"INTtime": "#FF6347", |
|
"INTplace": "#FFD700", |
|
"INTevent": "#FFA500", |
|
"EXTsemantic": "#4682B4", |
|
"EXTrepetition": "#5F9EA0", |
|
"EXTother": "#00CED1", |
|
} |
|
|
|
bin_colors = [bin_color_map[label] for label in bin_labels] |
|
ext_colors = [ext_color_map[label] for label in ext_labels] |
|
|
|
|
|
fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))]) |
|
fig1.update_layout( |
|
|
|
template='plotly_dark', |
|
plot_bgcolor='rgba(0,0,0,0)', |
|
paper_bgcolor='rgba(0,0,0,0)' |
|
) |
|
|
|
|
|
fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))]) |
|
fig2.update_layout( |
|
|
|
xaxis_title='Entity Type', |
|
yaxis_title='Count', |
|
template='plotly_dark', |
|
plot_bgcolor='rgba(0,0,0,0)', |
|
paper_bgcolor='rgba(0,0,0,0)' |
|
) |
|
|
|
|
|
wordcloud_image = generate_wordcloud(ner_output_ext['entities'], ext_color_map) |
|
|
|
return fig1, fig2, wordcloud_image |
|
|
|
def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray: |
|
entity_texts = [entity['entity'] for entity in entities] |
|
entity_scores = [entity['score'] for entity in entities] |
|
entity_types = [entity['entity'] for entity in entities] |
|
|
|
|
|
word_freq = {text: score for text, score in zip(entity_texts, entity_scores)} |
|
|
|
def color_func(word, font_size, position, orientation, random_state=None, **kwargs): |
|
entity_type = next(entity['entity'] for entity in entities if entity['entity'] == word) |
|
return color_map.get(entity_type, "#FFFFFF") |
|
|
|
wordcloud = WordCloud(width=800, height=400, background_color='black', color_func=color_func).generate_from_frequencies(word_freq) |
|
|
|
|
|
plt.figure(figsize=(10, 5)) |
|
plt.imshow(wordcloud, interpolation='bilinear') |
|
plt.axis('off') |
|
plt.tight_layout(pad=0) |
|
|
|
|
|
plt_image = plt.gcf() |
|
plt_image.canvas.draw() |
|
image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8) |
|
image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close() |
|
|
|
return image_array |
|
|
|
@spaces.GPU |
|
def all(text: str): |
|
ner_output_bin = process_ner(text, pipe_bin) |
|
ner_output_ext = process_ner(text, pipe_ext) |
|
classification_output = process_classification(text, model1, model2, tokenizer1) |
|
|
|
pie_chart, bar_chart, wordcloud_image = generate_charts(ner_output_bin, ner_output_ext) |
|
|
|
return (ner_output_bin, ner_output_ext, |
|
classification_output[0], classification_output[1], classification_output[2], |
|
pie_chart, bar_chart, wordcloud_image) |
|
|
|
iface = gr.Interface( |
|
fn=all, |
|
inputs=gr.Textbox(lines=5, label="Input Text", placeholder="Write about how your breakfast went or anything else that happened or might happen to you ..."), |
|
outputs=[ |
|
gr.HighlightedText(label="Binary Sequence Classification", |
|
color_map={ |
|
"External": "#6ad5bcff", |
|
"Internal": "#ee8bacff"} |
|
), |
|
gr.HighlightedText(label="Extended Sequence Classification", |
|
color_map={ |
|
"INTemothou": "#FF7F50", |
|
"INTpercept": "#FF4500", |
|
"INTtime": "#FF6347", |
|
"INTplace": "#FFD700", |
|
"INTevent": "#FFA500", |
|
"EXTsemantic": "#4682B4", |
|
"EXTrepetition": "#5F9EA0", |
|
"EXTother": "#00CED1", |
|
} |
|
), |
|
gr.Label(label="Internal Detail Count"), |
|
gr.Label(label="External Detail Count"), |
|
gr.Label(label="Approximated Internal Detail Ratio"), |
|
gr.Plot(label="Extended SeqClass Entity Distribution Pie Chart"), |
|
gr.Plot(label="Binary SeqClass Entity Count Bar Chart"), |
|
gr.Image(label="Entity Word Cloud") |
|
], |
|
title="Scoring Demo", |
|
description="Autobiographical Memory Analysis: This demo combines two text - and two sequence classification models to showcase our automated Autobiographical Interview scoring method. Submit a narrative to see the results.", |
|
examples=examples, |
|
theme=monochrome |
|
) |
|
|
|
iface.launch() |