File size: 4,045 Bytes
6614d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8739835
6614d86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02a3276
6614d86
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import base64
from transformers import AutoModel, AutoTokenizer
from graphviz import Digraph
import json

def display_tree(output):
    size = str(int(len(output))) + ',5'
    dpi = '300'
    format = 'svg'
    print(size, dpi)
    
    # Initialize Digraph object
    dot = Digraph(engine='dot', format=format)
    dot.attr('graph', rankdir='LR', rank='same', size=size, dpi=dpi)
    
    # Add nodes and edges
    for i,word_info in enumerate(output):
        word = word_info['word']  # Prepare word for RTL display
        head_idx = word_info['dep_head_idx']
        dep_func = word_info['dep_func']
        
        dot.node(str(i), word)
        # Create an invisible edge from the previous word to this one to enforce order
        if i > 0:
            dot.edge(str(i), str(i - 1), style='invis')
        if head_idx != -1:
            dot.edge(str(i), str(head_idx), label=dep_func, constraint='False')


    # Render the Digraph object
    dot.render('syntax_tree', format=format, cleanup=True)
    # Display the image in a scrollable container
    st.markdown(
        f"""
            <div style="height:250px; width:75vw; overflow:auto; border:1px solid #ccc; margin-left:-15vw">
                <img src="data:image/svg+xml;base64,{base64.b64encode(dot.pipe(format='svg')).decode()}" 
                    style="display: block; margin: auto; max-height: 240px;">
            </div>
        """, unsafe_allow_html=True)
    
    #st.image('syntax_tree.' + format, use_column_width=True)
        
# Streamlit app title
st.title('DictaBERT-Joint Visualizer')

# Load Hugging Face token
hf_token = st.secrets["HF_TOKEN"]  # Assuming you've set up the token in Streamlit secrets

# Authenticate and load model
tokenizer = AutoTokenizer.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token)
model = AutoModel.from_pretrained('dicta-il/dictabert-joint', use_auth_token=hf_token, trust_remote_code=True)

model.eval()

# Checkbox for the compute_mst parameter
compute_mst = st.checkbox('Compute Maximum Spanning Tree', value=True)

output_style = st.selectbox(
    'Output Style: ',
    ('JSON', 'UD', 'IAHLT_UD'), index=1).lower()

# User input
sentence = st.text_input('Enter a sentence to analyze:')

if sentence:
    # Display the input sentence
    st.text(sentence)

    # Model prediction
    output = model.predict([sentence], tokenizer, compute_syntax_mst=compute_mst, output_style=output_style)[0]
    
    if output_style == 'ud' or output_style == 'iahlt_ud':
        ud_output = output
        # convert to tree format of [dict(word, dep_head_idx, dep_func)]
        tree = []
        for l in ud_output[2:]:
            parts = l.split('\t')
            if '-' in parts[0]: continue
            tree.append(dict(word=parts[1], dep_head_idx=int(parts[6]) - 1, dep_func=parts[7]))
        display_tree(tree)

        # Construct the table as a Markdown string
        table_md = "<div dir='rtl' style='text-align: right;'>\n\n"  # Start with RTL div
        
        # Add the UD header lines
        table_md += "##" + ud_output[0] + "\n"
        table_md += "##" + ud_output[1] + "\n"
        # Table header
        table_md += "| " + " | ".join(["ID", "FORM", "LEMMA", "UPOS", "XPOS", "FEATS", "HEAD", "DEPREL", "DEPS", "MISC"]) + " |\n"
        # Table alignment
        table_md += "| " + " | ".join(["---"]*10) + " |\n"
        for line in ud_output[2:]:
            # Each UD line as a table row
            cells = line.replace('_', '\\_').replace('|', '&#124;').split('\t')
            table_md += "| " + " | ".join(cells) + " |\n"
        table_md += "</div>"  # Close the RTL div
        print(table_md)
        
        # Display the table using a single markdown call
        st.markdown(table_md, unsafe_allow_html=True)

    else:
        # display the tree
        tree = [w['syntax'] for w in output['tokens']]
        display_tree(tree)
        
        # and the full json
        st.markdown("```json\n" + json.dumps(output, ensure_ascii=False, indent=2) + "\n```")