max-long's picture
Update app.py
999a2cb verified
raw
history blame
4.99 kB
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the subset dataset from Hugging Face Hub
subset_dataset = load_dataset("TheBritishLibrary/blbooks", split="train", streaming=True, trust_remote_code=True)
# Load the GLiNER model
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
# Define the NER function
def ner(text: str, labels: str, threshold: float, nested_ner: bool):
labels = [label.strip() for label in labels.split(",")]
entities = model.predict_entities(text, labels, flat_ner=not nested_ner, threshold=threshold)
# Filter for "textile machinery" entities
textile_entities = [
{
"entity": ent["label"],
"word": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": ent.get("score", 0),
}
for ent in entities
if ent["label"].lower() == "textile machinery"
]
# Highlight entities with HTML
highlighted_text = text
for ent in sorted(textile_entities, key=lambda x: x['start'], reverse=True):
highlighted_text = (
highlighted_text[:ent['start']] +
f"<span style='background-color: yellow'>{highlighted_text[ent['start']:ent['end']]}</span>" +
highlighted_text[ent['end']:]
)
return gr.HTML(highlighted_text), textile_entities
# Build Gradio interface
with gr.Blocks(title="Textile Machinery NER Demo") as demo:
gr.Markdown(
"""
# Textile Machinery Entity Recognition Demo
This demo selects a random text snippet from the British Library's books dataset and identifies "textile machinery" entities using a fine-tuned GLiNER model.
"""
)
with gr.Accordion("How to run this model locally", open=False):
gr.Markdown(
"""
## Installation
To use this model, you must install the GLiNER Python library:
```
!pip install gliner
```
## Usage
Once you've downloaded the GLiNER library, you can import the GLiNER class. You can then load this model using `GLiNER.from_pretrained` and predict entities with `predict_entities`.
"""
)
gr.Code(
'''
from gliner import GLiNER
model = GLiNER.from_pretrained("max-long/textile_machines_3_oct")
text = "Your sample text here."
labels = ["textile machinery"]
entities = model.predict_entities(text, labels)
for entity in entities:
print(entity["text"], "=>", entity["label"])
''',
language="python",
)
gr.Code(
"""
Textile Machine 1 => textile machinery
Textile Machine 2 => textile machinery
"""
)
input_text = gr.Textbox(
value="Amelia Earhart flew her single engine Lockheed Vega 5B across the Atlantic to Paris.",
label="Text input",
placeholder="Enter your text here",
lines=5
)
with gr.Row():
labels = gr.Textbox(
value="textile machinery",
label="Labels",
placeholder="Enter your labels here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=0.3,
step=0.01,
label="Threshold",
info="Lower the threshold to increase how many entities get predicted.",
scale=1,
)
nested_ner = gr.Checkbox(
value=False,
label="Nested NER",
info="Allow for nested NER?",
scale=0,
)
output = gr.HighlightedText(label="Predicted Entities")
submit_btn = gr.Button("Analyze Random Snippet")
refresh_btn = gr.Button("Get New Snippet")
# Function to fetch a new random snippet
def get_new_snippet():
# WARNING: Streaming datasets may have performance implications
try:
sample = next(iter(subset_dataset))['text']
return sample
except StopIteration:
return "No more snippets available."
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
submit_btn.click(
fn=ner,
inputs=[input_text, labels, threshold, nested_ner],
outputs=[output, gr.JSON(label="Entities")]
)
examples = [
[
"However, both models lack other frequent DM symptoms including the fibre-type dependent atrophy, myotonia, cataract and male-infertility.",
"textile machinery",
0.3,
False,
],
# Add more examples as needed
]
gr.Examples(
examples=examples,
inputs=[input_text, labels, threshold, nested_ner],
outputs=[output, gr.JSON(label="Entities")],
fn=ner,
label="Examples",
cache_examples=True,
)
demo.queue()
demo.launch(debug=True)