karthik55's picture
updated app.py
e8c8b7d verified
raw
history blame
1.34 kB
import gradio as gr
import joblib
# Load models
models = {
"Logistic Regression": joblib.load("models/best_model.joblib"),
"Random Forest": joblib.load("models/random_forest_model.joblib"),
"SVM (Linear)": joblib.load("models/svm_model_linear.joblib"),
"SVM (Polynomial)": joblib.load("models/svm_model_polynomial.joblib"),
"SVM (RBF)": joblib.load("models/svm_model_rbf.joblib"),
"KNN": joblib.load("models/trained_knn_model.joblib"),
}
# Define prediction function
def predict(review, model_name):
model = models[model_name]
prediction = model.predict([review])[0]
probabilities = model.predict_proba([review])[0]
return {
"Predicted Class": str(prediction),
"Class Probabilities": {
"Class 0": probabilities[0],
"Class 1": probabilities[1],
},
}
# Create Gradio interface
interface = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Review Comment"),
gr.Dropdown(choices=list(models.keys()), label="Model"),
],
outputs=gr.JSON(label="Prediction Results"),
title="Text Classification Models",
description="Choose a model and provide a review to see the predicted sentiment class.",
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()