max-long's picture
Update app.py
4008dee verified
import pandas as pd
import random
from gliner import GLiNER
import gradio as gr
from datasets import load_dataset
# Load the CSV file
df = pd.read_csv("1921_catalogue_SMG.csv") # Replace with your actual CSV file path
text_column = "Description" # Replace with the actual column name containing the text data
# Load the model
model = GLiNER.from_pretrained("max-long/textile_machines_ner_5_oct", trust_remote_code=True)
def get_new_snippet():
# Randomly select a snippet from the CSV file
if len(df) > 0:
sample = df.sample(n=1)[text_column].values[0]
return sample
else:
return "No more snippets available." # Return this if the CSV file is empty
def ner(text: str):
labels = ["Textile Machinery"]
threshold = 0.5
# Predict entities using the fine-tuned GLiNER model
entities = model.predict_entities(text, labels, flat_ner=True, threshold=threshold)
# Filter for "Textile Machinery" entities
textile_entities = [
{
"entity": ent["label"],
"word": ent["text"],
"start": ent["start"],
"end": ent["end"],
"score": ent.get("score", 0),
}
for ent in entities
if ent["label"] == "Textile Machinery"
]
# Prepare entities for color-coded display using gr.HighlightedText in the required dictionary format
highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["entity"]} for ent in textile_entities]
# Return two outputs: one for the highlighted text and one for the entities in JSON format
return {
"text": text,
"entities": highlights
}, textile_entities
# Gradio Interface
with gr.Blocks(title="Textile Machinery NER Demo") as demo:
gr.Markdown(
"""
# Textile Machinery Entity Recognition Demo
This demo selects a random text snippet from the Science Museum's 1921 catalogue and identifies "Textile Machinery" entities using a fine-tuned GLiNER model developed by the Congruence Engine project.
"""
)
input_text = gr.Textbox(
value="Enter or refresh to get text from CSV",
label="Text input",
placeholder="Enter your text here",
lines=5
)
refresh_btn = gr.Button("Get New Snippet")
# Use HighlightedText to show the entities
output_highlighted = gr.HighlightedText(label="Predicted Entities")
output_entities = gr.JSON(label="Entities")
submit_btn = gr.Button("Find Textile Machinery!")
refresh_btn.click(fn=get_new_snippet, outputs=input_text)
submit_btn.click(
fn=ner,
inputs=[input_text],
outputs=[output_highlighted, output_entities]
)
demo.queue()
demo.launch(debug=True)