File size: 1,674 Bytes
925e416
 
 
cbd3f4c
0f34ca3
 
 
 
 
 
438fdb3
925e416
2aa7ddb
21aa3e9
2aa7ddb
438fdb3
c8d5deb
438fdb3
5a7c4c9
0f34ca3
438fdb3
0f34ca3
438fdb3
0f34ca3
 
 
 
 
 
 
 
 
 
 
48b6db6
 
0f34ca3
925e416
87806b7
c8d5deb
87806b7
2aa7ddb
925e416
 
733f064
7862c87
 
 
8b60d18
925e416
c71ee2b
 
733f064
925e416
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Imports
import gradio as gr
from sklearn.linear_model import LogisticRegression
import pickle5 as pickle
import re
import string
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer

# file name
lr_filename = 'lr_021223.pkl'

# Load model from pickle file
model = pickle.load(open(lr_filename, 'rb'))


# Process input text, including removing stopwords, converting to lowercase, and removing punctuation
stop = stopwords.words('english')
def process_text(text):
    text = [word for word in text.split() if word not in stop]
    text = str(text).lower()
    text = re.sub(
        f"[{re.escape(string.punctuation)}]", " ", text
    )
    text = " ".join(text.split())
    return text

# Vectorize input text
vectorizer = CountVectorizer()
def vectorize_text(text):
    text = process_text(text)
    text = vectorizer.fit_transform([text])
    return text

def predict(text):
    text = vectorize_text(text)
    prediction = model.predict(text)
    return prediction


# Define interface
demo = gr.Interface(fn=predict,
                        title="Text Classification Demo",
                        description="This is a demo of a text classification model using Logistic Regression.",
                        inputs=gr.Textbox(lines=10, placeholder='Input text here...', label="Input Text"),
                        outputs=gr.Textbox(label="Predicted Label: Other: 1, Healthcare: 2, Technology: 3", lines=2, placeholder='Predicted label will appear here...'),
                        allow_flagging='never'
)

demo.launch()