disham993's picture
First Commit.
b4171e7
import streamlit as st
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import pandas as pd
from spacy import displacy
###########################
# Utility Function for Cleanup
###########################
def clean_and_group_entities(ner_results, min_score=0.40):
"""
Combines tokens for the same entity and filters out entities below the score threshold.
"""
grouped_entities = []
current_entity = None
for result in ner_results:
# Skip entities with a score below threshold
if result["score"] < min_score:
if current_entity:
# If the current entity meets threshold, add it
if current_entity["score"] >= min_score:
grouped_entities.append(current_entity)
current_entity = None
continue
# Remove any subword prefix "##"
word = result["word"].replace("##", "")
# Check if this result continues the current entity
if (current_entity
and result["entity_group"] == current_entity["entity_group"]
and result["start"] == current_entity["end"]):
# Update the current entity
current_entity["word"] += word
current_entity["end"] = result["end"]
# Keep the minimum score as the "weakest link"
current_entity["score"] = min(current_entity["score"], result["score"])
# If combined score now drops below threshold, discard the entity
if current_entity["score"] < min_score:
current_entity = None
else:
# Finalize the previous entity if valid
if current_entity and current_entity["score"] >= min_score:
grouped_entities.append(current_entity)
# Start a new entity
current_entity = {
"entity_group": result["entity_group"],
"word": word,
"start": result["start"],
"end": result["end"],
"score": result["score"]
}
# Add the last entity if it meets threshold
if current_entity and current_entity["score"] >= min_score:
grouped_entities.append(current_entity)
return grouped_entities
###########################
# Constants and Setup
###########################
MODELS = {
"ModernBERT Base": "disham993/electrical-ner-modernbert-base",
"BERT Base": "disham993/electrical-ner-bert-base",
"ModernBERT Large": "disham993/electrical-ner-modernbert-large",
"BERT Large": "disham993/electrical-ner-bert-large",
"DistilBERT Base": "disham993/electrical-ner-distilbert-base"
}
ENTITY_COLORS = {
"COMPONENT": "#FFB6C1",
"DESIGN_PARAM": "#98FB98",
"MATERIAL": "#DDA0DD",
"EQUIPMENT": "#87CEEB",
"TECHNOLOGY": "#F0E68C",
"SOFTWARE": "#FFD700",
"STANDARD": "#FFA07A",
"VENDOR": "#E6E6FA",
"PRODUCT": "#98FF98"
}
EXAMPLES = [
"Texas Instruments LM358 op-amp requires dual power supply.",
"Using a Multimeter, the technician measured the 10 kΞ© resistance of a Copper wire in the circuit.",
"To improve the reliability of the circuit, the engineer tested a 10k Ohm resistor with a multimeter from Fluke.",
"During the circuit design, we measured the current flow using a Fluke multimeter to ensure it was within the 10A specification."
]
@st.cache_resource
def load_model(model_name):
"""
Load and return a token classification pipeline with an aggregation strategy of 'simple'.
"""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
return pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple" # <-- Aggregation strategy
)
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None
def get_base_entity_type(entity_label):
"""
Strips off 'B-' or 'I-' prefix if present.
"""
if entity_label.startswith("B-") or entity_label.startswith("I-"):
return entity_label[2:]
return entity_label
def create_displacy_data(text, entities):
"""
Create data for spaCy's displacy visualizer.
"""
ents = []
for entity in entities:
base_type = get_base_entity_type(entity["entity_group"])
ents.append({
"start": entity["start"],
"end": entity["end"],
"label": base_type
})
colors = {entity_type: color for entity_type, color in ENTITY_COLORS.items()}
options = {"ents": list(ENTITY_COLORS.keys()), "colors": colors}
doc_data = {
"text": text,
"ents": ents,
"title": None
}
# Render with manual mode = True
html_content = displacy.render(doc_data, style="ent", options=options, manual=True)
return html_content
###########################
# Main Streamlit App
###########################
def main():
st.set_page_config(page_title="Electrical Engineering NER", page_icon="⚑", layout="wide")
st.title("⚑ Electrical Engineering Named Entity Recognition")
st.markdown("""
This application identifies technical entities in electrical engineering text using a fine-tuned BERT model.
It can recognize components, parameters, materials, equipment, and more.
""")
# Sidebar - Model Selection
st.sidebar.title("Model Configuration")
selected_model_name = st.sidebar.selectbox(
"Select Model",
list(MODELS.keys()),
help="Choose which model to use for entity recognition"
)
with st.sidebar.expander("Model Details"):
st.write(f"**Model Path:** {MODELS[selected_model_name]}")
st.write("This model is fine-tuned specifically for the electrical engineering domain.")
# Confidence threshold
score_threshold = st.sidebar.slider(
'Minimum confidence threshold',
min_value=0.0,
max_value=1.0,
value=0.40,
step=0.01
)
# Load selected model
model = load_model(MODELS[selected_model_name])
if model is None:
st.error("Failed to load model. Please try selecting a different model.")
return
# Create a form to collect user text and an Analyze button
with st.form(key="text_form"):
st.subheader("Try an example or enter your own text")
example_idx = st.selectbox(
"Select an example:",
range(len(EXAMPLES)),
format_func=lambda x: EXAMPLES[x][:100] + "..."
)
text_input = st.text_area(
"Enter text for analysis:",
value=EXAMPLES[example_idx],
height=100
)
# This button triggers form submission
submit_button = st.form_submit_button(label="Analyze")
# Only run inference after the user clicks "Analyze"
if submit_button and text_input.strip():
with st.spinner("Analyzing text..."):
try:
raw_entities = model(text_input)
cleaned_entities = clean_and_group_entities(raw_entities, min_score=score_threshold)
# Visualization
st.subheader("Identified Entities")
html_content = create_displacy_data(text_input, cleaned_entities)
st.markdown(html_content, unsafe_allow_html=True)
# Create DataFrame
if cleaned_entities:
df = pd.DataFrame(cleaned_entities).round({"score": 3})
df = df.rename(columns={
"entity_group": "Entity Type",
"word": "Text",
"score": "Confidence",
"start": "Start",
"end": "End"
})
st.subheader("Entity Details")
st.dataframe(df)
st.subheader("Entity Distribution")
entity_counts = df["Entity Type"].value_counts()
st.bar_chart(entity_counts)
else:
st.info("No entities detected in the text (or all below threshold).")
except Exception as e:
st.error(f"Error processing text: {str(e)}")
# Entity type legend
st.sidebar.title("Entity Types")
st.sidebar.markdown("""
- πŸ”§ **COMPONENT**: Circuit elements
- πŸ“Š **DESIGN_PARAM**: Values, measurements
- 🧱 **MATERIAL**: Physical materials
- πŸ”Œ **EQUIPMENT**: Testing equipment
- πŸ’» **TECHNOLOGY**: Tech implementations
- πŸ’Ύ **SOFTWARE**: Software tools
- πŸ“œ **STANDARD**: Technical standards
- 🏒 **VENDOR**: Manufacturers
- πŸ“¦ **PRODUCT**: Specific products
""")
if __name__ == "__main__":
main()