import plotly.graph_objects as go import textwrap import re from collections import defaultdict def generate_subplot(paraphrased_sentence, scheme_sentences, sampled_sentence, highlight_info): # Combine nodes into one list with appropriate labels nodes = [paraphrased_sentence] + scheme_sentences + sampled_sentence nodes[0] += ' L0' # Paraphrased sentence is level 0 para_len = len(scheme_sentences) for i in range(1, para_len + 1): nodes[i] += ' L1' # Scheme sentences are level 1 for i in range(para_len + 1, len(nodes)): nodes[i] += ' L2' # Sampled sentences are level 2 # Define the highlight_words function def highlight_words(sentence, color_map): for word, color in color_map.items(): sentence = re.sub(f"\\b{word}\\b", f"{{{{{word}}}}}", sentence, flags=re.IGNORECASE) return sentence # Clean and wrap nodes, and highlight specified words globally cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes] global_color_map = dict(highlight_info) highlighted_nodes = [highlight_words(node, global_color_map) for node in cleaned_nodes] wrapped_nodes = ['
'.join(textwrap.wrap(node, width=30)) for node in highlighted_nodes] # Function to determine tree levels and create edges dynamically def get_levels_and_edges(nodes): levels = {} edges = [] for i, node in enumerate(nodes): level = int(node.split()[-1][1]) levels[i] = level # Add edges from L0 to all L1 nodes root_node = next(i for i, level in levels.items() if level == 0) for i, level in levels.items(): if level == 1: edges.append((root_node, i)) # Add edges from each L1 node to their corresponding L2 nodes l1_indices = [i for i, level in levels.items() if level == 1] l2_indices = [i for i, level in levels.items() if level == 2] for i, l1_node in enumerate(l1_indices): l2_start = i * 4 for j in range(4): l2_index = l2_start + j if l2_index < len(l2_indices): edges.append((l1_node, l2_indices[l2_index])) # Add edges from each L2 node to their corresponding L3 nodes l2_indices = [i for i, level in levels.items() if level == 2] l3_indices = [i for i, level in levels.items() if level == 3] l2_to_l3_map = {l2_node: [] for l2_node in l2_indices} # Map L3 nodes to L2 nodes for l3_node in l3_indices: l2_node = l3_node % len(l2_indices) l2_to_l3_map[l2_indices[l2_node]].append(l3_node) for l2_node, l3_nodes in l2_to_l3_map.items(): for l3_node in l3_nodes: edges.append((l2_node, l3_node)) return levels, edges # Get levels and dynamic edges levels, edges = get_levels_and_edges(nodes) max_level = max(levels.values(), default=0) # Calculate positions positions = {} level_heights = defaultdict(int) for node, level in levels.items(): level_heights[level] += 1 y_offsets = {level: - (height - 1) / 2 for level, height in level_heights.items()} x_gap = 2 l1_y_gap = 10 l2_y_gap = 6 for node, level in levels.items(): if level == 1: positions[node] = (-level * x_gap, y_offsets[level] * l1_y_gap) elif level == 2: positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap) else: positions[node] = (-level * x_gap, y_offsets[level] * l2_y_gap) y_offsets[level] += 1 # Function to highlight words in a wrapped node string def color_highlighted_words(node, color_map): parts = re.split(r'(\{\{.*?\}\})', node) colored_parts = [] for part in parts: match = re.match(r'\{\{(.*?)\}\}', part) if match: word = match.group(1) color = color_map.get(word, 'black') colored_parts.append(f"{word}") else: colored_parts.append(part) return ''.join(colored_parts) # Create figure fig = go.Figure() # Add nodes to the figure for i, node in enumerate(wrapped_nodes): colored_node = color_highlighted_words(node, global_color_map) x, y = positions[i] fig.add_trace(go.Scatter( x=[-x], # Reflect the x coordinate y=[y], mode='markers', marker=dict(size=10, color='blue'), hoverinfo='none' )) fig.add_annotation( x=-x, # Reflect the x coordinate y=y, text=colored_node, showarrow=False, xshift=15, align="center", font=dict(size=8), bordercolor='black', borderwidth=1, borderpad=2, bgcolor='white', width=150 ) # Add edges to the figure for edge in edges: x0, y0 = positions[edge[0]] x1, y1 = positions[edge[1]] fig.add_trace(go.Scatter( x=[-x0, -x1], # Reflect the x coordinates y=[y0, y1], mode='lines', line=dict(color='black', width=1) )) fig.update_layout( showlegend=False, margin=dict(t=20, b=20, l=20, r=20), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), width=1200, # Adjusted width to accommodate more levels height=1000 # Adjusted height to accommodate more levels ) return fig