import gradio as gr from typing import Dict, List, Any import pandas as pd import json import re import html as html_lib from tasks.ner import named_entity_recognition from utils.ner_helpers import NER_ENTITY_TYPES, DEFAULT_SELECTED_ENTITIES, is_llm_model # The ner_ui function and related logic moved from app.py def ner_ui(): # Default entity types for the multi-select DEFAULT_ENTITY_TYPES = list(NER_ENTITY_TYPES.keys()) def ner(text: str, model: str, entity_types: List[str]) -> Dict[str, Any]: """Extract named entities, automatically using LLM for supported models.""" if not text.strip(): return {"text": "", "entities": []} try: use_llm = is_llm_model(model) # Call the enhanced NER function entities = named_entity_recognition( text=text, model=model, use_llm=use_llm, entity_types=entity_types if use_llm else None ) # Convert to the format expected by the UI if not isinstance(entities, list): entities = [] if not use_llm and entity_types: entities = [e for e in entities if e.get("type", "") in entity_types or e.get("entity", "") in entity_types] return { "entities": [ { "entity": e.get("type", ""), "word": e.get("text", ""), "start": e.get("start", 0), "end": e.get("end", 0), "score": e.get("confidence", 1.0), "description": e.get("description", "") } for e in entities ] } except Exception as e: print(f"Error in NER: {str(e)}") return {"entities": []} def render_ner_html(text, entities): # COMPLETELY REVISED APPROACH: Clean inline display of entities with proper positioning if not text.strip() or not entities: return "
No named entities found in the text.
" COLORS = [ '#e3f2fd', '#e8f5e9', '#fff8e1', '#f3e5f5', '#e8eaf6', '#e0f7fa', '#f1f8e9', '#fce4ec', '#e8f5e9', '#f5f5f5', '#fafafa', '#e1f5fe', '#fff3e0', '#d7ccc8', '#f9fbe7', '#fbe9e7', '#ede7f6', '#e0f2f1' ] # Clean up entities and extract necessary data clean_entities = [] label_colors = {} for ent in entities: # Extract label label = ent.get('type') or ent.get('entity') if not label: continue # Skip entities without label # Extract text entity_text = ent.get('text') or ent.get('word') if not entity_text: continue # Skip entities without text # Get positions if available start = ent.get('start', -1) end = ent.get('end', -1) # Verify that entity text matches the span in the original text # This ensures positions are correct if start >= 0 and end > start and end <= len(text): span_text = text[start:end] if entity_text != span_text and not text[start:end].strip().startswith(entity_text): # Try to find the entity in the text if position doesn't match found = False for match in re.finditer(re.escape(entity_text), text): if not found: start = match.start() end = match.end() found = True else: # Try to find the entity in the text if no position information found = False for match in re.finditer(re.escape(entity_text), text): if not found: start = match.start() end = match.end() found = True # Assign color based on label if label not in label_colors: label_colors[label] = COLORS[len(label_colors) % len(COLORS)] clean_entities.append({ 'text': entity_text, 'label': label, 'color': label_colors[label], 'start': start, 'end': end }) # Sort entities by position (important for proper rendering) clean_entities.sort(key=lambda x: x['start']) # Check for overlapping entities and resolve conflicts non_overlapping = [] if clean_entities: non_overlapping.append(clean_entities[0]) for i in range(1, len(clean_entities)): current = clean_entities[i] prev = non_overlapping[-1] # Check if current entity overlaps with previous one if current['start'] < prev['end']: # Skip overlapping entity to avoid confusion continue else: non_overlapping.append(current) # Generate HTML with proper inline highlighting html = ["
"] # Process text sequentially with entity markers last_pos = 0 for entity in non_overlapping: start = entity['start'] end = entity['end'] # Add text before entity if start > last_pos: html.append(html_lib.escape(text[last_pos:start])) # Add the entity with its label (with spacing between entity and label) html.append(f"") html.append(f"{html_lib.escape(entity['text'])} ") html.append(f"{html_lib.escape(entity['label'])}") html.append("") # Update position last_pos = end # Add any remaining text if last_pos < len(text): html.append(html_lib.escape(text[last_pos:])) html.append("
") return "".join(html) def update_ui(model_id: str) -> Dict: """Update the UI based on the selected model.""" use_llm = is_llm_model(model_id) return { entity_types_group: gr.Group(visible=use_llm) } with gr.Row(): with gr.Column(scale=2): input_text = gr.Textbox( label="Input Text", lines=8, placeholder="Enter text to analyze for named entities..." ) gr.Examples( examples=[ ["Barack Obama was born in Hawaii and became the 44th President of the United States."], ["Google is headquartered in Mountain View, California."] ], inputs=[input_text], label="Examples" ) model_dropdown = gr.Dropdown( ["gemini-2.0-flash"], # Only allow gemini-2.0-flash for now value="gemini-2.0-flash", label="Model" ) with gr.Group() as entity_types_group: entity_types = gr.CheckboxGroup( label="Entity Types to Extract", choices=DEFAULT_ENTITY_TYPES, value=DEFAULT_SELECTED_ENTITIES, interactive=True ) with gr.Row(): select_all_btn = gr.Button("Select All", size="sm") clear_all_btn = gr.Button("Clear All", size="sm") btn = gr.Button("Extract Entities", variant="primary") # Button handlers for entity selection def select_all_entities(): return gr.CheckboxGroup(value=DEFAULT_ENTITY_TYPES) def clear_all_entities(): return gr.CheckboxGroup(value=[]) select_all_btn.click( fn=select_all_entities, outputs=[entity_types] ) clear_all_btn.click( fn=clear_all_entities, outputs=[entity_types] ) with gr.Column(scale=3): # Output with tabs with gr.Tabs() as output_tabs: with gr.Tab("Tagged View", id="tagged-view-ner"): no_results_html = gr.HTML( "
" "Enter text and click 'Extract Entities' to get results.
", visible=True ) output_html = gr.HTML( label="NER Highlighted", elem_id="ner-output-html", visible=False ) # Add CSS for NER tags (scoped to this component) gr.HTML(""" """) with gr.Tab("Table View", id="table-view-ner"): no_results_table = gr.HTML( "
" "Enter text and click 'Extract Entities' to get results.
", visible=True ) output_table = gr.Dataframe( label="Extracted Entities", headers=["Type", "Text", "Confidence", "Description"], datatype=["str", "str", "number", "str"], interactive=False, wrap=True, visible=False ) # Update the UI when the model changes model_dropdown.change( fn=update_ui, inputs=[model_dropdown], outputs=[entity_types_group] ) def process_and_show_results(text: str, model: str, entity_types: List[str]): """Process NER and return both the results and UI state""" if not text.strip(): msg = "
Please enter some text to analyze.
" return [ gr.HTML(visible=False), # output_html gr.HTML(msg, visible=True), # no_results_html gr.DataFrame(visible=False), # output_table gr.HTML(msg, visible=True) # no_results_table ] if not entity_types: entity_types = list(NER_ENTITY_TYPES.keys()) result = ner(text, model, entity_types) entities = result["entities"] if result and "entities" in result else [] # DataFrame for table view if entities: df = pd.DataFrame(entities) if not df.empty: df = df.rename(columns={ "entity": "Type", "word": "Text", "score": "Confidence", "description": "Description" }) display_columns = ["Type", "Text", "Confidence", "Description"] df = df[[col for col in display_columns if col in df.columns]] if 'start' in df.columns: df = df.sort_values('start') html = render_ner_html(text, entities) return [ gr.HTML(html, visible=True), # output_html gr.HTML(visible=False), # no_results_html gr.DataFrame(value=df, visible=True), # output_table gr.HTML(visible=False) # no_results_table ] # No entities found msg = "
No named entities found in the text.
" return [ gr.HTML(msg, visible=True), # output_html gr.HTML(visible=False), # no_results_html gr.DataFrame(visible=False), # output_table gr.HTML(msg, visible=True) # no_results_table ] # Set up the button click handler btn.click( fn=process_and_show_results, inputs=[input_text, model_dropdown, entity_types], outputs=[output_html, no_results_html, output_table, no_results_table] ) # Initial UI update update_ui(model_dropdown.value) return None