Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from typing import Iterable, List, Dict, Tuple | |
import gradio as gr | |
from gradio.themes.base import Base | |
from gradio.themes.soft import Soft | |
from gradio.themes.monochrome import Monochrome | |
from gradio.themes.default import Default | |
from gradio.themes.utils import colors, fonts, sizes | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, pipeline | |
import os | |
import colorsys | |
import matplotlib.pyplot as plt | |
import plotly.graph_objects as go | |
from typing import Tuple | |
import plotly.io as pio | |
from wordcloud import WordCloud | |
import io | |
def hex_to_rgb(hex_color: str) -> tuple[int, int, int]: | |
hex_color = hex_color.lstrip('#') | |
return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) | |
def rgb_to_hex(rgb_color: tuple[int, int, int]) -> str: | |
return "#{:02x}{:02x}{:02x}".format(*rgb_color) | |
def adjust_brightness(rgb_color: tuple[int, int, int], factor: float) -> tuple[int, int, int]: | |
hsv_color = colorsys.rgb_to_hsv(*[v / 255.0 for v in rgb_color]) | |
new_v = max(0, min(hsv_color[2] * factor, 1)) | |
new_rgb = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], new_v) | |
return tuple(int(v * 255) for v in new_rgb) | |
monochrome = Monochrome() | |
auth_token = os.environ['HF_TOKEN'] | |
tokenizer_bin = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_token", token=auth_token) | |
model_bin = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_token", token=auth_token) | |
tokenizer_bin.model_max_length = 512 | |
pipe_bin = pipeline("ner", model=model_bin, tokenizer=tokenizer_bin) | |
tokenizer_ext = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token) | |
model_ext = AutoModelForTokenClassification.from_pretrained("AlGe/deberta-v3-large_AIS-token", token=auth_token) | |
tokenizer_ext.model_max_length = 512 | |
pipe_ext = pipeline("ner", model=model_ext, tokenizer=tokenizer_ext) | |
model1 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_Int_segment", num_labels=1, token=auth_token) | |
tokenizer1 = AutoTokenizer.from_pretrained("AlGe/deberta-v3-large_Int_segment", token=auth_token) | |
model2 = AutoModelForSequenceClassification.from_pretrained("AlGe/deberta-v3-large_seq_ext", num_labels=1, token=auth_token) | |
def process_ner(text: str, pipeline) -> dict: | |
output = pipeline(text) | |
entities = [] | |
current_entity = None | |
for token in output: | |
entity_type = token['entity'][2:] | |
entity_prefix = token['entity'][:1] | |
if current_entity is None or entity_type != current_entity['entity'] or (entity_prefix == 'B' and entity_type == current_entity['entity']): | |
if current_entity is not None: | |
entities.append(current_entity) | |
current_entity = { | |
"entity": entity_type, | |
"start": token['start'], | |
"end": token['end'], | |
"score": token['score'] | |
} | |
else: | |
current_entity['end'] = token['end'] | |
current_entity['score'] = max(current_entity['score'], token['score']) | |
if current_entity is not None: | |
entities.append(current_entity) | |
return {"entities": entities} | |
def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str, str, str]: | |
inputs1 = tokenizer1(text, max_length=512, return_tensors='pt', truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs1 = model1(**inputs1) | |
outputs2 = model2(**inputs1) | |
prediction1 = outputs1[0].item() | |
prediction2 = outputs2[0].item() | |
score = prediction1 / (prediction2 + prediction1) | |
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}" | |
import plotly.graph_objects as go | |
from typing import Tuple | |
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure, np.ndarray]: | |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']] | |
entities_ext = [entity['entity'] for entity in ner_output_ext['entities']] | |
# Counting entities for binary classification | |
entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)} | |
bin_labels = list(entity_counts_bin.keys()) | |
bin_sizes = list(entity_counts_bin.values()) | |
# Counting entities for extended classification | |
entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)} | |
ext_labels = list(entity_counts_ext.keys()) | |
ext_sizes = list(entity_counts_ext.values()) | |
# Define color mapping | |
bin_color_map = { | |
"External": "#6ad5bc", | |
"Internal": "#ee8bac" | |
} | |
ext_color_map = { | |
"INTemothou": "#FF7F50", # Coral | |
"INTpercept": "#FF4500", # OrangeRed | |
"INTtime": "#FF6347", # Tomato | |
"INTplace": "#FFD700", # Gold | |
"INTevent": "#FFA500", # Orange | |
"EXTsemantic": "#4682B4", # SteelBlue | |
"EXTrepetition": "#5F9EA0", # CadetBlue | |
"EXTother": "#00CED1", # DarkTurquoise | |
} | |
bin_colors = [bin_color_map[label] for label in bin_labels] | |
ext_colors = [ext_color_map[label] for label in ext_labels] | |
# Create pie chart for extended classification | |
fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3, marker=dict(colors=ext_colors))]) | |
fig1.update_layout( | |
#title_text='Extended Sequence Classification Subclasses', | |
template='plotly_dark', | |
plot_bgcolor='rgba(0,0,0,0)', | |
paper_bgcolor='rgba(0,0,0,0)' | |
) | |
# Create bar chart for binary classification | |
fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes, marker=dict(color=bin_colors))]) | |
fig2.update_layout( | |
#title='Binary Sequence Classification Classes', | |
xaxis_title='Entity Type', | |
yaxis_title='Count', | |
template='plotly_dark', | |
plot_bgcolor='rgba(0,0,0,0)', | |
paper_bgcolor='rgba(0,0,0,0)' | |
) | |
# Generate word cloud | |
wordcloud_image = generate_wordcloud(ner_output_ext['entities'], ext_color_map) | |
return fig1, fig2, wordcloud_image | |
def generate_wordcloud(entities: List[Dict], color_map: Dict[str, str]) -> np.ndarray: | |
entity_texts = [entity['entity'] for entity in entities] | |
entity_scores = [entity['score'] for entity in entities] | |
entity_types = [entity['entity'] for entity in entities] | |
# Create a dictionary for word cloud | |
word_freq = {text: score for text, score in zip(entity_texts, entity_scores)} | |
def color_func(word, font_size, position, orientation, random_state=None, **kwargs): | |
entity_type = next(entity['entity'] for entity in entities if entity['entity'] == word) | |
return color_map.get(entity_type, "#FFFFFF") | |
wordcloud = WordCloud(width=800, height=400, background_color='black', color_func=color_func).generate_from_frequencies(word_freq) | |
# Convert to image array | |
plt.figure(figsize=(10, 5)) | |
plt.imshow(wordcloud, interpolation='bilinear') | |
plt.axis('off') | |
plt.tight_layout(pad=0) | |
# Convert plt to numpy array | |
plt_image = plt.gcf() | |
plt_image.canvas.draw() | |
image_array = np.frombuffer(plt_image.canvas.tostring_rgb(), dtype=np.uint8) | |
image_array = image_array.reshape(plt_image.canvas.get_width_height()[::-1] + (3,)) | |
plt.close() | |
return image_array | |
def all(text: str): | |
ner_output_bin = process_ner(text, pipe_bin) | |
ner_output_ext = process_ner(text, pipe_ext) | |
classification_output = process_classification(text, model1, model2, tokenizer1) | |
pie_chart, bar_chart, wordcloud_image = generate_charts(ner_output_bin, ner_output_ext) | |
return (ner_output_bin, ner_output_ext, | |
classification_output[0], classification_output[1], classification_output[2], | |
pie_chart, bar_chart, wordcloud_image) | |
iface = gr.Interface( | |
fn=all, | |
inputs=gr.Textbox(lines=5, label="Input Text", placeholder="Write about how your breakfast went or anything else that happened or might happen to you ..."), | |
outputs=[ | |
gr.HighlightedText(label="Binary Sequence Classification", | |
color_map={ | |
"External": "#6ad5bcff", | |
"Internal": "#ee8bacff"} | |
), | |
gr.HighlightedText(label="Extended Sequence Classification", | |
color_map={ | |
"INTemothou": "#FF7F50", # Coral | |
"INTpercept": "#FF4500", # OrangeRed | |
"INTtime": "#FF6347", # Tomato | |
"INTplace": "#FFD700", # Gold | |
"INTevent": "#FFA500", # Orange | |
"EXTsemantic": "#4682B4", # SteelBlue | |
"EXTrepetition": "#5F9EA0", # CadetBlue | |
"EXTother": "#00CED1", # DarkTurquoise | |
} | |
), | |
gr.Label(label="Internal Detail Count"), | |
gr.Label(label="External Detail Count"), | |
gr.Label(label="Approximated Internal Detail Ratio"), | |
gr.Plot(label="Extended SeqClass Entity Distribution Pie Chart"), | |
gr.Plot(label="Binary SeqClass Entity Count Bar Chart"), | |
gr.Image(label="Entity Word Cloud") | |
], | |
title="Scoring Demo", | |
description="Autobiographical Memory Analysis: This demo combines two text - and two sequence classification models to showcase our automated Autobiographical Interview scoring method. Submit a narrative to see the results.", | |
examples=examples, | |
theme=monochrome | |
) | |
iface.launch() |