File size: 3,075 Bytes
4289090
 
 
 
 
 
 
b26b502
4289090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26b502
4289090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

import random

import gradio as gr
import spaces

from lib.graph_extract import triplextract, parse_triples
from lib.visualize import create_bokeh_plot #, create_plotly_plot
from lib.samples import snippets

WORD_LIMIT = 300

def process_text(text, entity_types, predicates):
    if not text:
        return None, "Please enter some text."

    words = text.split()
    if len(words) > WORD_LIMIT:
        return None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"

    entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
    predicates = [p.strip() for p in predicates.split(",") if p.strip()]

    if not entity_types:
        return None, "Please enter at least one entity type."
    if not predicates:
        return None, "Please enter at least one predicate."

    try:
        prediction = triplextract(text, entity_types, predicates)
        if prediction.startswith("Error"):
            return None, prediction

        entities, relationships = parse_triples(prediction)

        if not entities and not relationships:
            return (
                None,
                "No entities or relationships found. Try different text or check your input.",
            )

        fig = create_bokeh_plot(entities, relationships)
        return (
            fig,
            f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
        )
    except Exception as e:
        print(f"Error in process_text: {e}")
        return None, f"An error occurred: {str(e)}"

def update_inputs(sample_name):
    sample = snippets[sample_name]
    return sample.text_input, sample.entity_types, sample.predicates

with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
    gr.Markdown("# Knowledge Graph Extractor")
    
    default_sample_name = random.choice(list(snippets.keys()))
    default_sample = snippets[default_sample_name]
    
    with gr.Row():
        with gr.Column(scale=1):
            sample_dropdown = gr.Dropdown(
                choices=list(snippets.keys()),
                label="Select Sample",
                value=default_sample_name
            )
            input_text = gr.Textbox(
                label="Input Text",
                lines=5,
                value=default_sample.text_input
            )
            entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
            predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
            submit_btn = gr.Button("Extract Knowledge Graph")
        with gr.Column(scale=2):
            output_graph = gr.Plot(label="Knowledge Graph")
            error_message = gr.Textbox(label="Textual Output")

    sample_dropdown.change(
        update_inputs,
        inputs=[sample_dropdown],
        outputs=[input_text, entity_types, predicates]
    )

    submit_btn.click(
        process_text,
        inputs=[input_text, entity_types, predicates],
        outputs=[output_graph, error_message],
    )

if __name__ == "__main__":
    demo.launch()