File size: 3,532 Bytes
f9bc688 7fd6f11 f9bc688 f625748 f9bc688 e947e04 f625748 e947e04 f9bc688 f625748 9daea47 f9bc688 e947e04 f625748 e947e04 f625748 e947e04 f9bc688 f625748 f9bc688 df09c16 f9bc688 df09c16 f9bc688 df09c16 f9bc688 f625748 f9bc688 f625748 f9bc688 f625748 f9bc688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the BL dataset as a streaming iterator
dataset_iter = load_dataset(
"TheBritishLibrary/blbooks",
split="train",
streaming=True, # Enable streaming
trust_remote_code=True
).shuffle(seed=42) # Shuffle added
# Load the model
model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True)
def ner(text: str, labels: str, threshold: float, nested_ner: bool):
# Convert user-provided labels (comma-separated string) into a list
labels_list = [label.strip() for label in labels.split(",")]
# Truncate the text to avoid length exceeding model limits (e.g., 384 tokens)
max_length = 384
truncated_text = text[:max_length]
# Predict entities using the GLiNER model
entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold)
# Prepare entities for color-coded display using gr.HighlightedText
highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities]
# Return both the highlighted text and the raw entities in JSON format
return {
"text": truncated_text,
"entities": highlights
}, entities # Return both outputs: the first for HighlightedText, the second for JSON
with gr.Blocks(title="General NER with Color-Coded Output") as demo:
gr.Markdown(
"""
# General Entity Recognition Demo
This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1).
"""
)
# Display a random example
input_text = gr.Textbox(
value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset",
label="Text input",
placeholder="Enter your text here",
lines=5
)
with gr.Row() as row:
labels = gr.Textbox(
value="Person, Location", # Default example labels
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.5, # Adjusted to match the threshold used in the function
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
nested_ner = gr.Checkbox(
value=False,
label="Nested NER",
info="Enable Nested NER?",
)
# Define output components using HighlightedText for color-coded display
output_highlighted = gr.HighlightedText(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
submit_btn = gr.Button("Find Entities!")
refresh_btn = gr.Button("Get New Snippet")
def get_new_snippet():
attempts = 0
max_attempts = 1000 # Prevent infinite loops
for sample in dataset_iter:
return sample['text']
return "No more snippets available." # Return this if no valid snippets are found
# Connect refresh button
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
# Connect submit button
submit_btn.click(
fn=ner,
inputs=[input_text, labels, threshold, nested_ner],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True) |