|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline |
|
import pandas as pd |
|
from spacy import displacy |
|
|
|
|
|
|
|
|
|
def clean_and_group_entities(ner_results, min_score=0.40): |
|
""" |
|
Combines tokens for the same entity and filters out entities below the score threshold. |
|
""" |
|
grouped_entities = [] |
|
current_entity = None |
|
|
|
for result in ner_results: |
|
|
|
if result["score"] < min_score: |
|
if current_entity: |
|
|
|
if current_entity["score"] >= min_score: |
|
grouped_entities.append(current_entity) |
|
current_entity = None |
|
continue |
|
|
|
|
|
word = result["word"].replace("##", "") |
|
|
|
|
|
if (current_entity |
|
and result["entity_group"] == current_entity["entity_group"] |
|
and result["start"] == current_entity["end"]): |
|
|
|
|
|
current_entity["word"] += word |
|
current_entity["end"] = result["end"] |
|
|
|
current_entity["score"] = min(current_entity["score"], result["score"]) |
|
|
|
|
|
if current_entity["score"] < min_score: |
|
current_entity = None |
|
else: |
|
|
|
if current_entity and current_entity["score"] >= min_score: |
|
grouped_entities.append(current_entity) |
|
|
|
|
|
current_entity = { |
|
"entity_group": result["entity_group"], |
|
"word": word, |
|
"start": result["start"], |
|
"end": result["end"], |
|
"score": result["score"] |
|
} |
|
|
|
|
|
if current_entity and current_entity["score"] >= min_score: |
|
grouped_entities.append(current_entity) |
|
|
|
return grouped_entities |
|
|
|
|
|
|
|
|
|
MODELS = { |
|
"ModernBERT Base": "disham993/electrical-ner-modernbert-base", |
|
"BERT Base": "disham993/electrical-ner-bert-base", |
|
"ModernBERT Large": "disham993/electrical-ner-modernbert-large", |
|
"BERT Large": "disham993/electrical-ner-bert-large", |
|
"DistilBERT Base": "disham993/electrical-ner-distilbert-base" |
|
} |
|
|
|
ENTITY_COLORS = { |
|
"COMPONENT": "#FFB6C1", |
|
"DESIGN_PARAM": "#98FB98", |
|
"MATERIAL": "#DDA0DD", |
|
"EQUIPMENT": "#87CEEB", |
|
"TECHNOLOGY": "#F0E68C", |
|
"SOFTWARE": "#FFD700", |
|
"STANDARD": "#FFA07A", |
|
"VENDOR": "#E6E6FA", |
|
"PRODUCT": "#98FF98" |
|
} |
|
|
|
EXAMPLES = [ |
|
"Texas Instruments LM358 op-amp requires dual power supply.", |
|
"Using a Multimeter, the technician measured the 10 kΞ© resistance of a Copper wire in the circuit.", |
|
"To improve the reliability of the circuit, the engineer tested a 10k Ohm resistor with a multimeter from Fluke.", |
|
"During the circuit design, we measured the current flow using a Fluke multimeter to ensure it was within the 10A specification." |
|
] |
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
""" |
|
Load and return a token classification pipeline with an aggregation strategy of 'simple'. |
|
""" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
return pipeline( |
|
"ner", |
|
model=model, |
|
tokenizer=tokenizer, |
|
aggregation_strategy="simple" |
|
) |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
return None |
|
|
|
def get_base_entity_type(entity_label): |
|
""" |
|
Strips off 'B-' or 'I-' prefix if present. |
|
""" |
|
if entity_label.startswith("B-") or entity_label.startswith("I-"): |
|
return entity_label[2:] |
|
return entity_label |
|
|
|
def create_displacy_data(text, entities): |
|
""" |
|
Create data for spaCy's displacy visualizer. |
|
""" |
|
ents = [] |
|
for entity in entities: |
|
base_type = get_base_entity_type(entity["entity_group"]) |
|
ents.append({ |
|
"start": entity["start"], |
|
"end": entity["end"], |
|
"label": base_type |
|
}) |
|
|
|
colors = {entity_type: color for entity_type, color in ENTITY_COLORS.items()} |
|
options = {"ents": list(ENTITY_COLORS.keys()), "colors": colors} |
|
|
|
doc_data = { |
|
"text": text, |
|
"ents": ents, |
|
"title": None |
|
} |
|
|
|
|
|
html_content = displacy.render(doc_data, style="ent", options=options, manual=True) |
|
return html_content |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="Electrical Engineering NER", page_icon="β‘", layout="wide") |
|
|
|
st.title("β‘ Electrical Engineering Named Entity Recognition") |
|
st.markdown(""" |
|
This application identifies technical entities in electrical engineering text using a fine-tuned BERT model. |
|
It can recognize components, parameters, materials, equipment, and more. |
|
""") |
|
|
|
|
|
st.sidebar.title("Model Configuration") |
|
selected_model_name = st.sidebar.selectbox( |
|
"Select Model", |
|
list(MODELS.keys()), |
|
help="Choose which model to use for entity recognition" |
|
) |
|
|
|
with st.sidebar.expander("Model Details"): |
|
st.write(f"**Model Path:** {MODELS[selected_model_name]}") |
|
st.write("This model is fine-tuned specifically for the electrical engineering domain.") |
|
|
|
|
|
score_threshold = st.sidebar.slider( |
|
'Minimum confidence threshold', |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.40, |
|
step=0.01 |
|
) |
|
|
|
|
|
model = load_model(MODELS[selected_model_name]) |
|
|
|
if model is None: |
|
st.error("Failed to load model. Please try selecting a different model.") |
|
return |
|
|
|
|
|
with st.form(key="text_form"): |
|
st.subheader("Try an example or enter your own text") |
|
example_idx = st.selectbox( |
|
"Select an example:", |
|
range(len(EXAMPLES)), |
|
format_func=lambda x: EXAMPLES[x][:100] + "..." |
|
) |
|
|
|
text_input = st.text_area( |
|
"Enter text for analysis:", |
|
value=EXAMPLES[example_idx], |
|
height=100 |
|
) |
|
|
|
|
|
submit_button = st.form_submit_button(label="Analyze") |
|
|
|
|
|
if submit_button and text_input.strip(): |
|
with st.spinner("Analyzing text..."): |
|
try: |
|
raw_entities = model(text_input) |
|
cleaned_entities = clean_and_group_entities(raw_entities, min_score=score_threshold) |
|
|
|
|
|
st.subheader("Identified Entities") |
|
html_content = create_displacy_data(text_input, cleaned_entities) |
|
st.markdown(html_content, unsafe_allow_html=True) |
|
|
|
|
|
if cleaned_entities: |
|
df = pd.DataFrame(cleaned_entities).round({"score": 3}) |
|
|
|
df = df.rename(columns={ |
|
"entity_group": "Entity Type", |
|
"word": "Text", |
|
"score": "Confidence", |
|
"start": "Start", |
|
"end": "End" |
|
}) |
|
|
|
st.subheader("Entity Details") |
|
st.dataframe(df) |
|
|
|
st.subheader("Entity Distribution") |
|
entity_counts = df["Entity Type"].value_counts() |
|
st.bar_chart(entity_counts) |
|
else: |
|
st.info("No entities detected in the text (or all below threshold).") |
|
|
|
except Exception as e: |
|
st.error(f"Error processing text: {str(e)}") |
|
|
|
|
|
st.sidebar.title("Entity Types") |
|
st.sidebar.markdown(""" |
|
- π§ **COMPONENT**: Circuit elements |
|
- π **DESIGN_PARAM**: Values, measurements |
|
- 𧱠**MATERIAL**: Physical materials |
|
- π **EQUIPMENT**: Testing equipment |
|
- π» **TECHNOLOGY**: Tech implementations |
|
- πΎ **SOFTWARE**: Software tools |
|
- π **STANDARD**: Technical standards |
|
- π’ **VENDOR**: Manufacturers |
|
- π¦ **PRODUCT**: Specific products |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |