max-long's picture
Update app.py
cbccbc9 verified
raw
history blame
3.8 kB
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the BL dataset with streaming
dataset_iter = iter(
load_dataset(
"TheBritishLibrary/blbooks",
split="train",
streaming=True,
trust_remote_code=True
).shuffle(buffer_size=10000, seed=42) # Shuffling added here
)
# Load the model
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
def ner(text: str, labels: str, threshold: float, nested_ner: bool):
# Split and clean labels
labels = [label.strip() for label in labels.split(",")]
# Predict entities using the fine-tuned GLiNER model
entities = model.predict_entities(text, labels, flat_ner=not nested_ner, threshold=threshold)
# Filter for "textile machinery" entities
textile_entities = [
{
"entity": ent["label"],
"word": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": ent.get("score", 0),
}
for ent in entities
if ent["label"].lower() == "textile machinery"
]
# Prepare data for HighlightedText
highlighted_text = text
for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True):
highlighted_text = (
highlighted_text[:ent['start']] +
f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" +
highlighted_text[ent['end']:]
)
return highlighted_text, textile_entities
with gr.Blocks(title="Textile Machinery NER Demo") as demo:
gr.Markdown(
"""
# Textile Machinery Entity Recognition Demo
This demo selects a random text snippet from the British Library's books dataset and identifies "textile machinery" entities using a fine-tuned GLiNER model.
"""
)
# Display a random example
input_text = gr.Textbox(
value=" ",
label="Text input",
placeholder="Enter your text here",
lines=5
)
with gr.Row():
labels = gr.Textbox(
value="textile machinery",
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.3,
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
nested_ner = gr.Checkbox(
value=False,
label="Nested NER",
info="Allow for nested NER?",
scale=0,
)
# Define output components
output_highlighted = gr.HTML(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
submit_btn = gr.Button("Analyze Random Snippet")
refresh_btn = gr.Button("Get New Snippet")
def get_new_snippet():
attempts = 0
max_attempts = 1000 # Prevent infinite loops
while attempts < max_attempts:
try:
sample = next(dataset_iter)
title = sample.get('title', '')
if title and 'textile' in title.lower():
return title
attempts += 1
except StopIteration:
break
return "No more snippets available."
# Connect refresh button
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
# Connect submit button
submit_btn.click(
fn=ner,
inputs=[input_text, labels, threshold, nested_ner],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True)