import gradio as gr | |
from joblib import load | |
# Load the model and preprocessing artifacts | |
model = load("logistic_model.joblib") | |
tfidf_vectorizer = load("tfidf_vectorizer.joblib") | |
mlb = load("label_binarizer.joblib") | |
# Define a function to classify commit messages | |
def classify_commit(message): | |
# Preprocess the input message | |
X_tfidf = tfidf_vectorizer.transform([message]) | |
# Predict the labels | |
prediction = model.predict(X_tfidf) | |
predicted_labels = mlb.inverse_transform(prediction) | |
# Return the predicted labels as a comma-separated string | |
return ", ".join(predicted_labels[0]) if predicted_labels[0] else "No labels" | |
# Create a Gradio interface | |
demo = gr.Interface( | |
fn=classify_commit, # Function to call | |
inputs=gr.Textbox(label="Enter Commit Message"), # Input: Textbox for commit message | |
outputs=gr.Textbox(label="Predicted Labels"), # Output: Textbox for predicted labels | |
title="Commit Message Classifier", | |
description="Enter a commit message to classify it into predefined categories." | |
) | |
# Launch the Gradio app | |
demo.launch() | |