JP_NER / app.py
sabari
Add application file
51de811
raw
history blame
4.53 kB
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 17 19:03:17 2024
@author: SABARI
"""
import os
import torch
from transformers import AutoConfig
from transformers.models.roberta.modeling_roberta import RobertaForTokenClassification
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import spacy
from spacy.tokens import Doc, Span
from spacy import displacy
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class JapaneseNER():
def __init__(self, model_path, model_name="xlm-roberta-base"):
self._index_to_tag = {0: 'O',
1: 'PER',
2: 'ORG',
3: 'ORG-P',
4: 'ORG-O',
5: 'LOC',
6: 'INS',
7: 'PRD',
8: 'EVT'}
self._tag_to_index = {v: k for k, v in self._index_to_tag.items()}
self._tag_feature_num_classes = len(self._index_to_tag)
self._model_name = model_name
self._model_path = model_path
xlmr_config = AutoConfig.from_pretrained(
self._model_name,
num_labels=self._tag_feature_num_classes,
id2label=self._index_to_tag,
label2id=self._tag_to_index
)
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name)
self.model = (RobertaForTokenClassification
.from_pretrained(self._model_path, config=xlmr_config)
.to(device))
def prepare(self):
# create dataset for prediction
sample_encoding = self.tokenizer([
"鈴木は4月の陽気の良い日に、鈴をつけて熊本県の阿蘇山に登った",
"中国では、中国共産党による一党統治が続く",
], truncation=True, max_length=512)
sample_encoding = {k: v.to(device) for k, v in sample_encoding.items()}
# Perform prediction
with torch.no_grad():
output = self.model(**sample_encoding)
predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()[0]
print("predictedl label",predicted_label_id)
def predict(self, text):
encoding = self.tokenizer([text], truncation=True, max_length=512, return_tensors="pt")
encoding = {k: v.to(device) for k, v in encoding.items()}
# Perform prediction
with torch.no_grad():
output = self.model(**encoding)
# Get the predicted label ids
predicted_label_id = torch.argmax(output.logits, axis=-1).cpu().numpy()[0]
tokens = self.tokenizer.tokenize(self.tokenizer.decode(encoding["input_ids"][0]))
# Map the predicted labels to their corresponding tag
predictions = [self._index_to_tag[label_id] for label_id in predicted_label_id]
return tokens, predictions
# Instantiate the NER model
model_path = "path_to_your_saved_model"
ner_model = JapaneseNER(model_path)
ner_model.prepare()
# Function to integrate with spaCy displacy for visualization
def ner_inference(text):
# Get tokens and predictions
tokens, predictions = ner_model.predict(text)
# Create a spaCy document to visualize with displacy
nlp = spacy.blank("ja") # Initialize a blank Japanese model in spaCy
doc = Doc(nlp.vocab, words=tokens) # Create a spaCy Doc object with tokens
# Create entity spans from predictions and add them to the Doc object
ents = []
for i, label in enumerate(predictions):
if label != 'O': # Skip non-entity tokens
span = Span(doc, i, i+1, label=label) # Create Span for the token
ents.append(span)
doc.ents = ents # Set the entities in the Doc
# Render using spacy displacy
html = displacy.render(doc, style="ent", jupyter=False) # Generate HTML for entities
return html
# Create Gradio interface
import gradio as gr
iface = gr.Interface(
fn=ner_inference, # The function to call for prediction
inputs=gr.Textbox(lines=5, placeholder="Enter Japanese text for NER..."), # Input widget
outputs="html", # Output will be in HTML format using displacy
title="Japanese Named Entity Recognition (NER)",
description="Enter Japanese text and see the named entities highlighted in the output."
)
# Launch the interface
iface.launch()