File size: 3,800 Bytes
1dc581a
999a2cb
1dc581a
999a2cb
1dc581a
85c3256
0b6e959
 
 
 
 
 
 
 
1dc581a
85c3256
999a2cb
1dc581a
999a2cb
4fd99d4
999a2cb
4fd99d4
85c3256
999a2cb
1dc581a
999a2cb
 
 
 
 
 
 
 
 
 
 
 
1dc581a
4fd99d4
999a2cb
 
 
 
4fd99d4
999a2cb
 
 
4fd99d4
1dc581a
999a2cb
 
 
 
 
 
 
 
4fd99d4
999a2cb
15b5366
999a2cb
 
 
 
 
1dc581a
999a2cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fd99d4
 
 
999a2cb
 
 
 
 
cbccbc9
 
 
 
 
 
 
 
 
 
 
 
999a2cb
4fd99d4
999a2cb
 
4fd99d4
999a2cb
 
 
4fd99d4
999a2cb
1dc581a
999a2cb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset

# Load the BL dataset with streaming
dataset_iter = iter(
    load_dataset(
        "TheBritishLibrary/blbooks",
        split="train",
        streaming=True,
        trust_remote_code=True
    ).shuffle(buffer_size=10000, seed=42)  # Shuffling added here
)

# Load the model
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)

def ner(text: str, labels: str, threshold: float, nested_ner: bool):
    # Split and clean labels
    labels = [label.strip() for label in labels.split(",")]
    
    # Predict entities using the fine-tuned GLiNER model
    entities = model.predict_entities(text, labels, flat_ner=not nested_ner, threshold=threshold)
    
    # Filter for "textile machinery" entities
    textile_entities = [
        {
            "entity": ent["label"],
            "word": ent["text"],
            "start": ent["start"],
            "end": ent["end"],
            "score": ent.get("score", 0),
        }
        for ent in entities
        if ent["label"].lower() == "textile machinery"
    ]
    
    # Prepare data for HighlightedText
    highlighted_text = text
    for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True):
        highlighted_text = (
            highlighted_text[:ent['start']] +
            f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" +
            highlighted_text[ent['end']:]
        )
    
    return highlighted_text, textile_entities

with gr.Blocks(title="Textile Machinery NER Demo") as demo:
    gr.Markdown(
        """
        # Textile Machinery Entity Recognition Demo
        This demo selects a random text snippet from the British Library's books dataset and identifies "textile machinery" entities using a fine-tuned GLiNER model.
        """
    )
    
    # Display a random example
    input_text = gr.Textbox(
        value=" ",
        label="Text input",
        placeholder="Enter your text here",
        lines=5
    )
    
    with gr.Row():
        labels = gr.Textbox(
            value="textile machinery",
            label="Labels",
            placeholder="Enter your labels here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=0.3,
            step=0.01,
            label="Threshold",
            info="Lower the threshold to increase how many entities get predicted.",
            scale=1,
        )
        nested_ner = gr.Checkbox(
            value=False,
            label="Nested NER",
            info="Allow for nested NER?",
            scale=0,
        )
    
    # Define output components
    output_highlighted = gr.HTML(label="Predicted Entities")
    output_entities = gr.JSON(label="Entities")
    
    submit_btn = gr.Button("Analyze Random Snippet")
    refresh_btn = gr.Button("Get New Snippet")
    
    def get_new_snippet():
        attempts = 0
        max_attempts = 1000  # Prevent infinite loops
        while attempts < max_attempts:
            try:
                sample = next(dataset_iter)
                title = sample.get('title', '')
                if title and 'textile' in title.lower():
                    return title
                attempts += 1
            except StopIteration:
                break
                return "No more snippets available."
    
    # Connect refresh button
    refresh_btn.click(fn=get_new_snippet, outputs=input_text)
    
    # Connect submit button
    submit_btn.click(
        fn=ner,
        inputs=[input_text, labels, threshold, nested_ner],
        outputs=[output_highlighted, output_entities]
    )

demo.queue()
demo.launch(debug=True)