jgyasu's picture
Upload folder using huggingface_hub
ee305a4 verified
raw
history blame
5.72 kB
import plotly.graph_objects as go
import textwrap
import re
from collections import defaultdict
def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info):
# Combine nodes into one list with appropriate labels
nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence
nodes[0] += ' L0' # Paraphrased sentence is level 0
para_len = len(scheme_sentences)
for i in range(1, para_len + 1):
nodes[i] += ' L1' # Scheme sentences are level 1
for i in range(para_len + 1, len(nodes)):
nodes[i] += ' L2' # Sampled sentences are level 2
# Define the highlight_words function
def highlight_words(sentence, color_map):
for word, color in color_map.items():
sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE)
return sentence
# Clean and wrap nodes, and highlight specified words globally
cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
global_color_map = dict(highlight_info)
highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes]
wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes]
# Function to determine tree levels and create edges dynamically
def get_levels_and_edges(nodes):
levels = {}
edges = []
for i, node in enumerate(nodes):
level = int(node.split()[-1][1])
levels[i] = level
# Add edges from L0 to all L1 nodes
root_node = next(i for i, level in levels.items() if level == 0)
for i, level in levels.items():
if level == 1:
edges.append((root_node, i))
# Add edges from each L1 node to their corresponding L2 nodes
l1_indices = [i for i, level in levels.items() if level == 1]
l2_indices = [i for i, level in levels.items() if level == 2]
for i, l1_node in enumerate(l1_indices):
l2_start = i * 4
for j in range(4):
l2_index = l2_start + j
if l2_index < len(l2_indices):
edges.append((l1_node, l2_indices[l2_index]))
# Add edges from each L2 node to their corresponding L3 nodes
l2_indices = [i for i, level in levels.items() if level == 2]
l3_indices = [i for i, level in levels.items() if level == 3]
l2_to_l3_map = {l2_node: [] for l2_node in l2_indices}
# Map L3 nodes to L2 nodes
for l3_node in l3_indices:
l2_node = l3_node % len(l2_indices)
l2_to_l3_map[l2_indices[l2_node]].append(l3_node)
for l2_node, l3_nodes in l2_to_l3_map.items():
for l3_node in l3_nodes:
edges.append((l2_node, l3_node))
return levels, edges
# Get levels and dynamic edges
levels, edges = get_levels_and_edges(nodes)
max_level = max(levels.values(), default=0)
# Calculate positions
positions = {}
level_heights = defaultdict(int)
for node, level in levels.items():
level_heights[level] += 1
y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()}
x_gap = 2
l1_y_gap = 10
l2_y_gap = 6
for node, level in levels.items():
if level == 1:
positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap)
elif level == 2:
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
else:
positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap)
y_offsets[level] += 1
# Function to highlight words in a wrapped node string
def color_highlighted_words(node, color_map):
parts = re.split(r'(\{\{.*?\}\})', node)
colored_parts = []
for part in parts:
match = re.match(r'\{\{(.*?)\}\}', part)
if match:
word = match.group(1)
color = color_map.get(word, 'black')
colored_parts.append(f"<span style='color: {color};'>{word}</span>")
else:
colored_parts.append(part)
return ''.join(colored_parts)
# Create figure
fig = go.Figure()
# Add nodes to the figure
for i, node in enumerate(wrapped_nodes):
colored_node = color_highlighted_words(node, global_color_map)
x, y = positions[i]
fig.add_trace(go.Scatter(
x=[-x], # Reflect the x coordinate
y=[y],
mode='markers',
marker=dict(size=10, color='blue'),
hoverinfo='none'
))
fig.add_annotation(
x=-x, # Reflect the x coordinate
y=y,
text=colored_node,
showarrow=False,
xshift=15,
align="center",
font=dict(size=8),
bordercolor='black',
borderwidth=1,
borderpad=2,
bgcolor='white',
width=150
)
# Add edges to the figure
for edge in edges:
x0, y0 = positions[edge[0]]
x1, y1 = positions[edge[1]]
fig.add_trace(go.Scatter(
x=[-x0, -x1], # Reflect the x coordinates
y=[y0, y1],
mode='lines',
line=dict(color='black', width=1)
))
fig.update_layout(
showlegend=False,
margin=dict(t=20, b=20, l=20, r=20),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=1200, # Adjusted width to accommodate more levels
height=1000 # Adjusted height to accommodate more levels
)
return fig