|
import gradio as gr |
|
from typing import Dict, List, Any |
|
import pandas as pd |
|
import json |
|
import re |
|
import html as html_lib |
|
from tasks.ner import named_entity_recognition |
|
from utils.ner_helpers import NER_ENTITY_TYPES, DEFAULT_SELECTED_ENTITIES, is_llm_model |
|
|
|
|
|
|
|
def ner_ui(): |
|
|
|
DEFAULT_ENTITY_TYPES = list(NER_ENTITY_TYPES.keys()) |
|
|
|
def ner(text: str, model: str, entity_types: List[str]) -> Dict[str, Any]: |
|
"""Extract named entities, automatically using LLM for supported models.""" |
|
if not text.strip(): |
|
return {"text": "", "entities": []} |
|
|
|
try: |
|
use_llm = is_llm_model(model) |
|
|
|
entities = named_entity_recognition( |
|
text=text, |
|
model=model, |
|
use_llm=use_llm, |
|
entity_types=entity_types if use_llm else None |
|
) |
|
|
|
|
|
if not isinstance(entities, list): |
|
entities = [] |
|
|
|
if not use_llm and entity_types: |
|
entities = [e for e in entities if e.get("type", "") in entity_types or e.get("entity", "") in entity_types] |
|
|
|
return { |
|
"entities": [ |
|
{ |
|
"entity": e.get("type", ""), |
|
"word": e.get("text", ""), |
|
"start": e.get("start", 0), |
|
"end": e.get("end", 0), |
|
"score": e.get("confidence", 1.0), |
|
"description": e.get("description", "") |
|
} |
|
for e in entities |
|
] |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error in NER: {str(e)}") |
|
return {"entities": []} |
|
|
|
def render_ner_html(text, entities): |
|
|
|
if not text.strip() or not entities: |
|
return "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>" |
|
|
|
COLORS = [ |
|
'#e3f2fd', '#e8f5e9', '#fff8e1', '#f3e5f5', '#e8eaf6', '#e0f7fa', |
|
'#f1f8e9', '#fce4ec', '#e8f5e9', '#f5f5f5', '#fafafa', '#e1f5fe', |
|
'#fff3e0', '#d7ccc8', '#f9fbe7', '#fbe9e7', '#ede7f6', '#e0f2f1' |
|
] |
|
|
|
|
|
clean_entities = [] |
|
label_colors = {} |
|
|
|
for ent in entities: |
|
|
|
label = ent.get('type') or ent.get('entity') |
|
if not label: |
|
continue |
|
|
|
|
|
entity_text = ent.get('text') or ent.get('word') |
|
if not entity_text: |
|
continue |
|
|
|
|
|
start = ent.get('start', -1) |
|
end = ent.get('end', -1) |
|
|
|
|
|
|
|
if start >= 0 and end > start and end <= len(text): |
|
span_text = text[start:end] |
|
if entity_text != span_text and not text[start:end].strip().startswith(entity_text): |
|
|
|
found = False |
|
for match in re.finditer(re.escape(entity_text), text): |
|
if not found: |
|
start = match.start() |
|
end = match.end() |
|
found = True |
|
else: |
|
|
|
found = False |
|
for match in re.finditer(re.escape(entity_text), text): |
|
if not found: |
|
start = match.start() |
|
end = match.end() |
|
found = True |
|
|
|
|
|
if label not in label_colors: |
|
label_colors[label] = COLORS[len(label_colors) % len(COLORS)] |
|
|
|
clean_entities.append({ |
|
'text': entity_text, |
|
'label': label, |
|
'color': label_colors[label], |
|
'start': start, |
|
'end': end |
|
}) |
|
|
|
|
|
clean_entities.sort(key=lambda x: x['start']) |
|
|
|
|
|
non_overlapping = [] |
|
if clean_entities: |
|
non_overlapping.append(clean_entities[0]) |
|
for i in range(1, len(clean_entities)): |
|
current = clean_entities[i] |
|
prev = non_overlapping[-1] |
|
|
|
|
|
if current['start'] < prev['end']: |
|
|
|
continue |
|
else: |
|
non_overlapping.append(current) |
|
|
|
|
|
html = ["<div class='ner-highlight' style='line-height:1.6;padding:15px;border:1px solid #e0e0e0;border-radius:4px;background:#f9f9f9;white-space:pre-wrap;'>"] |
|
|
|
|
|
last_pos = 0 |
|
for entity in non_overlapping: |
|
start = entity['start'] |
|
end = entity['end'] |
|
|
|
|
|
if start > last_pos: |
|
html.append(html_lib.escape(text[last_pos:start])) |
|
|
|
|
|
html.append(f"<span style='background:{entity['color']};border-radius:3px;padding:2px 4px;margin:0 1px;border:1px solid rgba(0,0,0,0.1);'>") |
|
html.append(f"{html_lib.escape(entity['text'])} ") |
|
html.append(f"<span style='font-size:0.8em;font-weight:bold;color:#555;border-radius:2px;padding:0 2px;background:rgba(255,255,255,0.7);'>{html_lib.escape(entity['label'])}</span>") |
|
html.append("</span>") |
|
|
|
|
|
last_pos = end |
|
|
|
|
|
if last_pos < len(text): |
|
html.append(html_lib.escape(text[last_pos:])) |
|
|
|
html.append("</div>") |
|
return "".join(html) |
|
|
|
def update_ui(model_id: str) -> Dict: |
|
"""Update the UI based on the selected model.""" |
|
use_llm = is_llm_model(model_id) |
|
return { |
|
entity_types_group: gr.Group(visible=use_llm) |
|
} |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
lines=8, |
|
placeholder="Enter text to analyze for named entities..." |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["Barack Obama was born in Hawaii and became the 44th President of the United States."], |
|
["Google is headquartered in Mountain View, California."] |
|
], |
|
inputs=[input_text], |
|
label="Examples" |
|
) |
|
model_dropdown = gr.Dropdown( |
|
["gemini-2.0-flash"], |
|
value="gemini-2.0-flash", |
|
label="Model" |
|
) |
|
|
|
with gr.Group() as entity_types_group: |
|
entity_types = gr.CheckboxGroup( |
|
label="Entity Types to Extract", |
|
choices=DEFAULT_ENTITY_TYPES, |
|
value=DEFAULT_SELECTED_ENTITIES, |
|
interactive=True |
|
) |
|
with gr.Row(): |
|
select_all_btn = gr.Button("Select All", size="sm") |
|
clear_all_btn = gr.Button("Clear All", size="sm") |
|
|
|
btn = gr.Button("Extract Entities", variant="primary") |
|
|
|
|
|
def select_all_entities(): |
|
return gr.CheckboxGroup(value=DEFAULT_ENTITY_TYPES) |
|
|
|
def clear_all_entities(): |
|
return gr.CheckboxGroup(value=[]) |
|
|
|
select_all_btn.click( |
|
fn=select_all_entities, |
|
outputs=[entity_types] |
|
) |
|
|
|
clear_all_btn.click( |
|
fn=clear_all_entities, |
|
outputs=[entity_types] |
|
) |
|
|
|
with gr.Column(scale=3): |
|
|
|
with gr.Tabs() as output_tabs: |
|
with gr.Tab("Tagged View", id="tagged-view-ner"): |
|
no_results_html = gr.HTML( |
|
"<div style='text-align: center; color: #666; padding: 20px;'>" |
|
"Enter text and click 'Extract Entities' to get results.</div>", |
|
visible=True |
|
) |
|
output_html = gr.HTML( |
|
label="NER Highlighted", |
|
elem_id="ner-output-html", |
|
visible=False |
|
) |
|
|
|
gr.HTML(""" |
|
<style> |
|
#ner-output-html .pos-highlight { |
|
white-space: pre-wrap; |
|
line-height: 1.8; |
|
font-size: 14px; |
|
padding: 15px; |
|
border: 1px solid #e0e0e0; |
|
border-radius: 4px; |
|
background: #f9f9f9; |
|
} |
|
#ner-output-html .pos-token { |
|
display: inline-block; |
|
margin: 0 2px 4px 0; |
|
vertical-align: top; |
|
text-align: center; |
|
} |
|
#ner-output-html .token-text { |
|
display: block; |
|
padding: 2px 8px; |
|
background: #f0f4f8; |
|
border-radius: 4px 4px 0 0; |
|
border: 1px solid #dbe4ed; |
|
border-bottom: none; |
|
font-size: 0.9em; |
|
} |
|
#ner-output-html .pos-tag { |
|
display: block; |
|
padding: 2px 8px; |
|
border-radius: 0 0 4px 4px; |
|
#ner-output-html .WORK_OF_ART { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; } |
|
#ner-output-html .LAW { background-color: #fce4ec; border-color: #f8bbd0; color: #880e4f; } |
|
#ner-output-html .LANGUAGE { background-color: #e8f5e9; border-color: #c8e6c9; color: #1b5e20; font-weight: bold; } |
|
#ner-output-html .DATE { background-color: #f5f5f5; border-color: #e0e0e0; color: #424242; } |
|
#ner-output-html .TIME { background-color: #fafafa; border-color: #f5f5f5; color: #616161; } |
|
#ner-output-html .PERCENT { background-color: #e1f5fe; border-color: #b3e5fc; color: #01579b; font-weight: bold; } |
|
#ner-output-html .MONEY { background-color: #f3e5f5; border-color: #e1bee7; color: #6a1b9a; } |
|
#ner-output-html .QUANTITY { background-color: #f1f8e9; border-color: #dcedc8; color: #33691e; font-style: italic; } |
|
#ner-output-html .ORDINAL { background-color: #fff3e0; border-color: #ffe0b2; color: #e65100; } |
|
#ner-output-html .CARDINAL { background-color: #ede7f6; border-color: #d1c4e9; color: #4527a0; } |
|
</style> |
|
""") |
|
with gr.Tab("Table View", id="table-view-ner"): |
|
no_results_table = gr.HTML( |
|
"<div style='text-align: center; color: #666; padding: 20px;'>" |
|
"Enter text and click 'Extract Entities' to get results.</div>", |
|
visible=True |
|
) |
|
output_table = gr.Dataframe( |
|
label="Extracted Entities", |
|
headers=["Type", "Text", "Confidence", "Description"], |
|
datatype=["str", "str", "number", "str"], |
|
interactive=False, |
|
wrap=True, |
|
visible=False |
|
) |
|
|
|
|
|
model_dropdown.change( |
|
fn=update_ui, |
|
inputs=[model_dropdown], |
|
outputs=[entity_types_group] |
|
) |
|
|
|
def process_and_show_results(text: str, model: str, entity_types: List[str]): |
|
"""Process NER and return both the results and UI state""" |
|
if not text.strip(): |
|
msg = "<div style='text-align: center; color: #f44336; padding: 20px;'>Please enter some text to analyze.</div>" |
|
return [ |
|
gr.HTML(visible=False), |
|
gr.HTML(msg, visible=True), |
|
gr.DataFrame(visible=False), |
|
gr.HTML(msg, visible=True) |
|
] |
|
if not entity_types: |
|
entity_types = list(NER_ENTITY_TYPES.keys()) |
|
result = ner(text, model, entity_types) |
|
entities = result["entities"] if result and "entities" in result else [] |
|
|
|
if entities: |
|
df = pd.DataFrame(entities) |
|
if not df.empty: |
|
df = df.rename(columns={ |
|
"entity": "Type", |
|
"word": "Text", |
|
"score": "Confidence", |
|
"description": "Description" |
|
}) |
|
display_columns = ["Type", "Text", "Confidence", "Description"] |
|
df = df[[col for col in display_columns if col in df.columns]] |
|
if 'start' in df.columns: |
|
df = df.sort_values('start') |
|
html = render_ner_html(text, entities) |
|
return [ |
|
gr.HTML(html, visible=True), |
|
gr.HTML(visible=False), |
|
gr.DataFrame(value=df, visible=True), |
|
gr.HTML(visible=False) |
|
] |
|
|
|
msg = "<div style='text-align: center; color: #666; padding: 20px;'>No named entities found in the text.</div>" |
|
return [ |
|
gr.HTML(msg, visible=True), |
|
gr.HTML(visible=False), |
|
gr.DataFrame(visible=False), |
|
gr.HTML(msg, visible=True) |
|
] |
|
|
|
|
|
btn.click( |
|
fn=process_and_show_results, |
|
inputs=[input_text, model_dropdown, entity_types], |
|
outputs=[output_html, no_results_html, output_table, no_results_table] |
|
) |
|
|
|
|
|
update_ui(model_dropdown.value) |
|
|
|
return None |
|
|