Spaces:
Sleeping
Sleeping
import gradio as gr | |
from functools import partial | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import tqdm | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline | |
import penman | |
from collections import Counter, defaultdict | |
import networkx as nx | |
from networkx.drawing.nx_agraph import pygraphviz_layout | |
class FramingLabels: | |
def __init__(self, base_model, candidate_labels, batch_size=16): | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
self.base_pipeline = pipeline("zero-shot-classification", model=base_model, device=device) | |
self.candidate_labels = candidate_labels | |
self.classifier = partial(self.base_pipeline, candidate_labels=candidate_labels, multi_label=True, batch_size=batch_size) | |
def order_scores(self, dic): | |
indices_order = [dic["labels"].index(l) for l in self.candidate_labels] | |
scores_ordered = np.array(dic["scores"])[indices_order].tolist() | |
return scores_ordered | |
def get_ordered_scores(self, sequence_to_classify): | |
if type(sequence_to_classify) == list: | |
res = [] | |
for out in tqdm.tqdm(self.classifier(sequence_to_classify)): | |
res.append(out) | |
else: | |
res = self.classifier(sequence_to_classify) | |
if type(res) == list: | |
scores_ordered = list(map(self.order_scores, res)) | |
scores_ordered = list(map(list, zip(*scores_ordered))) # reorder | |
else: | |
scores_ordered = self.order_scores(res) | |
return scores_ordered | |
def get_label_names(self): | |
label_names = [l.split(":")[0].split(" ")[0] for l in self.candidate_labels] | |
return label_names | |
def __call__(self, sequence_to_classify): | |
scores = self.get_ordered_scores(sequence_to_classify) | |
label_names = self.get_label_names() | |
return dict(zip(label_names, scores)) | |
def visualize(self, name_to_score_dict, threshold=0.5, **kwargs): | |
fig, ax = plt.subplots() | |
cp = sns.color_palette() | |
scores_ordered = list(name_to_score_dict.values()) | |
label_names = list(name_to_score_dict.keys()) | |
colors = [cp[0] if s > 0.5 else cp[1] for s in scores_ordered] | |
ax.barh(label_names[::-1], scores_ordered[::-1], color=colors[::-1], **kwargs) | |
plt.xlim(left=0) | |
plt.tight_layout() | |
return fig, ax | |
class FramingDimensions: | |
def __init__(self, base_model, dimensions, pole_names): | |
self.encoder = SentenceTransformer(base_model) | |
self.dimensions = dimensions | |
self.dim_embs = self.encoder.encode(dimensions) | |
self.pole_names = pole_names | |
self.axis_names = list(map(lambda x: x[0] + "/" + x[1], pole_names)) | |
axis_embs = [] | |
for pole1, pole2 in pole_names: | |
p1 = self.get_dimension_names().index(pole1) | |
p2 = self.get_dimension_names().index(pole2) | |
axis_emb = self.dim_embs[p1] - self.dim_embs[p2] | |
axis_embs.append(axis_emb) | |
self.axis_embs = np.stack(axis_embs) | |
def get_dimension_names(self): | |
dimension_names = [l.split(":")[0].split(" ")[0] for l in self.dimensions] | |
return dimension_names | |
def __call__(self, sequence_to_align): | |
embs = self.encoder.encode(sequence_to_align) | |
scores = embs @ self.axis_embs.T | |
named_scores = dict(zip(self.pole_names, scores.T)) | |
return named_scores | |
def visualize(self, align_scores_df, **kwargs): | |
name_left = align_scores_df.columns.map(lambda x: x[1]) | |
name_right = align_scores_df.columns.map(lambda x: x[0]) | |
bias = align_scores_df.mean() | |
color = ["b" if x > 0 else "r" for x in bias] | |
inten = (align_scores_df.var().fillna(0)+0.001)*50_000 | |
bounds = bias.abs().max()*1.1 | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
plt.scatter(x=bias, y=name_left, s=inten, c=color) | |
plt.axvline(0) | |
plt.xlim(-bounds, bounds) | |
plt.gca().invert_yaxis() | |
axi = ax.twinx() | |
axi.set_ylim(ax.get_ylim()) | |
axi.set_yticks(ax.get_yticks(), labels=name_right) | |
plt.tight_layout() | |
return fig | |
class FramingStructure: | |
def __init__(self, base_model, roles=None): | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
self.translator = pipeline("text2text-generation", base_model, device=device, max_length=300) | |
def __call__(self, sequence_to_translate): | |
res = self.translator(sequence_to_translate) | |
def try_decode(x): | |
try: | |
return penman.decode(x["generated_text"]) | |
except: | |
# print(f"Decode error for {res}") | |
return None | |
graphs = list(filter(lambda item: item is not None, [try_decode(x) for x in res])) | |
return graphs | |
def visualize(self, decoded_graphs, min_node_threshold=1, **kwargs): | |
cnt = Counter() | |
for gen_text in decoded_graphs: | |
amr = gen_text.triples | |
amr = list(filter(lambda x: x[2] is not None, amr)) | |
amr = list(map(lambda x: (x[0], x[1].replace(":", ""), x[2]), amr)) | |
def trim_distinction_end(x): | |
x = x.split("_")[0] | |
return x | |
amr = list(map(lambda x: (trim_distinction_end(x[0]), x[1], trim_distinction_end(x[2])), amr)) | |
cnt.update(amr) | |
G = nx.DiGraph() | |
color_map = defaultdict(lambda: "k", { | |
"ARG0": "y", | |
"ARG1": "r", | |
"ARG2": "g", | |
"ARG3": "b" | |
}) | |
for entry, num in cnt.items(): | |
if not G.has_node(entry[0]): | |
G.add_node(entry[0], weight=0) | |
if not G.has_node(entry[2]): | |
G.add_node(entry[2], weight=0) | |
G.nodes[entry[0]]["weight"] += num | |
G.nodes[entry[2]]["weight"] += num | |
G.add_edge(entry[0], entry[2], role=entry[1], weight=num, color=color_map[entry[1]]) | |
G_sub = nx.subgraph_view(G, filter_node=lambda n: G.nodes[n]["weight"] >= min_node_threshold) | |
node_sizes = [x * 100 for x in nx.get_node_attributes(G_sub,'weight').values()] | |
edge_colors = nx.get_edge_attributes(G_sub,'color').values() | |
fig = plt.figure() | |
pos = pygraphviz_layout(G_sub, prog="dot") | |
nx.draw_networkx(G_sub, pos, node_size=node_sizes, edge_color=edge_colors) | |
nx.draw_networkx_labels(G_sub, pos) | |
nx.draw_networkx_edge_labels(G_sub, pos, edge_labels=nx.get_edge_attributes(G_sub, "role")) | |
plt.tight_layout() | |
return fig | |
# Specify the models | |
base_model_1 = "facebook/bart-large-mnli" | |
base_model_2 = 'all-mpnet-base-v2' | |
base_model_3 = "Iseratho/model_parse_xfm_bart_base-v0_1_0" | |
# https://homes.cs.washington.edu/~nasmith/papers/card+boydstun+gross+resnik+smith.acl15.pdf | |
candidate_labels = [ | |
"Economic: costs, benefits, or other financial implications", | |
"Capacity and resources: availability of physical, human or financial resources, and capacity of current systems", | |
"Morality: religious or ethical implications", | |
"Fairness and equality: balance or distribution of rights, responsibilities, and resources", | |
"Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government", | |
"Policy prescription and evaluation: discussion of specific policies aimed at addressing problems", | |
"Crime and punishment: effectiveness and implications of laws and their enforcement", | |
"Security and defense: threats to welfare of the individual, community, or nation", | |
"Health and safety: health care, sanitation, public safety", | |
"Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being", | |
"Cultural identity: traditions, customs, or values of a social group in relation to a policy issue", | |
"Public opinion: attitudes and opinions of the general public, including polling and demographics", | |
"Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters", | |
"External regulation and reputation: international reputation or foreign policy of the U.S.", | |
"Other: any coherent group of frames not covered by the above categories", | |
] | |
# https://osf.io/xakyw | |
dimensions = [ | |
"Care: ...acted with kindness, compassion, or empathy, or nurtured another person.", | |
"Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.", | |
"Fairness: ...acted in a fair manner, promoting equality, justice, or rights.", | |
"Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.", | |
"Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.", | |
"Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.", | |
"Authority: ...obeyed, or acted with respect for authority or tradition.", | |
"Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.", | |
"Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.", | |
"Degradation: ...was depraved, degrading, impure, or unnatural.", | |
] | |
pole_names = [ | |
("Care", "Harm"), | |
("Fairness", "Cheating"), | |
("Loyalty", "Betrayal"), | |
("Authority", "Subversion"), | |
("Sanctity", "Degradation"), | |
] | |
framing_label_model = FramingLabels(base_model_1, candidate_labels) | |
framing_dimen_model = FramingDimensions(base_model_2, dimensions, pole_names) | |
framing_struc_model = FramingStructure(base_model_3) | |
def framing_multi(texts, min_node_threshold=1): | |
res1 = pd.DataFrame(framing_label_model(texts)) | |
fig1, _ = framing_label_model.visualize(res1.mean().to_dict(), xerr=res1.sem()) | |
fig2 = framing_dimen_model.visualize(pd.DataFrame(framing_dimen_model(texts))) | |
fig3 = framing_struc_model.visualize(framing_struc_model(texts), min_node_threshold=min_node_threshold) | |
return fig1, fig2, fig3 | |
def framing_single(text, min_node_threshold=1): | |
fig1, _ = framing_label_model.visualize(framing_label_model(text)) | |
fig2 = framing_dimen_model.visualize(pd.DataFrame({k: [v] for k, v in framing_dimen_model(text).items()})) | |
fig3 = framing_struc_model.visualize(framing_struc_model(text), min_node_threshold=min_node_threshold) | |
return fig1, fig2, fig3 | |
async def framing_textbox(text, split, min_node_threshold): | |
texts = text.split("\n") | |
if split and len(texts) > 1: | |
return framing_multi(texts, min_node_threshold) | |
return framing_single(text, min_node_threshold) | |
async def framing_file(file_obj, split, min_node_threshold): | |
with open(file_obj.name, "r") as f: | |
if split: | |
texts = f.readlines() | |
if len(texts) > 1: | |
return framing_multi(texts, min_node_threshold) | |
else: | |
text = texts[0] | |
else: | |
text = f.read() | |
return framing_single(text, min_node_threshold) | |
example_list = [["In 2010, CFCs were banned internationally due to their harmful effect on the ozone layer.", False, 1], | |
["In 2021, doctors prevented the spread of the virus by vaccinating with Pfizer.", False, 1], | |
["We must fight for our freedom.", False, 1], | |
["The government prevents our freedom.", False, 1], | |
["They prevent the spread.", False, 1], | |
["We fight the virus.", False, 1], | |
["I believe that we should act now.\nThere is no time to waste.", True, 1], | |
] | |
description = """A simple tool that helps you find (discover and detect) frames in text. | |
Note that due to the computation time required for underlying Transformer models, only short texts are recommended.""" | |
article=""""Check out the preliminary article in the [Web Conference Symposium](https://dl.acm.org/doi/pdf/10.1145/3543873.3587534), will be updated to currently in review article after publication. | |
<details> | |
<summary>Explanation of labels:</summary> | |
<ul> | |
<li>Economic: costs, benefits, or other financial implications</li> | |
<li>Capacity and resources: availability of physical, human or financial resources, and capacity of current systems</li> | |
<li>Morality: religious or ethical implications</li> | |
<li>Fairness and equality: balance or distribution of rights, responsibilities, and resources</li> | |
<li>Legality, constitutionality and jurisprudence: rights, freedoms, and authority of individuals, corporations, and government</li> | |
<li>Policy prescription and evaluation: discussion of specific policies aimed at addressing problems</li> | |
<li>Crime and punishment: effectiveness and implications of laws and their enforcement</li> | |
<li>Security and defense: threats to welfare of the individual, community, or nation</li> | |
<li>Health and safety: health care, sanitation, public safety</li> | |
<li>Quality of life: threats and opportunities for the individual’s wealth, happiness, and well-being</li> | |
<li>Cultural identity: traditions, customs, or values of a social group in relation to a policy issue</li> | |
<li>Public opinion: attitudes and opinions of the general public, including polling and demographics</li> | |
<li>Political: considerations related to politics and politicians, including lobbying, elections, and attempts to sway voters</li> | |
<li>External regulation and reputation: international reputation or foreign policy of the U.S.</li> | |
<li>Other: any coherent group of frames not covered by the above categories</li> | |
</ul> | |
</details> | |
<details> | |
<summary>Explanation of dimensions: </summary> | |
<ul> | |
<li>Care: ...acted with kindness, compassion, or empathy, or nurtured another person.</li> | |
<li>Harm: ...acted with cruelty, or hurt or harmed another person/animal and caused suffering.</li> | |
<li>Fairness: ...acted in a fair manner, promoting equality, justice, or rights.</li> | |
<li>Cheating: ...was unfair or cheated, or caused an injustice or engaged in fraud.</li> | |
<li>Loyalty: ...acted with fidelity, or as a team player, or was loyal or patriotic.</li> | |
<li>Betrayal: ...acted disloyal, betrayed someone, was disloyal, or was a traitor.</li> | |
<li>Authority: ...obeyed, or acted with respect for authority or tradition.</li> | |
<li>Subversion: ...disobeyed or showed disrespect, or engaged in subversion or caused chaos.</li> | |
<li>Sanctity: ...acted in a way that was wholesome or sacred, or displayed purity or sanctity.</li> | |
<li>Degradation: ...was depraved, degrading, impure, or unnatural.</li> | |
</ul> | |
</details> | |
Document of structure (AMR) explanation: [AMR Specification](https://github.com/amrisi/amr-guidelines/blob/master/amr.md) | |
""" | |
textbox_inferface = gr.Interface(fn=framing_textbox, | |
inputs=[ | |
gr.Textbox(label="Text to analyze."), | |
gr.Checkbox(True, label="Split on newlines? (To enter newlines type shift+Enter)"), | |
gr.Number(1, label="Min node threshold for framing structure.") | |
], | |
description=description, | |
examples=example_list, | |
article=article, | |
outputs=[gr.Plot(label="Label"), | |
gr.Plot(label="Dimensions"), | |
gr.Plot(label="Structure") | |
]) | |
file_interface = gr.Interface(fn=framing_file, | |
inputs=[ | |
gr.File(label="File of texts to analyze."), | |
gr.Checkbox(True, label="Split on newlines?"), | |
gr.Number(1, label="Min node threshold for framing structure."), | |
], | |
description=description, | |
article=article, | |
outputs=[gr.Plot(label="Label"), | |
gr.Plot(label="Dimensions"), | |
gr.Plot(label="Structure")]) | |
demo = gr.TabbedInterface([textbox_inferface, file_interface], | |
tab_names=["Single Mode", "File Mode"], | |
title="FrameFinder",) | |
demo.launch() | |