File size: 3,567 Bytes
ea7f5b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import plotly.graph_objs as go
import textwrap
import re
from collections import defaultdict
from paraphraser import generate_paraphrase
from masking_methods import mask, mask_non_stopword

def generate_plot(original_sentence):
    paraphrased_sentences = generate_paraphrase(original_sentence)
    first_paraphrased_sentence = paraphrased_sentences[0]
    masked_sentence = mask_non_stopword(first_paraphrased_sentence)
    masked_versions = mask(masked_sentence)
    
    nodes = []
    nodes.append(original_sentence)
    nodes.extend(paraphrased_sentences)
    nodes.extend(masked_versions)
    nodes[0] += ' L0'
    para_len = len(paraphrased_sentences)
    for i in range(1, para_len+1):
        nodes[i] += ' L1'
    for i in range(para_len+1, len(nodes)):
        nodes[i] += ' L2'
    
    cleaned_nodes = [re.sub(r'\sL[0-9]$', '', node) for node in nodes]
    wrapped_nodes = ['<br>'.join(textwrap.wrap(node, width=30)) for node in cleaned_nodes]
    
    def get_levels_and_edges(nodes):
        levels = {}
        edges = []
        for i, node in enumerate(nodes):
            level = int(node.split()[-1][1])
            levels[i] = level

        # Add edges from L0 to all L1 nodes
        root_node = next(i for i, level in levels.items() if level == 0)
        for i, level in levels.items():
            if level == 1:
                edges.append((root_node, i))

        # Identify the first L1 node
        first_l1_node = next(i for i, level in levels.items() if level == 1)
        # Add edges from the first L1 node to all L2 nodes
        for i, level in levels.items():
            if level == 2:
                edges.append((first_l1_node, i))

        return levels, edges

    # Get levels and dynamic edges
    levels, edges = get_levels_and_edges(nodes)
    max_level = max(levels.values())

    # Calculate positions
    positions = {}
    level_widths = defaultdict(int)
    for node, level in levels.items():
        level_widths[level] += 1

    x_offsets = {level: - (width - 1) / 2 for level, width in level_widths.items()}
    y_gap = 4

    for node, level in levels.items():
        positions[node] = (x_offsets[level], -level * y_gap)
        x_offsets[level] += 1

    # Create figure
    fig = go.Figure()

    # Add nodes to the figure
    for i, node in enumerate(wrapped_nodes):
        x, y = positions[i]
        fig.add_trace(go.Scatter(
            x=[x],
            y=[y],
            mode='markers',
            marker=dict(size=10, color='blue'),
            hoverinfo='none'
        ))
        fig.add_annotation(
            x=x,
            y=y,
            text=node,
            showarrow=False,
            yshift=20,  # Adjust the y-shift value to avoid overlap
            align="center",
            font=dict(size=10),
            bordercolor='black',
            borderwidth=1,
            borderpad=4,
            bgcolor='white',
            width=200
        )

    # Add edges to the figure
    for edge in edges:
        x0, y0 = positions[edge[0]]
        x1, y1 = positions[edge[1]]
        fig.add_trace(go.Scatter(
            x=[x0, x1],
            y=[y0, y1],
            mode='lines',
            line=dict(color='black', width=2)
        ))

    fig.update_layout(
        showlegend=False,
        margin=dict(t=50, b=50, l=50, r=50),
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        width=1470,
        height=800  # Increase height to provide more space
    )

    return fig