Spaces:
Sleeping
Sleeping
import random | |
from gliner import GLiNER | |
import gradio as gr | |
from datasets import load_dataset | |
# Load the BL dataset with streaming | |
dataset_iter = iter( | |
load_dataset( | |
"TheBritishLibrary/blbooks", | |
split="train", | |
streaming=True, | |
trust_remote_code=True | |
).shuffle(buffer_size=10000, seed=42) # Shuffling added here | |
) | |
# Load the model | |
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True) | |
def ner(text: str, labels: str, threshold: float, nested_ner: bool): | |
# Split and clean labels | |
labels = [label.strip() for label in labels.split(",")] | |
# Predict entities using the fine-tuned GLiNER model | |
entities = model.predict_entities(text, labels, flat_ner=not nested_ner, threshold=threshold) | |
# Filter for "textile machinery" entities | |
textile_entities = [ | |
{ | |
"entity": ent["label"], | |
"word": ent["text"], | |
"start": ent["start"], | |
"end": ent["end"], | |
"score": ent.get("score", 0), | |
} | |
for ent in entities | |
if ent["label"].lower() == "textile machinery" | |
] | |
# Prepare data for HighlightedText | |
highlighted_text = text | |
for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True): | |
highlighted_text = ( | |
highlighted_text[:ent['start']] + | |
f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" + | |
highlighted_text[ent['end']:] | |
) | |
return highlighted_text, textile_entities | |
with gr.Blocks(title="Textile Machinery NER Demo") as demo: | |
gr.Markdown( | |
""" | |
# Textile Machinery Entity Recognition Demo | |
This demo selects a random text snippet from the British Library's books dataset and identifies "textile machinery" entities using a fine-tuned GLiNER model. | |
""" | |
) | |
# Display a random example | |
input_text = gr.Textbox( | |
value=" ", | |
label="Text input", | |
placeholder="Enter your text here", | |
lines=5 | |
) | |
with gr.Row(): | |
labels = gr.Textbox( | |
value="textile machinery", | |
label="Labels", | |
placeholder="Enter your labels here (comma separated)", | |
scale=2, | |
) | |
threshold = gr.Slider( | |
0, | |
1, | |
value=0.3, | |
step=0.01, | |
label="Threshold", | |
info="Lower the threshold to increase how many entities get predicted.", | |
scale=1, | |
) | |
nested_ner = gr.Checkbox( | |
value=False, | |
label="Nested NER", | |
info="Allow for nested NER?", | |
scale=0, | |
) | |
# Define output components | |
output_highlighted = gr.HTML(label="Predicted Entities") | |
output_entities = gr.JSON(label="Entities") | |
submit_btn = gr.Button("Analyze Random Snippet") | |
refresh_btn = gr.Button("Get New Snippet") | |
def get_new_snippet(): | |
attempts = 0 | |
max_attempts = 1000 # Prevent infinite loops | |
while attempts < max_attempts: | |
try: | |
sample = next(dataset_iter) | |
title = sample.get('title', '') | |
if title and 'textile' in title.lower(): | |
return title | |
attempts += 1 | |
except StopIteration: | |
break | |
return "No more snippets available." | |
# Connect refresh button | |
refresh_btn.click(fn=get_new_snippet, outputs=input_text) | |
# Connect submit button | |
submit_btn.click( | |
fn=ner, | |
inputs=[input_text, labels, threshold, nested_ner], | |
outputs=[output_highlighted, output_entities] | |
) | |
demo.queue() | |
demo.launch(debug=True) |