|
import plotly.graph_objects as go |
|
import textwrap |
|
import re |
|
from collections import defaultdict |
|
|
|
def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info): |
|
|
|
nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence |
|
nodes[0] += ' L0' |
|
para_len = len(scheme_sentences) |
|
for i in range(1, para_len + 1): |
|
nodes[i] += ' L1' |
|
for i in range(para_len + 1, len(nodes)): |
|
nodes[i] += ' L2' |
|
|
|
|
|
def highlight_words(sentence, color_map): |
|
for word, color in color_map.items(): |
|
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) |
|
return sentence |
|
|
|
|
|
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] |
|
global_color_map = dict(highlight_info) |
|
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] |
|
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes] |
|
|
|
|
|
def get_levels_and_edges(nodes): |
|
levels = {} |
|
edges = [] |
|
for i, node in enumerate(nodes): |
|
level = int(node.split()[-1][1]) |
|
levels[i] = level |
|
|
|
|
|
root_node = next(i for i, level in levels.items() if level == 0) |
|
for i, level in levels.items(): |
|
if level == 1: |
|
edges.append((root_node, i)) |
|
|
|
|
|
l1_indices = [i for i, level in levels.items() if level == 1] |
|
l2_indices = [i for i, level in levels.items() if level == 2] |
|
|
|
for i, l1_node in enumerate(l1_indices): |
|
l2_start = i * 4 |
|
for j in range(4): |
|
l2_index = l2_start + j |
|
if l2_index < len(l2_indices): |
|
edges.append((l1_node, l2_indices[l2_index])) |
|
|
|
|
|
l2_indices = [i for i, level in levels.items() if level == 2] |
|
l3_indices = [i for i, level in levels.items() if level == 3] |
|
|
|
l2_to_l3_map = {l2_node: [] for l2_node in l2_indices} |
|
|
|
|
|
for l3_node in l3_indices: |
|
l2_node = l3_node % len(l2_indices) |
|
l2_to_l3_map[l2_indices[l2_node]].append(l3_node) |
|
|
|
for l2_node, l3_nodes in l2_to_l3_map.items(): |
|
for l3_node in l3_nodes: |
|
edges.append((l2_node, l3_node)) |
|
|
|
return levels, edges |
|
|
|
|
|
levels, edges = get_levels_and_edges(nodes) |
|
max_level = max(levels.values(), default=0) |
|
|
|
|
|
positions = {} |
|
level_heights = defaultdict(int) |
|
for node, level in levels.items(): |
|
level_heights[level] += 1 |
|
|
|
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} |
|
x_gap = 2 |
|
l1_y_gap = 10 |
|
l2_y_gap = 6 |
|
|
|
for node, level in levels.items(): |
|
if level == 1: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) |
|
elif level == 2: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap) |
|
else: |
|
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap) |
|
y_offsets[level] += 1 |
|
|
|
|
|
def color_highlighted_words(node, color_map): |
|
parts = re.split(r'(\{\{.*?\}\})', node) |
|
colored_parts = [] |
|
for part in parts: |
|
match = re.match(r'\{\{(.*?)\}\}', part) |
|
if match: |
|
word = match.group(1) |
|
color = color_map.get(word, 'black') |
|
colored_parts.append(f"<span style='color: {color};'>{word}</span>") |
|
else: |
|
colored_parts.append(part) |
|
return ''.join(colored_parts) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
for i, node in enumerate(wrapped_nodes): |
|
colored_node = color_highlighted_words(node, global_color_map) |
|
x, y = positions[i] |
|
fig.add_trace(go.Scatter( |
|
x=[-x], |
|
y=[y], |
|
mode='markers', |
|
marker=dict(size=10, color='blue'), |
|
hoverinfo='none' |
|
)) |
|
fig.add_annotation( |
|
x=-x, |
|
y=y, |
|
text=colored_node, |
|
showarrow=False, |
|
xshift=15, |
|
align="center", |
|
font=dict(size=8), |
|
bordercolor='black', |
|
borderwidth=1, |
|
borderpad=2, |
|
bgcolor='white', |
|
width=150 |
|
) |
|
|
|
|
|
for edge in edges: |
|
x0, y0 = positions[edge[0]] |
|
x1, y1 = positions[edge[1]] |
|
fig.add_trace(go.Scatter( |
|
x=[-x0, -x1], |
|
y=[y0, y1], |
|
mode='lines', |
|
line=dict(color='black', width=1) |
|
)) |
|
|
|
fig.update_layout( |
|
showlegend=False, |
|
margin=dict(t=20, b=20, l=20, r=20), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
width=1200, |
|
height=1000 |
|
) |
|
|
|
return fig |