File size: 3,256 Bytes
f9bc688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset

# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
    "TheBritishLibrary/blbooks",
    split="train",
    streaming=True,  # Enable streaming
    trust_remote_code=True
).shuffle(seed=42)  # Shuffle added

# Load the model
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)

def ner(text: str, labels: str, threshold: float):
    # Convert user-provided labels (comma-separated string) into a list
    labels_list = [label.strip() for label in labels.split(",")]
    
    # Predict entities using the fine-tuned GLiNER model
    entities = model.predict_entities(text, labels_list, flat_ner=True, threshold=threshold)
    
    # Prepare data for HighlightedText
    highlighted_text = text
    for ent in sorted(entities, key=lambda x: x['start'], reverse=True):
        highlighted_text = (
            highlighted_text[:ent['start']] +
            f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" +
            highlighted_text[ent['end']:]
        )
    
    return highlighted_text, entities

with gr.Blocks(title="General NER Demo") as demo:
    gr.Markdown(
        """
        # General Entity Recognition Demo
        This demo selects a random text snippet from a subset of the British Library's books dataset and identifies entities using a fine-tuned GLiNER model. You can specify the entities you want to find.
        """
    )
    
    # Display a random example
    input_text = gr.Textbox(
        value="The machine is fed by means of an endless apron, the wool entering at the smaller end...",
        label="Text input",
        placeholder="Enter your text here",
        lines=5
    )

    with gr.Row() as row:
        labels = gr.Textbox(
            value="Machine, Wool",  # Default example labels
            label="Labels",
            placeholder="Enter your labels here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=0.5,  # Adjusted to match the threshold used in the function
            step=0.01,
            label="Threshold",
            info="Lower the threshold to increase how many entities get predicted.",
            scale=1,
        )

    # Define output components
    output_highlighted = gr.HTML(label="Predicted Entities")
    output_entities = gr.JSON(label="Entities")
    
    submit_btn = gr.Button("Find Entities!")
    refresh_btn = gr.Button("Get New Snippet")

    def get_new_snippet():
        attempts = 0
        max_attempts = 1000  # Prevent infinite loops
        for sample in dataset_iter:
            return sample['text']
        return "No more snippets available."  # Return this if no valid snippets are found  
    
    # Connect refresh button
    refresh_btn.click(fn=get_new_snippet, outputs=input_text)
    
    # Connect submit button
    submit_btn.click(
        fn=lambda text, labels, threshold: ner(text, labels, threshold),
        inputs=[input_text, labels, threshold],
        outputs=[output_highlighted, output_entities]
    )

demo.queue()
demo.launch(debug=True)