|
from transformers import pipeline |
|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
|
|
|
|
pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health") |
|
pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health") |
|
|
|
|
|
def ensemble_predict(text): |
|
|
|
results_1 = pipe_1(text) |
|
results_2 = pipe_2(text) |
|
|
|
|
|
ensemble_scores = {} |
|
|
|
|
|
for result in results_1: |
|
ensemble_scores[result['label']] = 0 |
|
|
|
|
|
for result in results_2: |
|
ensemble_scores[result['label']] = 0 |
|
|
|
|
|
for results in [results_1, results_2]: |
|
for result in results: |
|
label = result['label'] |
|
score = result['score'] |
|
ensemble_scores[label] += score / 2 |
|
|
|
|
|
predicted_label = max(ensemble_scores, key=ensemble_scores.get) |
|
confidence = ensemble_scores[predicted_label] |
|
|
|
return predicted_label, confidence |
|
|
|
|
|
st.title('Mental Illness Prediction') |
|
|
|
|
|
sentence = st.text_area("Enter the long sentence to predict your mental illness state:") |
|
|
|
if st.button('Predict'): |
|
|
|
predicted_label, confidence = ensemble_predict(sentence) |
|
|
|
|
|
st.markdown(f""" |
|
<h2 style='text-align: center; color: #1E90FF;'>Prediction Results</h2> |
|
<p style='font-size: 24px; font-weight: bold;'>Result: <span style='color: #1E90FF;'>{predicted_label}</span></p> |
|
<p style='font-size: 24px; font-weight: bold;'>Confidence: <span style='color: #1E90FF;'>{confidence:.2f}</span></p> |
|
""", unsafe_allow_html=True) |
|
|
|
st.info("Remember: This prediction is not a diagnosis. Our method is designed to support, not replace, mental health professionals. The model's predictions should be used as a reference, and the final diagnosis should be made by a qualified professional to avoid potential biases and inaccuracies.") |
|
|