File size: 3,532 Bytes
f9bc688
 
 
 
 
 
 
 
 
 
 
 
 
 
7fd6f11
f9bc688
f625748
f9bc688
 
 
e947e04
 
 
 
f625748
e947e04
f9bc688
f625748
9daea47
f9bc688
e947e04
f625748
e947e04
f625748
e947e04
f9bc688
f625748
f9bc688
 
 
df09c16
f9bc688
 
 
 
 
df09c16
f9bc688
 
 
 
 
 
 
df09c16
f9bc688
 
 
 
 
 
 
 
 
 
 
 
 
f625748
 
 
 
 
f9bc688
f625748
 
f9bc688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f625748
 
f9bc688
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset

# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
    "TheBritishLibrary/blbooks",
    split="train",
    streaming=True,  # Enable streaming
    trust_remote_code=True
).shuffle(seed=42)  # Shuffle added

# Load the model
model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True)

def ner(text: str, labels: str, threshold: float, nested_ner: bool):
    # Convert user-provided labels (comma-separated string) into a list
    labels_list = [label.strip() for label in labels.split(",")]
    
    # Truncate the text to avoid length exceeding model limits (e.g., 384 tokens)
    max_length = 384
    truncated_text = text[:max_length]
    
    # Predict entities using the GLiNER model
    entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold)
    
    # Prepare entities for color-coded display using gr.HighlightedText
    highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities]
    
    # Return both the highlighted text and the raw entities in JSON format
    return {
        "text": truncated_text,
        "entities": highlights
    }, entities  # Return both outputs: the first for HighlightedText, the second for JSON

with gr.Blocks(title="General NER with Color-Coded Output") as demo:
    gr.Markdown(
        """
        # General Entity Recognition Demo
        This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1). 
        """
    )
    
    # Display a random example
    input_text = gr.Textbox(
        value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset",
        label="Text input",
        placeholder="Enter your text here",
        lines=5
    )

    with gr.Row() as row:
        labels = gr.Textbox(
            value="Person, Location",  # Default example labels
            label="Labels",
            placeholder="Enter your labels here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=0.5,  # Adjusted to match the threshold used in the function
            step=0.01,
            label="Threshold",
            info="Lower the threshold to increase how many entities get predicted.",
            scale=1,
        )
        nested_ner = gr.Checkbox(
            value=False,
            label="Nested NER",
            info="Enable Nested NER?",
        )

    # Define output components using HighlightedText for color-coded display
    output_highlighted = gr.HighlightedText(label="Predicted Entities")
    output_entities = gr.JSON(label="Entities")
    
    submit_btn = gr.Button("Find Entities!")
    refresh_btn = gr.Button("Get New Snippet")

    def get_new_snippet():
        attempts = 0
        max_attempts = 1000  # Prevent infinite loops
        for sample in dataset_iter:
            return sample['text']
        return "No more snippets available."  # Return this if no valid snippets are found  
    
    # Connect refresh button
    refresh_btn.click(fn=get_new_snippet, outputs=input_text)
    
    # Connect submit button
    submit_btn.click(
        fn=ner,
        inputs=[input_text, labels, threshold, nested_ner],
        outputs=[output_highlighted, output_entities]
    )

demo.queue()
demo.launch(debug=True)