Timmyafolami commited on
Commit
6c17133
·
verified ·
1 Parent(s): 545a39d

Upload 35 files

Browse files
.dockerignore ADDED
File without changes
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/cleaned_data.csv filter=lfs diff=lfs merge=lfs -text
37
+ data/Combined_Data.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ samh_venv
2
+ .env
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Copy the requirements file into the container
8
+ COPY requirements.txt /app/requirements.txt
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy the current directory contents into the container at /app
14
+ COPY . /app
15
+
16
+ # Download NLTK data
17
+ RUN python -c "import nltk; nltk.download('stopwords'); nltk.download('wordnet')"
18
+
19
+ # Make port 8000 available to the world outside this container
20
+ EXPOSE 8000
21
+
22
+ # Run the entrypoint script
23
+ CMD ["sh", "./entrypoint.sh"]
README.md CHANGED
@@ -1,11 +1,42 @@
1
- ---
2
- title: SAMH
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sentiment Analysis API
2
+ ![alt text](image.png)
3
+
4
+ This project provides a sentiment analysis API using FastAPI and a machine learning model trained on textual data.
5
+
6
+ ## Features
7
+
8
+ - Data ingestion and preprocessing
9
+ - Model training and saving
10
+ - FastAPI application for serving predictions
11
+ - Dockerized for easy deployment
12
+
13
+ ## Setup
14
+
15
+ ### Prerequisites
16
+
17
+ - Docker installed on your system
18
+
19
+ ### Build and Run
20
+
21
+ 1. Build the Docker image:
22
+
23
+ ```bash
24
+ docker build -t sentiment-analysis-api .
25
+ ```
26
+
27
+ 2. Run the Docker container:
28
+
29
+ ```bash
30
+ docker run -p 8000:8000 sentiment-analysis-api
31
+ ```
32
+
33
+ 3. Access the API:
34
+
35
+ - Home: [http://localhost:8000](http://localhost:8000)
36
+ - Health Check: [http://localhost:8000/health](http://localhost:8000/health)
37
+ - Predict Sentiment: Use a POST request to [http://localhost:8000/predict_sentiment](http://localhost:8000/predict_sentiment) with a JSON body.
38
+
39
+ ## Example cURL Command
40
+
41
+ ```bash
42
+ curl -X POST "http://localhost:8000/predict_sentiment" -H "Content-Type: application/json" -d '{"text": "I love programming in Python."}'
data/Combined_Data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0700996d814af3ec77ef31870b68c6cdf991217eb76e259c7196df7f2e0e27ba
3
+ size 31469552
data/cleaned_data.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8b9a23caf50bd71eb2e02f6b49447f247791e66b0936f0cb47e479736b0c17e
3
+ size 49456310
data_pipeline/__init__.py ADDED
File without changes
data_pipeline/data_ingestion.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import requests
4
+
5
+ # Add the root directory to sys.path
6
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from logging_config.logger_config import get_logger
9
+
10
+ # Get the logger
11
+ logger = get_logger(__name__)
12
+
13
+ def download_data(url, save_path):
14
+ # Ensure the save directory exists
15
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
16
+
17
+ # Send a GET request to the URL
18
+ logger.info(f"Sending GET request to {url}")
19
+ response = requests.get(url)
20
+
21
+ # Check if the request was successful
22
+ if response.status_code == 200:
23
+ # Write the content to the specified file
24
+ with open(save_path, 'wb') as file:
25
+ file.write(response.content)
26
+ logger.info(f"Data downloaded successfully and saved to {save_path}")
27
+ else:
28
+ logger.error(f"Failed to download data. Status code: {response.status_code}")
29
+
30
+ if __name__ == "__main__":
31
+ # URL of the dataset
32
+ dataset_url = "https://raw.githubusercontent.com/timothyafolami/SAMH-Sentiment-Analysis-For-Mental-Health/master/data/Combined_Data.csv"
33
+
34
+ # Path to save the dataset
35
+ save_file_path = os.path.join("./data", "Combined_Data.csv")
36
+
37
+ # Download the dataset
38
+ download_data(dataset_url, save_file_path)
data_pipeline/data_preprocessor.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import string
5
+ import pandas as pd
6
+ import nltk
7
+ from nltk.corpus import stopwords
8
+ from nltk.stem import WordNetLemmatizer
9
+
10
+ # Add the root directory to sys.path
11
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
12
+
13
+ from logging_config.logger_config import get_logger
14
+
15
+
16
+ # Download necessary NLTK data files
17
+ nltk.download('stopwords')
18
+ nltk.download('wordnet')
19
+
20
+ # Get the logger
21
+ logger = get_logger(__name__)
22
+
23
+ # Custom Preprocessor Class
24
+ class TextPreprocessor:
25
+ def __init__(self):
26
+ self.stop_words = set(stopwords.words('english'))
27
+ self.lemmatizer = WordNetLemmatizer()
28
+ logger.info("TextPreprocessor initialized.")
29
+
30
+ def preprocess_text(self, text):
31
+ # logger.info(f"Original text: {text}")
32
+ # Lowercase the text
33
+ text = text.lower()
34
+ # logger.info(f"Lowercased text: {text}")
35
+
36
+ # Remove punctuation
37
+ text = re.sub(f'[{re.escape(string.punctuation)}]', '', text)
38
+ # logger.info(f"Text after punctuation removal: {text}")
39
+
40
+ # Remove numbers
41
+ text = re.sub(r'\d+', '', text)
42
+ # logger.info(f"Text after number removal: {text}")
43
+
44
+ # Tokenize the text
45
+ words = text.split()
46
+ # logger.info(f"Tokenized text: {words}")
47
+
48
+ # Remove stopwords and apply lemmatization
49
+ words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]
50
+ # logger.info(f"Text after stopword removal and lemmatization: {words}")
51
+
52
+ # Join words back into a single string
53
+ cleaned_text = ' '.join(words)
54
+ # logger.info(f"Cleaned text: {cleaned_text}")
55
+
56
+ return cleaned_text
57
+
58
+ def load_and_preprocess_data(file_path):
59
+ # Load the data
60
+ logger.info(f"Loading data from {file_path}")
61
+ df = pd.read_csv(file_path)
62
+ # dropping missing values
63
+ logger.info("Dropping missing values")
64
+ df.dropna(inplace=True)
65
+
66
+ # Check if the necessary column exists
67
+ if 'statement' not in df.columns:
68
+ logger.error("The required column 'statement' is missing from the dataset.")
69
+ return
70
+
71
+ # Initialize the text preprocessor
72
+ preprocessor = TextPreprocessor()
73
+
74
+ # Apply the preprocessing to the 'statement' column
75
+ logger.info("Starting text preprocessing...")
76
+ df['cleaned_statement'] = df['statement'].apply(preprocessor.preprocess_text)
77
+ logger.info("Text preprocessing completed.")
78
+
79
+ # Save the cleaned data to a new CSV file
80
+ cleaned_file_path = os.path.join('./data', 'cleaned_data.csv')
81
+ df.to_csv(cleaned_file_path, index=False)
82
+ logger.info(f"Cleaned data saved to {cleaned_file_path}")
83
+
84
+ if __name__ == "__main__":
85
+ # Path to the downloaded dataset
86
+ dataset_path = os.path.join("./data", "Combined_Data.csv")
87
+
88
+ # Preprocess the data
89
+ load_and_preprocess_data(dataset_path)
db_connection.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ from supabase import create_client, Client
3
+
4
+ # Add the root directory to sys.path
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
6
+ from logging_config.logger_config import get_logger
7
+
8
+ # Get the logger
9
+ logger = get_logger(__name__)
10
+
11
+
12
+ #connecting to the database
13
+ url: str = os.environ.get("SUPABASE_PROJECT_URL")
14
+ key: str = os.environ.get("SUPABASE_API_KEY")
15
+ supabase: Client = create_client(url, key)
16
+
17
+ # creating a function to update the database
18
+ def insert_db(data: dict, table='Interaction History'):
19
+ try:
20
+ logger.info(f"Inserting data into the database: {data}")
21
+ response = supabase.table(table).insert(data).execute()
22
+ logger.info(f"Data inserted successfully: {response}")
23
+ return response
24
+ except Exception as e:
25
+ logger.error(f"Error inserting data into the database: {e}")
26
+ return None
27
+
28
+ if __name__ == "__main__":
29
+ # Test the insert_db function
30
+ data = {
31
+ "Input_text" : "I feel incredibly anxious about everything and can't stop worrying",
32
+ "Model_prediction" : "Anxiety",
33
+ "Llama_3_Prediction" : "Anxiety",
34
+ "Llama_3_Explanation" : "After my analysis, i concluded that the user is suffering from anxiety",
35
+ "User Rating" : 5,
36
+ }
37
+
38
+ response = insert_db(data)
39
+ print(response)
40
+
entrypoint.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ # Exit immediately if a command exits with a non-zero status
4
+ set -e
5
+
6
+ # Step 1: Data Ingestion
7
+ echo "Running data ingestion..."
8
+ python data_pipeline/data_ingestion.py
9
+
10
+ # Step 2: Data Preprocessing
11
+ echo "Running data preprocessing..."
12
+ python data_pipeline/data_preprocessor.py
13
+
14
+ # Step 3: Model Training
15
+ echo "Running model training..."
16
+ python model_pipeline/model_trainer.py
17
+
18
+ # Step 4: Run FastAPI App
19
+ echo "Starting FastAPI app..."
20
+ uvicorn fastapi_app.main:app --host 0.0.0.0 --port 8000
experiment.ipynb ADDED
@@ -0,0 +1,1871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 13,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd\n",
10
+ "import numpy as np"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 16,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "# data loading\n",
20
+ "data = pd.read_csv('data//Combined_Data.csv')"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 17,
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "data": {
30
+ "text/html": [
31
+ "<div>\n",
32
+ "<style scoped>\n",
33
+ " .dataframe tbody tr th:only-of-type {\n",
34
+ " vertical-align: middle;\n",
35
+ " }\n",
36
+ "\n",
37
+ " .dataframe tbody tr th {\n",
38
+ " vertical-align: top;\n",
39
+ " }\n",
40
+ "\n",
41
+ " .dataframe thead th {\n",
42
+ " text-align: right;\n",
43
+ " }\n",
44
+ "</style>\n",
45
+ "<table border=\"1\" class=\"dataframe\">\n",
46
+ " <thead>\n",
47
+ " <tr style=\"text-align: right;\">\n",
48
+ " <th></th>\n",
49
+ " <th>Unnamed: 0</th>\n",
50
+ " <th>statement</th>\n",
51
+ " <th>status</th>\n",
52
+ " </tr>\n",
53
+ " </thead>\n",
54
+ " <tbody>\n",
55
+ " <tr>\n",
56
+ " <th>0</th>\n",
57
+ " <td>0</td>\n",
58
+ " <td>oh my gosh</td>\n",
59
+ " <td>Anxiety</td>\n",
60
+ " </tr>\n",
61
+ " <tr>\n",
62
+ " <th>1</th>\n",
63
+ " <td>1</td>\n",
64
+ " <td>trouble sleeping, confused mind, restless hear...</td>\n",
65
+ " <td>Anxiety</td>\n",
66
+ " </tr>\n",
67
+ " <tr>\n",
68
+ " <th>2</th>\n",
69
+ " <td>2</td>\n",
70
+ " <td>All wrong, back off dear, forward doubt. Stay ...</td>\n",
71
+ " <td>Anxiety</td>\n",
72
+ " </tr>\n",
73
+ " <tr>\n",
74
+ " <th>3</th>\n",
75
+ " <td>3</td>\n",
76
+ " <td>I've shifted my focus to something else but I'...</td>\n",
77
+ " <td>Anxiety</td>\n",
78
+ " </tr>\n",
79
+ " <tr>\n",
80
+ " <th>4</th>\n",
81
+ " <td>4</td>\n",
82
+ " <td>I'm restless and restless, it's been a month n...</td>\n",
83
+ " <td>Anxiety</td>\n",
84
+ " </tr>\n",
85
+ " </tbody>\n",
86
+ "</table>\n",
87
+ "</div>"
88
+ ],
89
+ "text/plain": [
90
+ " Unnamed: 0 statement status\n",
91
+ "0 0 oh my gosh Anxiety\n",
92
+ "1 1 trouble sleeping, confused mind, restless hear... Anxiety\n",
93
+ "2 2 All wrong, back off dear, forward doubt. Stay ... Anxiety\n",
94
+ "3 3 I've shifted my focus to something else but I'... Anxiety\n",
95
+ "4 4 I'm restless and restless, it's been a month n... Anxiety"
96
+ ]
97
+ },
98
+ "execution_count": 17,
99
+ "metadata": {},
100
+ "output_type": "execute_result"
101
+ }
102
+ ],
103
+ "source": [
104
+ "data.head()"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 21,
110
+ "metadata": {},
111
+ "outputs": [
112
+ {
113
+ "data": {
114
+ "text/plain": [
115
+ "'I recently watched my dad die a gruesome death due to cancer this week, and I am sure something similar is in my future, I do not have any real friends and I do not have a home, I have been living in a hotel the past 6 months. I do not want to live anymore I just want to see my dad again and I do not want to suffer like he did I do not want to live anymore'"
116
+ ]
117
+ },
118
+ "execution_count": 21,
119
+ "metadata": {},
120
+ "output_type": "execute_result"
121
+ }
122
+ ],
123
+ "source": [
124
+ "data['statement'].values[19230]"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 19,
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "data": {
134
+ "text/html": [
135
+ "<div>\n",
136
+ "<style scoped>\n",
137
+ " .dataframe tbody tr th:only-of-type {\n",
138
+ " vertical-align: middle;\n",
139
+ " }\n",
140
+ "\n",
141
+ " .dataframe tbody tr th {\n",
142
+ " vertical-align: top;\n",
143
+ " }\n",
144
+ "\n",
145
+ " .dataframe thead th {\n",
146
+ " text-align: right;\n",
147
+ " }\n",
148
+ "</style>\n",
149
+ "<table border=\"1\" class=\"dataframe\">\n",
150
+ " <thead>\n",
151
+ " <tr style=\"text-align: right;\">\n",
152
+ " <th></th>\n",
153
+ " <th>statement</th>\n",
154
+ " <th>status</th>\n",
155
+ " </tr>\n",
156
+ " </thead>\n",
157
+ " <tbody>\n",
158
+ " <tr>\n",
159
+ " <th>0</th>\n",
160
+ " <td>oh my gosh</td>\n",
161
+ " <td>Anxiety</td>\n",
162
+ " </tr>\n",
163
+ " <tr>\n",
164
+ " <th>1</th>\n",
165
+ " <td>trouble sleeping, confused mind, restless hear...</td>\n",
166
+ " <td>Anxiety</td>\n",
167
+ " </tr>\n",
168
+ " <tr>\n",
169
+ " <th>2</th>\n",
170
+ " <td>All wrong, back off dear, forward doubt. Stay ...</td>\n",
171
+ " <td>Anxiety</td>\n",
172
+ " </tr>\n",
173
+ " <tr>\n",
174
+ " <th>3</th>\n",
175
+ " <td>I've shifted my focus to something else but I'...</td>\n",
176
+ " <td>Anxiety</td>\n",
177
+ " </tr>\n",
178
+ " <tr>\n",
179
+ " <th>4</th>\n",
180
+ " <td>I'm restless and restless, it's been a month n...</td>\n",
181
+ " <td>Anxiety</td>\n",
182
+ " </tr>\n",
183
+ " </tbody>\n",
184
+ "</table>\n",
185
+ "</div>"
186
+ ],
187
+ "text/plain": [
188
+ " statement status\n",
189
+ "0 oh my gosh Anxiety\n",
190
+ "1 trouble sleeping, confused mind, restless hear... Anxiety\n",
191
+ "2 All wrong, back off dear, forward doubt. Stay ... Anxiety\n",
192
+ "3 I've shifted my focus to something else but I'... Anxiety\n",
193
+ "4 I'm restless and restless, it's been a month n... Anxiety"
194
+ ]
195
+ },
196
+ "execution_count": 19,
197
+ "metadata": {},
198
+ "output_type": "execute_result"
199
+ }
200
+ ],
201
+ "source": [
202
+ "# selecting needed columns\n",
203
+ "df = data[['statement', 'status']]\n",
204
+ "df.head()"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 5,
210
+ "metadata": {},
211
+ "outputs": [
212
+ {
213
+ "data": {
214
+ "text/plain": [
215
+ "status\n",
216
+ "Normal 16351\n",
217
+ "Depression 15404\n",
218
+ "Suicidal 10653\n",
219
+ "Anxiety 3888\n",
220
+ "Bipolar 2877\n",
221
+ "Stress 2669\n",
222
+ "Personality disorder 1201\n",
223
+ "Name: count, dtype: int64"
224
+ ]
225
+ },
226
+ "execution_count": 5,
227
+ "metadata": {},
228
+ "output_type": "execute_result"
229
+ }
230
+ ],
231
+ "source": [
232
+ "# value counts for the status\n",
233
+ "df['status'].value_counts()"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 6,
239
+ "metadata": {},
240
+ "outputs": [
241
+ {
242
+ "data": {
243
+ "text/plain": [
244
+ "(53043, 2)"
245
+ ]
246
+ },
247
+ "execution_count": 6,
248
+ "metadata": {},
249
+ "output_type": "execute_result"
250
+ }
251
+ ],
252
+ "source": [
253
+ "df.shape"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 7,
259
+ "metadata": {},
260
+ "outputs": [
261
+ {
262
+ "data": {
263
+ "text/plain": [
264
+ "statement 362\n",
265
+ "status 0\n",
266
+ "dtype: int64"
267
+ ]
268
+ },
269
+ "execution_count": 7,
270
+ "metadata": {},
271
+ "output_type": "execute_result"
272
+ }
273
+ ],
274
+ "source": [
275
+ "# checking for nan values\n",
276
+ "df.isnull().sum()"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 8,
282
+ "metadata": {},
283
+ "outputs": [
284
+ {
285
+ "data": {
286
+ "text/plain": [
287
+ "statement 0\n",
288
+ "status 0\n",
289
+ "dtype: int64"
290
+ ]
291
+ },
292
+ "execution_count": 8,
293
+ "metadata": {},
294
+ "output_type": "execute_result"
295
+ }
296
+ ],
297
+ "source": [
298
+ "# dropping nan values\n",
299
+ "df_1 = df.dropna()\n",
300
+ "df_1.isna().sum()"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": 9,
306
+ "metadata": {},
307
+ "outputs": [
308
+ {
309
+ "name": "stderr",
310
+ "output_type": "stream",
311
+ "text": [
312
+ "[nltk_data] Downloading package stopwords to\n",
313
+ "[nltk_data] C:\\Users\\timmy\\AppData\\Roaming\\nltk_data...\n",
314
+ "[nltk_data] Package stopwords is already up-to-date!\n",
315
+ "[nltk_data] Downloading package wordnet to\n",
316
+ "[nltk_data] C:\\Users\\timmy\\AppData\\Roaming\\nltk_data...\n",
317
+ "[nltk_data] Package wordnet is already up-to-date!\n"
318
+ ]
319
+ },
320
+ {
321
+ "data": {
322
+ "text/plain": [
323
+ "True"
324
+ ]
325
+ },
326
+ "execution_count": 9,
327
+ "metadata": {},
328
+ "output_type": "execute_result"
329
+ }
330
+ ],
331
+ "source": [
332
+ "import re\n",
333
+ "import string\n",
334
+ "import nltk\n",
335
+ "from nltk.corpus import stopwords\n",
336
+ "from nltk.stem import PorterStemmer, WordNetLemmatizer\n",
337
+ "\n",
338
+ "# Download necessary NLTK data files\n",
339
+ "nltk.download('stopwords')\n",
340
+ "nltk.download('wordnet')"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": 10,
346
+ "metadata": {},
347
+ "outputs": [
348
+ {
349
+ "name": "stdout",
350
+ "output_type": "stream",
351
+ "text": [
352
+ "example sentence demonstrate text preprocessing python includes number like punctuation\n"
353
+ ]
354
+ }
355
+ ],
356
+ "source": [
357
+ "# creating a cleaning pipeline for the statement column\n",
358
+ "def preprocess_text(text, use_stemming=False, use_lemmatization=True):\n",
359
+ " # Lowercase the text\n",
360
+ " text = text.lower()\n",
361
+ " \n",
362
+ " # Remove punctuation\n",
363
+ " text = re.sub(f'[{re.escape(string.punctuation)}]', '', text)\n",
364
+ " \n",
365
+ " # Remove numbers\n",
366
+ " text = re.sub(r'\\d+', '', text)\n",
367
+ " \n",
368
+ " # Tokenize the text\n",
369
+ " words = text.split()\n",
370
+ " \n",
371
+ " # Remove stopwords\n",
372
+ " stop_words = set(stopwords.words('english'))\n",
373
+ " words = [word for word in words if word not in stop_words]\n",
374
+ " \n",
375
+ " # Initialize stemmer and lemmatizer\n",
376
+ " stemmer = PorterStemmer()\n",
377
+ " lemmatizer = WordNetLemmatizer()\n",
378
+ " \n",
379
+ " if use_stemming:\n",
380
+ " # Apply stemming\n",
381
+ " words = [stemmer.stem(word) for word in words]\n",
382
+ " elif use_lemmatization:\n",
383
+ " # Apply lemmatization\n",
384
+ " words = [lemmatizer.lemmatize(word) for word in words]\n",
385
+ " \n",
386
+ " # Join words back into a single string\n",
387
+ " cleaned_text = ' '.join(words)\n",
388
+ " \n",
389
+ " return cleaned_text\n",
390
+ "\n",
391
+ "# Example usage\n",
392
+ "text = \"This is an example sentence to demonstrate text preprocessing in Python. It includes numbers like 123 and punctuation!\"\n",
393
+ "cleaned_text = preprocess_text(text)\n",
394
+ "print(cleaned_text)\n"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": 11,
400
+ "metadata": {},
401
+ "outputs": [
402
+ {
403
+ "name": "stderr",
404
+ "output_type": "stream",
405
+ "text": [
406
+ "C:\\Users\\timmy\\AppData\\Local\\Temp\\ipykernel_4184\\637849828.py:2: SettingWithCopyWarning: \n",
407
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
408
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
409
+ "\n",
410
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
411
+ " df_1['cleaned_statement'] = df_1['statement'].apply(preprocess_text)\n"
412
+ ]
413
+ }
414
+ ],
415
+ "source": [
416
+ "# implementing on the statement column\n",
417
+ "df_1['cleaned_statement'] = df_1['statement'].apply(preprocess_text)"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 12,
423
+ "metadata": {},
424
+ "outputs": [
425
+ {
426
+ "data": {
427
+ "text/html": [
428
+ "<div>\n",
429
+ "<style scoped>\n",
430
+ " .dataframe tbody tr th:only-of-type {\n",
431
+ " vertical-align: middle;\n",
432
+ " }\n",
433
+ "\n",
434
+ " .dataframe tbody tr th {\n",
435
+ " vertical-align: top;\n",
436
+ " }\n",
437
+ "\n",
438
+ " .dataframe thead th {\n",
439
+ " text-align: right;\n",
440
+ " }\n",
441
+ "</style>\n",
442
+ "<table border=\"1\" class=\"dataframe\">\n",
443
+ " <thead>\n",
444
+ " <tr style=\"text-align: right;\">\n",
445
+ " <th></th>\n",
446
+ " <th>statement</th>\n",
447
+ " <th>status</th>\n",
448
+ " <th>cleaned_statement</th>\n",
449
+ " </tr>\n",
450
+ " </thead>\n",
451
+ " <tbody>\n",
452
+ " <tr>\n",
453
+ " <th>0</th>\n",
454
+ " <td>oh my gosh</td>\n",
455
+ " <td>Anxiety</td>\n",
456
+ " <td>oh gosh</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <th>1</th>\n",
460
+ " <td>trouble sleeping, confused mind, restless hear...</td>\n",
461
+ " <td>Anxiety</td>\n",
462
+ " <td>trouble sleeping confused mind restless heart ...</td>\n",
463
+ " </tr>\n",
464
+ " <tr>\n",
465
+ " <th>2</th>\n",
466
+ " <td>All wrong, back off dear, forward doubt. Stay ...</td>\n",
467
+ " <td>Anxiety</td>\n",
468
+ " <td>wrong back dear forward doubt stay restless re...</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <th>3</th>\n",
472
+ " <td>I've shifted my focus to something else but I'...</td>\n",
473
+ " <td>Anxiety</td>\n",
474
+ " <td>ive shifted focus something else im still worried</td>\n",
475
+ " </tr>\n",
476
+ " <tr>\n",
477
+ " <th>4</th>\n",
478
+ " <td>I'm restless and restless, it's been a month n...</td>\n",
479
+ " <td>Anxiety</td>\n",
480
+ " <td>im restless restless month boy mean</td>\n",
481
+ " </tr>\n",
482
+ " </tbody>\n",
483
+ "</table>\n",
484
+ "</div>"
485
+ ],
486
+ "text/plain": [
487
+ " statement status \\\n",
488
+ "0 oh my gosh Anxiety \n",
489
+ "1 trouble sleeping, confused mind, restless hear... Anxiety \n",
490
+ "2 All wrong, back off dear, forward doubt. Stay ... Anxiety \n",
491
+ "3 I've shifted my focus to something else but I'... Anxiety \n",
492
+ "4 I'm restless and restless, it's been a month n... Anxiety \n",
493
+ "\n",
494
+ " cleaned_statement \n",
495
+ "0 oh gosh \n",
496
+ "1 trouble sleeping confused mind restless heart ... \n",
497
+ "2 wrong back dear forward doubt stay restless re... \n",
498
+ "3 ive shifted focus something else im still worried \n",
499
+ "4 im restless restless month boy mean "
500
+ ]
501
+ },
502
+ "execution_count": 12,
503
+ "metadata": {},
504
+ "output_type": "execute_result"
505
+ }
506
+ ],
507
+ "source": [
508
+ "df_1.head()"
509
+ ]
510
+ },
511
+ {
512
+ "cell_type": "code",
513
+ "execution_count": 13,
514
+ "metadata": {},
515
+ "outputs": [
516
+ {
517
+ "data": {
518
+ "text/html": [
519
+ "<div>\n",
520
+ "<style scoped>\n",
521
+ " .dataframe tbody tr th:only-of-type {\n",
522
+ " vertical-align: middle;\n",
523
+ " }\n",
524
+ "\n",
525
+ " .dataframe tbody tr th {\n",
526
+ " vertical-align: top;\n",
527
+ " }\n",
528
+ "\n",
529
+ " .dataframe thead th {\n",
530
+ " text-align: right;\n",
531
+ " }\n",
532
+ "</style>\n",
533
+ "<table border=\"1\" class=\"dataframe\">\n",
534
+ " <thead>\n",
535
+ " <tr style=\"text-align: right;\">\n",
536
+ " <th></th>\n",
537
+ " <th>cleaned_statement</th>\n",
538
+ " <th>status</th>\n",
539
+ " </tr>\n",
540
+ " </thead>\n",
541
+ " <tbody>\n",
542
+ " <tr>\n",
543
+ " <th>0</th>\n",
544
+ " <td>oh gosh</td>\n",
545
+ " <td>Anxiety</td>\n",
546
+ " </tr>\n",
547
+ " <tr>\n",
548
+ " <th>1</th>\n",
549
+ " <td>trouble sleeping confused mind restless heart ...</td>\n",
550
+ " <td>Anxiety</td>\n",
551
+ " </tr>\n",
552
+ " <tr>\n",
553
+ " <th>2</th>\n",
554
+ " <td>wrong back dear forward doubt stay restless re...</td>\n",
555
+ " <td>Anxiety</td>\n",
556
+ " </tr>\n",
557
+ " <tr>\n",
558
+ " <th>3</th>\n",
559
+ " <td>ive shifted focus something else im still worried</td>\n",
560
+ " <td>Anxiety</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <th>4</th>\n",
564
+ " <td>im restless restless month boy mean</td>\n",
565
+ " <td>Anxiety</td>\n",
566
+ " </tr>\n",
567
+ " </tbody>\n",
568
+ "</table>\n",
569
+ "</div>"
570
+ ],
571
+ "text/plain": [
572
+ " cleaned_statement status\n",
573
+ "0 oh gosh Anxiety\n",
574
+ "1 trouble sleeping confused mind restless heart ... Anxiety\n",
575
+ "2 wrong back dear forward doubt stay restless re... Anxiety\n",
576
+ "3 ive shifted focus something else im still worried Anxiety\n",
577
+ "4 im restless restless month boy mean Anxiety"
578
+ ]
579
+ },
580
+ "execution_count": 13,
581
+ "metadata": {},
582
+ "output_type": "execute_result"
583
+ }
584
+ ],
585
+ "source": [
586
+ "df_2 = df_1[['cleaned_statement', 'status']]\n",
587
+ "df_2.head()"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": 14,
593
+ "metadata": {},
594
+ "outputs": [
595
+ {
596
+ "name": "stderr",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "C:\\Users\\timmy\\AppData\\Local\\Temp\\ipykernel_4184\\858368390.py:4: SettingWithCopyWarning: \n",
600
+ "A value is trying to be set on a copy of a slice from a DataFrame.\n",
601
+ "Try using .loc[row_indexer,col_indexer] = value instead\n",
602
+ "\n",
603
+ "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
604
+ " df_2['status'] = encoder.fit_transform(df_2['status'])\n"
605
+ ]
606
+ }
607
+ ],
608
+ "source": [
609
+ "# encoding the status column\n",
610
+ "from sklearn.preprocessing import LabelEncoder\n",
611
+ "encoder = LabelEncoder()\n",
612
+ "df_2['status'] = encoder.fit_transform(df_2['status'])"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "code",
617
+ "execution_count": 15,
618
+ "metadata": {},
619
+ "outputs": [
620
+ {
621
+ "data": {
622
+ "text/plain": [
623
+ "array(['Anxiety', 'Bipolar', 'Depression', 'Normal',\n",
624
+ " 'Personality disorder', 'Stress', 'Suicidal'], dtype=object)"
625
+ ]
626
+ },
627
+ "execution_count": 15,
628
+ "metadata": {},
629
+ "output_type": "execute_result"
630
+ }
631
+ ],
632
+ "source": [
633
+ "encoder.classes_"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": 16,
639
+ "metadata": {},
640
+ "outputs": [
641
+ {
642
+ "data": {
643
+ "text/plain": [
644
+ "{'Anxiety': np.int64(0),\n",
645
+ " 'Bipolar': np.int64(1),\n",
646
+ " 'Depression': np.int64(2),\n",
647
+ " 'Normal': np.int64(3),\n",
648
+ " 'Personality disorder': np.int64(4),\n",
649
+ " 'Stress': np.int64(5),\n",
650
+ " 'Suicidal': np.int64(6)}"
651
+ ]
652
+ },
653
+ "execution_count": 16,
654
+ "metadata": {},
655
+ "output_type": "execute_result"
656
+ }
657
+ ],
658
+ "source": [
659
+ "label_mapping = dict(zip(encoder.classes_, encoder.transform(encoder.classes_)))\n",
660
+ "label_mapping"
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "code",
665
+ "execution_count": 17,
666
+ "metadata": {},
667
+ "outputs": [
668
+ {
669
+ "data": {
670
+ "text/html": [
671
+ "<div>\n",
672
+ "<style scoped>\n",
673
+ " .dataframe tbody tr th:only-of-type {\n",
674
+ " vertical-align: middle;\n",
675
+ " }\n",
676
+ "\n",
677
+ " .dataframe tbody tr th {\n",
678
+ " vertical-align: top;\n",
679
+ " }\n",
680
+ "\n",
681
+ " .dataframe thead th {\n",
682
+ " text-align: right;\n",
683
+ " }\n",
684
+ "</style>\n",
685
+ "<table border=\"1\" class=\"dataframe\">\n",
686
+ " <thead>\n",
687
+ " <tr style=\"text-align: right;\">\n",
688
+ " <th></th>\n",
689
+ " <th>cleaned_statement</th>\n",
690
+ " <th>status</th>\n",
691
+ " </tr>\n",
692
+ " </thead>\n",
693
+ " <tbody>\n",
694
+ " <tr>\n",
695
+ " <th>0</th>\n",
696
+ " <td>oh gosh</td>\n",
697
+ " <td>0</td>\n",
698
+ " </tr>\n",
699
+ " <tr>\n",
700
+ " <th>1</th>\n",
701
+ " <td>trouble sleeping confused mind restless heart ...</td>\n",
702
+ " <td>0</td>\n",
703
+ " </tr>\n",
704
+ " <tr>\n",
705
+ " <th>2</th>\n",
706
+ " <td>wrong back dear forward doubt stay restless re...</td>\n",
707
+ " <td>0</td>\n",
708
+ " </tr>\n",
709
+ " <tr>\n",
710
+ " <th>3</th>\n",
711
+ " <td>ive shifted focus something else im still worried</td>\n",
712
+ " <td>0</td>\n",
713
+ " </tr>\n",
714
+ " <tr>\n",
715
+ " <th>4</th>\n",
716
+ " <td>im restless restless month boy mean</td>\n",
717
+ " <td>0</td>\n",
718
+ " </tr>\n",
719
+ " </tbody>\n",
720
+ "</table>\n",
721
+ "</div>"
722
+ ],
723
+ "text/plain": [
724
+ " cleaned_statement status\n",
725
+ "0 oh gosh 0\n",
726
+ "1 trouble sleeping confused mind restless heart ... 0\n",
727
+ "2 wrong back dear forward doubt stay restless re... 0\n",
728
+ "3 ive shifted focus something else im still worried 0\n",
729
+ "4 im restless restless month boy mean 0"
730
+ ]
731
+ },
732
+ "execution_count": 17,
733
+ "metadata": {},
734
+ "output_type": "execute_result"
735
+ }
736
+ ],
737
+ "source": [
738
+ "df_2.head()"
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "execution_count": 20,
744
+ "metadata": {},
745
+ "outputs": [],
746
+ "source": [
747
+ "# splitting the data \n",
748
+ "from sklearn.model_selection import train_test_split\n",
749
+ "X = df_2['cleaned_statement']\n",
750
+ "y = df_2['status']\n",
751
+ "\n",
752
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)"
753
+ ]
754
+ },
755
+ {
756
+ "cell_type": "code",
757
+ "execution_count": 21,
758
+ "metadata": {},
759
+ "outputs": [],
760
+ "source": [
761
+ "# creating vectors for the cleaned_statement column\n",
762
+ "from sklearn.feature_extraction.text import TfidfVectorizer\n",
763
+ "\n",
764
+ "# Vectorize the text using TF-IDF\n",
765
+ "vectorizer = TfidfVectorizer()\n",
766
+ "X_train_tfidf = vectorizer.fit_transform(X_train)\n",
767
+ "X_test_tfidf = vectorizer.transform(X_test)\n"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": 26,
773
+ "metadata": {},
774
+ "outputs": [
775
+ {
776
+ "data": {
777
+ "text/html": [
778
+ "<style>#sk-container-id-2 {\n",
779
+ " /* Definition of color scheme common for light and dark mode */\n",
780
+ " --sklearn-color-text: black;\n",
781
+ " --sklearn-color-line: gray;\n",
782
+ " /* Definition of color scheme for unfitted estimators */\n",
783
+ " --sklearn-color-unfitted-level-0: #fff5e6;\n",
784
+ " --sklearn-color-unfitted-level-1: #f6e4d2;\n",
785
+ " --sklearn-color-unfitted-level-2: #ffe0b3;\n",
786
+ " --sklearn-color-unfitted-level-3: chocolate;\n",
787
+ " /* Definition of color scheme for fitted estimators */\n",
788
+ " --sklearn-color-fitted-level-0: #f0f8ff;\n",
789
+ " --sklearn-color-fitted-level-1: #d4ebff;\n",
790
+ " --sklearn-color-fitted-level-2: #b3dbfd;\n",
791
+ " --sklearn-color-fitted-level-3: cornflowerblue;\n",
792
+ "\n",
793
+ " /* Specific color for light theme */\n",
794
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
795
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
796
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
797
+ " --sklearn-color-icon: #696969;\n",
798
+ "\n",
799
+ " @media (prefers-color-scheme: dark) {\n",
800
+ " /* Redefinition of color scheme for dark theme */\n",
801
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
802
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
803
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
804
+ " --sklearn-color-icon: #878787;\n",
805
+ " }\n",
806
+ "}\n",
807
+ "\n",
808
+ "#sk-container-id-2 {\n",
809
+ " color: var(--sklearn-color-text);\n",
810
+ "}\n",
811
+ "\n",
812
+ "#sk-container-id-2 pre {\n",
813
+ " padding: 0;\n",
814
+ "}\n",
815
+ "\n",
816
+ "#sk-container-id-2 input.sk-hidden--visually {\n",
817
+ " border: 0;\n",
818
+ " clip: rect(1px 1px 1px 1px);\n",
819
+ " clip: rect(1px, 1px, 1px, 1px);\n",
820
+ " height: 1px;\n",
821
+ " margin: -1px;\n",
822
+ " overflow: hidden;\n",
823
+ " padding: 0;\n",
824
+ " position: absolute;\n",
825
+ " width: 1px;\n",
826
+ "}\n",
827
+ "\n",
828
+ "#sk-container-id-2 div.sk-dashed-wrapped {\n",
829
+ " border: 1px dashed var(--sklearn-color-line);\n",
830
+ " margin: 0 0.4em 0.5em 0.4em;\n",
831
+ " box-sizing: border-box;\n",
832
+ " padding-bottom: 0.4em;\n",
833
+ " background-color: var(--sklearn-color-background);\n",
834
+ "}\n",
835
+ "\n",
836
+ "#sk-container-id-2 div.sk-container {\n",
837
+ " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
838
+ " but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
839
+ " so we also need the `!important` here to be able to override the\n",
840
+ " default hidden behavior on the sphinx rendered scikit-learn.org.\n",
841
+ " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
842
+ " display: inline-block !important;\n",
843
+ " position: relative;\n",
844
+ "}\n",
845
+ "\n",
846
+ "#sk-container-id-2 div.sk-text-repr-fallback {\n",
847
+ " display: none;\n",
848
+ "}\n",
849
+ "\n",
850
+ "div.sk-parallel-item,\n",
851
+ "div.sk-serial,\n",
852
+ "div.sk-item {\n",
853
+ " /* draw centered vertical line to link estimators */\n",
854
+ " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
855
+ " background-size: 2px 100%;\n",
856
+ " background-repeat: no-repeat;\n",
857
+ " background-position: center center;\n",
858
+ "}\n",
859
+ "\n",
860
+ "/* Parallel-specific style estimator block */\n",
861
+ "\n",
862
+ "#sk-container-id-2 div.sk-parallel-item::after {\n",
863
+ " content: \"\";\n",
864
+ " width: 100%;\n",
865
+ " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
866
+ " flex-grow: 1;\n",
867
+ "}\n",
868
+ "\n",
869
+ "#sk-container-id-2 div.sk-parallel {\n",
870
+ " display: flex;\n",
871
+ " align-items: stretch;\n",
872
+ " justify-content: center;\n",
873
+ " background-color: var(--sklearn-color-background);\n",
874
+ " position: relative;\n",
875
+ "}\n",
876
+ "\n",
877
+ "#sk-container-id-2 div.sk-parallel-item {\n",
878
+ " display: flex;\n",
879
+ " flex-direction: column;\n",
880
+ "}\n",
881
+ "\n",
882
+ "#sk-container-id-2 div.sk-parallel-item:first-child::after {\n",
883
+ " align-self: flex-end;\n",
884
+ " width: 50%;\n",
885
+ "}\n",
886
+ "\n",
887
+ "#sk-container-id-2 div.sk-parallel-item:last-child::after {\n",
888
+ " align-self: flex-start;\n",
889
+ " width: 50%;\n",
890
+ "}\n",
891
+ "\n",
892
+ "#sk-container-id-2 div.sk-parallel-item:only-child::after {\n",
893
+ " width: 0;\n",
894
+ "}\n",
895
+ "\n",
896
+ "/* Serial-specific style estimator block */\n",
897
+ "\n",
898
+ "#sk-container-id-2 div.sk-serial {\n",
899
+ " display: flex;\n",
900
+ " flex-direction: column;\n",
901
+ " align-items: center;\n",
902
+ " background-color: var(--sklearn-color-background);\n",
903
+ " padding-right: 1em;\n",
904
+ " padding-left: 1em;\n",
905
+ "}\n",
906
+ "\n",
907
+ "\n",
908
+ "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
909
+ "clickable and can be expanded/collapsed.\n",
910
+ "- Pipeline and ColumnTransformer use this feature and define the default style\n",
911
+ "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
912
+ "*/\n",
913
+ "\n",
914
+ "/* Pipeline and ColumnTransformer style (default) */\n",
915
+ "\n",
916
+ "#sk-container-id-2 div.sk-toggleable {\n",
917
+ " /* Default theme specific background. It is overwritten whether we have a\n",
918
+ " specific estimator or a Pipeline/ColumnTransformer */\n",
919
+ " background-color: var(--sklearn-color-background);\n",
920
+ "}\n",
921
+ "\n",
922
+ "/* Toggleable label */\n",
923
+ "#sk-container-id-2 label.sk-toggleable__label {\n",
924
+ " cursor: pointer;\n",
925
+ " display: block;\n",
926
+ " width: 100%;\n",
927
+ " margin-bottom: 0;\n",
928
+ " padding: 0.5em;\n",
929
+ " box-sizing: border-box;\n",
930
+ " text-align: center;\n",
931
+ "}\n",
932
+ "\n",
933
+ "#sk-container-id-2 label.sk-toggleable__label-arrow:before {\n",
934
+ " /* Arrow on the left of the label */\n",
935
+ " content: \"▸\";\n",
936
+ " float: left;\n",
937
+ " margin-right: 0.25em;\n",
938
+ " color: var(--sklearn-color-icon);\n",
939
+ "}\n",
940
+ "\n",
941
+ "#sk-container-id-2 label.sk-toggleable__label-arrow:hover:before {\n",
942
+ " color: var(--sklearn-color-text);\n",
943
+ "}\n",
944
+ "\n",
945
+ "/* Toggleable content - dropdown */\n",
946
+ "\n",
947
+ "#sk-container-id-2 div.sk-toggleable__content {\n",
948
+ " max-height: 0;\n",
949
+ " max-width: 0;\n",
950
+ " overflow: hidden;\n",
951
+ " text-align: left;\n",
952
+ " /* unfitted */\n",
953
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
954
+ "}\n",
955
+ "\n",
956
+ "#sk-container-id-2 div.sk-toggleable__content.fitted {\n",
957
+ " /* fitted */\n",
958
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
959
+ "}\n",
960
+ "\n",
961
+ "#sk-container-id-2 div.sk-toggleable__content pre {\n",
962
+ " margin: 0.2em;\n",
963
+ " border-radius: 0.25em;\n",
964
+ " color: var(--sklearn-color-text);\n",
965
+ " /* unfitted */\n",
966
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
967
+ "}\n",
968
+ "\n",
969
+ "#sk-container-id-2 div.sk-toggleable__content.fitted pre {\n",
970
+ " /* unfitted */\n",
971
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
972
+ "}\n",
973
+ "\n",
974
+ "#sk-container-id-2 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
975
+ " /* Expand drop-down */\n",
976
+ " max-height: 200px;\n",
977
+ " max-width: 100%;\n",
978
+ " overflow: auto;\n",
979
+ "}\n",
980
+ "\n",
981
+ "#sk-container-id-2 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
982
+ " content: \"▾\";\n",
983
+ "}\n",
984
+ "\n",
985
+ "/* Pipeline/ColumnTransformer-specific style */\n",
986
+ "\n",
987
+ "#sk-container-id-2 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
988
+ " color: var(--sklearn-color-text);\n",
989
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
990
+ "}\n",
991
+ "\n",
992
+ "#sk-container-id-2 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
993
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
994
+ "}\n",
995
+ "\n",
996
+ "/* Estimator-specific style */\n",
997
+ "\n",
998
+ "/* Colorize estimator box */\n",
999
+ "#sk-container-id-2 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
1000
+ " /* unfitted */\n",
1001
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1002
+ "}\n",
1003
+ "\n",
1004
+ "#sk-container-id-2 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
1005
+ " /* fitted */\n",
1006
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1007
+ "}\n",
1008
+ "\n",
1009
+ "#sk-container-id-2 div.sk-label label.sk-toggleable__label,\n",
1010
+ "#sk-container-id-2 div.sk-label label {\n",
1011
+ " /* The background is the default theme color */\n",
1012
+ " color: var(--sklearn-color-text-on-default-background);\n",
1013
+ "}\n",
1014
+ "\n",
1015
+ "/* On hover, darken the color of the background */\n",
1016
+ "#sk-container-id-2 div.sk-label:hover label.sk-toggleable__label {\n",
1017
+ " color: var(--sklearn-color-text);\n",
1018
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1019
+ "}\n",
1020
+ "\n",
1021
+ "/* Label box, darken color on hover, fitted */\n",
1022
+ "#sk-container-id-2 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
1023
+ " color: var(--sklearn-color-text);\n",
1024
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1025
+ "}\n",
1026
+ "\n",
1027
+ "/* Estimator label */\n",
1028
+ "\n",
1029
+ "#sk-container-id-2 div.sk-label label {\n",
1030
+ " font-family: monospace;\n",
1031
+ " font-weight: bold;\n",
1032
+ " display: inline-block;\n",
1033
+ " line-height: 1.2em;\n",
1034
+ "}\n",
1035
+ "\n",
1036
+ "#sk-container-id-2 div.sk-label-container {\n",
1037
+ " text-align: center;\n",
1038
+ "}\n",
1039
+ "\n",
1040
+ "/* Estimator-specific */\n",
1041
+ "#sk-container-id-2 div.sk-estimator {\n",
1042
+ " font-family: monospace;\n",
1043
+ " border: 1px dotted var(--sklearn-color-border-box);\n",
1044
+ " border-radius: 0.25em;\n",
1045
+ " box-sizing: border-box;\n",
1046
+ " margin-bottom: 0.5em;\n",
1047
+ " /* unfitted */\n",
1048
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
1049
+ "}\n",
1050
+ "\n",
1051
+ "#sk-container-id-2 div.sk-estimator.fitted {\n",
1052
+ " /* fitted */\n",
1053
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
1054
+ "}\n",
1055
+ "\n",
1056
+ "/* on hover */\n",
1057
+ "#sk-container-id-2 div.sk-estimator:hover {\n",
1058
+ " /* unfitted */\n",
1059
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1060
+ "}\n",
1061
+ "\n",
1062
+ "#sk-container-id-2 div.sk-estimator.fitted:hover {\n",
1063
+ " /* fitted */\n",
1064
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1065
+ "}\n",
1066
+ "\n",
1067
+ "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
1068
+ "\n",
1069
+ "/* Common style for \"i\" and \"?\" */\n",
1070
+ "\n",
1071
+ ".sk-estimator-doc-link,\n",
1072
+ "a:link.sk-estimator-doc-link,\n",
1073
+ "a:visited.sk-estimator-doc-link {\n",
1074
+ " float: right;\n",
1075
+ " font-size: smaller;\n",
1076
+ " line-height: 1em;\n",
1077
+ " font-family: monospace;\n",
1078
+ " background-color: var(--sklearn-color-background);\n",
1079
+ " border-radius: 1em;\n",
1080
+ " height: 1em;\n",
1081
+ " width: 1em;\n",
1082
+ " text-decoration: none !important;\n",
1083
+ " margin-left: 1ex;\n",
1084
+ " /* unfitted */\n",
1085
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
1086
+ " color: var(--sklearn-color-unfitted-level-1);\n",
1087
+ "}\n",
1088
+ "\n",
1089
+ ".sk-estimator-doc-link.fitted,\n",
1090
+ "a:link.sk-estimator-doc-link.fitted,\n",
1091
+ "a:visited.sk-estimator-doc-link.fitted {\n",
1092
+ " /* fitted */\n",
1093
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
1094
+ " color: var(--sklearn-color-fitted-level-1);\n",
1095
+ "}\n",
1096
+ "\n",
1097
+ "/* On hover */\n",
1098
+ "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
1099
+ ".sk-estimator-doc-link:hover,\n",
1100
+ "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
1101
+ ".sk-estimator-doc-link:hover {\n",
1102
+ " /* unfitted */\n",
1103
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
1104
+ " color: var(--sklearn-color-background);\n",
1105
+ " text-decoration: none;\n",
1106
+ "}\n",
1107
+ "\n",
1108
+ "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
1109
+ ".sk-estimator-doc-link.fitted:hover,\n",
1110
+ "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
1111
+ ".sk-estimator-doc-link.fitted:hover {\n",
1112
+ " /* fitted */\n",
1113
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
1114
+ " color: var(--sklearn-color-background);\n",
1115
+ " text-decoration: none;\n",
1116
+ "}\n",
1117
+ "\n",
1118
+ "/* Span, style for the box shown on hovering the info icon */\n",
1119
+ ".sk-estimator-doc-link span {\n",
1120
+ " display: none;\n",
1121
+ " z-index: 9999;\n",
1122
+ " position: relative;\n",
1123
+ " font-weight: normal;\n",
1124
+ " right: .2ex;\n",
1125
+ " padding: .5ex;\n",
1126
+ " margin: .5ex;\n",
1127
+ " width: min-content;\n",
1128
+ " min-width: 20ex;\n",
1129
+ " max-width: 50ex;\n",
1130
+ " color: var(--sklearn-color-text);\n",
1131
+ " box-shadow: 2pt 2pt 4pt #999;\n",
1132
+ " /* unfitted */\n",
1133
+ " background: var(--sklearn-color-unfitted-level-0);\n",
1134
+ " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
1135
+ "}\n",
1136
+ "\n",
1137
+ ".sk-estimator-doc-link.fitted span {\n",
1138
+ " /* fitted */\n",
1139
+ " background: var(--sklearn-color-fitted-level-0);\n",
1140
+ " border: var(--sklearn-color-fitted-level-3);\n",
1141
+ "}\n",
1142
+ "\n",
1143
+ ".sk-estimator-doc-link:hover span {\n",
1144
+ " display: block;\n",
1145
+ "}\n",
1146
+ "\n",
1147
+ "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
1148
+ "\n",
1149
+ "#sk-container-id-2 a.estimator_doc_link {\n",
1150
+ " float: right;\n",
1151
+ " font-size: 1rem;\n",
1152
+ " line-height: 1em;\n",
1153
+ " font-family: monospace;\n",
1154
+ " background-color: var(--sklearn-color-background);\n",
1155
+ " border-radius: 1rem;\n",
1156
+ " height: 1rem;\n",
1157
+ " width: 1rem;\n",
1158
+ " text-decoration: none;\n",
1159
+ " /* unfitted */\n",
1160
+ " color: var(--sklearn-color-unfitted-level-1);\n",
1161
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
1162
+ "}\n",
1163
+ "\n",
1164
+ "#sk-container-id-2 a.estimator_doc_link.fitted {\n",
1165
+ " /* fitted */\n",
1166
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
1167
+ " color: var(--sklearn-color-fitted-level-1);\n",
1168
+ "}\n",
1169
+ "\n",
1170
+ "/* On hover */\n",
1171
+ "#sk-container-id-2 a.estimator_doc_link:hover {\n",
1172
+ " /* unfitted */\n",
1173
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
1174
+ " color: var(--sklearn-color-background);\n",
1175
+ " text-decoration: none;\n",
1176
+ "}\n",
1177
+ "\n",
1178
+ "#sk-container-id-2 a.estimator_doc_link.fitted:hover {\n",
1179
+ " /* fitted */\n",
1180
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
1181
+ "}\n",
1182
+ "</style><div id=\"sk-container-id-2\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>RandomForestClassifier()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" checked><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier()</pre></div> </div></div></div></div>"
1183
+ ],
1184
+ "text/plain": [
1185
+ "RandomForestClassifier()"
1186
+ ]
1187
+ },
1188
+ "execution_count": 26,
1189
+ "metadata": {},
1190
+ "output_type": "execute_result"
1191
+ }
1192
+ ],
1193
+ "source": [
1194
+ "# random forest classifier\n",
1195
+ "from sklearn.ensemble import RandomForestClassifier\n",
1196
+ "\n",
1197
+ "# Initialize the model\n",
1198
+ "model = RandomForestClassifier()\n",
1199
+ "\n",
1200
+ "# Train the model\n",
1201
+ "model.fit(X_train_tfidf, y_train)\n"
1202
+ ]
1203
+ },
1204
+ {
1205
+ "cell_type": "code",
1206
+ "execution_count": 27,
1207
+ "metadata": {},
1208
+ "outputs": [
1209
+ {
1210
+ "name": "stdout",
1211
+ "output_type": "stream",
1212
+ "text": [
1213
+ "Accuracy: 0.688715953307393\n",
1214
+ " precision recall f1-score support\n",
1215
+ "\n",
1216
+ " 0 0.90 0.50 0.64 768\n",
1217
+ " 1 0.97 0.37 0.53 556\n",
1218
+ " 2 0.55 0.82 0.66 3081\n",
1219
+ " 3 0.79 0.95 0.86 3269\n",
1220
+ " 4 1.00 0.26 0.41 215\n",
1221
+ " 5 0.97 0.21 0.35 517\n",
1222
+ " 6 0.71 0.40 0.52 2131\n",
1223
+ "\n",
1224
+ " accuracy 0.69 10537\n",
1225
+ " macro avg 0.84 0.50 0.57 10537\n",
1226
+ "weighted avg 0.74 0.69 0.67 10537\n",
1227
+ "\n"
1228
+ ]
1229
+ }
1230
+ ],
1231
+ "source": [
1232
+ "from sklearn.metrics import classification_report, accuracy_score\n",
1233
+ "# making predictions\n",
1234
+ "y_pred = model.predict(X_test_tfidf)\n",
1235
+ "\n",
1236
+ "# checking the accuracy\n",
1237
+ "accuracy = accuracy_score(y_test, y_pred)\n",
1238
+ "print('Accuracy:', accuracy)\n",
1239
+ "\n",
1240
+ "# classification report\n",
1241
+ "report = classification_report(y_test, y_pred)\n",
1242
+ "print(report)"
1243
+ ]
1244
+ },
1245
+ {
1246
+ "cell_type": "code",
1247
+ "execution_count": 28,
1248
+ "metadata": {},
1249
+ "outputs": [],
1250
+ "source": [
1251
+ "# creating a pipeline\n",
1252
+ "from sklearn.base import BaseEstimator, TransformerMixin\n",
1253
+ "from sklearn.pipeline import Pipeline\n",
1254
+ "\n",
1255
+ "# Custom transformer for text preprocessing\n",
1256
+ "class TextPreprocessor(BaseEstimator, TransformerMixin):\n",
1257
+ " def __init__(self):\n",
1258
+ " self.stop_words = set(stopwords.words('english'))\n",
1259
+ " self.lemmatizer = WordNetLemmatizer()\n",
1260
+ " \n",
1261
+ " def preprocess_text(self, text):\n",
1262
+ " # Lowercase the text\n",
1263
+ " text = text.lower()\n",
1264
+ " \n",
1265
+ " # Remove punctuation\n",
1266
+ " text = re.sub(f'[{re.escape(string.punctuation)}]', '', text)\n",
1267
+ " \n",
1268
+ " # Remove numbers\n",
1269
+ " text = re.sub(r'\\d+', '', text)\n",
1270
+ " \n",
1271
+ " # Tokenize the text\n",
1272
+ " words = text.split()\n",
1273
+ " \n",
1274
+ " # Remove stopwords and apply lemmatization\n",
1275
+ " words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]\n",
1276
+ " \n",
1277
+ " # Join words back into a single string\n",
1278
+ " cleaned_text = ' '.join(words)\n",
1279
+ " \n",
1280
+ " return cleaned_text\n",
1281
+ " \n",
1282
+ " def fit(self, X, y=None):\n",
1283
+ " return self\n",
1284
+ " \n",
1285
+ " def transform(self, X, y=None):\n",
1286
+ " return [self.preprocess_text(text) for text in X]\n",
1287
+ " \n",
1288
+ " \n"
1289
+ ]
1290
+ },
1291
+ {
1292
+ "cell_type": "code",
1293
+ "execution_count": 29,
1294
+ "metadata": {},
1295
+ "outputs": [],
1296
+ "source": [
1297
+ "pipeline = Pipeline([\n",
1298
+ " ('preprocessor', TextPreprocessor()),\n",
1299
+ " ('vectorizer', TfidfVectorizer()),\n",
1300
+ " ('classifier', RandomForestClassifier())\n",
1301
+ "])"
1302
+ ]
1303
+ },
1304
+ {
1305
+ "cell_type": "code",
1306
+ "execution_count": 31,
1307
+ "metadata": {},
1308
+ "outputs": [],
1309
+ "source": [
1310
+ "X = df_1['statement']\n",
1311
+ "y = df_2['status']\n",
1312
+ "\n",
1313
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)"
1314
+ ]
1315
+ },
1316
+ {
1317
+ "cell_type": "code",
1318
+ "execution_count": 32,
1319
+ "metadata": {},
1320
+ "outputs": [
1321
+ {
1322
+ "data": {
1323
+ "text/html": [
1324
+ "<style>#sk-container-id-3 {\n",
1325
+ " /* Definition of color scheme common for light and dark mode */\n",
1326
+ " --sklearn-color-text: black;\n",
1327
+ " --sklearn-color-line: gray;\n",
1328
+ " /* Definition of color scheme for unfitted estimators */\n",
1329
+ " --sklearn-color-unfitted-level-0: #fff5e6;\n",
1330
+ " --sklearn-color-unfitted-level-1: #f6e4d2;\n",
1331
+ " --sklearn-color-unfitted-level-2: #ffe0b3;\n",
1332
+ " --sklearn-color-unfitted-level-3: chocolate;\n",
1333
+ " /* Definition of color scheme for fitted estimators */\n",
1334
+ " --sklearn-color-fitted-level-0: #f0f8ff;\n",
1335
+ " --sklearn-color-fitted-level-1: #d4ebff;\n",
1336
+ " --sklearn-color-fitted-level-2: #b3dbfd;\n",
1337
+ " --sklearn-color-fitted-level-3: cornflowerblue;\n",
1338
+ "\n",
1339
+ " /* Specific color for light theme */\n",
1340
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
1341
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
1342
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
1343
+ " --sklearn-color-icon: #696969;\n",
1344
+ "\n",
1345
+ " @media (prefers-color-scheme: dark) {\n",
1346
+ " /* Redefinition of color scheme for dark theme */\n",
1347
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
1348
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
1349
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
1350
+ " --sklearn-color-icon: #878787;\n",
1351
+ " }\n",
1352
+ "}\n",
1353
+ "\n",
1354
+ "#sk-container-id-3 {\n",
1355
+ " color: var(--sklearn-color-text);\n",
1356
+ "}\n",
1357
+ "\n",
1358
+ "#sk-container-id-3 pre {\n",
1359
+ " padding: 0;\n",
1360
+ "}\n",
1361
+ "\n",
1362
+ "#sk-container-id-3 input.sk-hidden--visually {\n",
1363
+ " border: 0;\n",
1364
+ " clip: rect(1px 1px 1px 1px);\n",
1365
+ " clip: rect(1px, 1px, 1px, 1px);\n",
1366
+ " height: 1px;\n",
1367
+ " margin: -1px;\n",
1368
+ " overflow: hidden;\n",
1369
+ " padding: 0;\n",
1370
+ " position: absolute;\n",
1371
+ " width: 1px;\n",
1372
+ "}\n",
1373
+ "\n",
1374
+ "#sk-container-id-3 div.sk-dashed-wrapped {\n",
1375
+ " border: 1px dashed var(--sklearn-color-line);\n",
1376
+ " margin: 0 0.4em 0.5em 0.4em;\n",
1377
+ " box-sizing: border-box;\n",
1378
+ " padding-bottom: 0.4em;\n",
1379
+ " background-color: var(--sklearn-color-background);\n",
1380
+ "}\n",
1381
+ "\n",
1382
+ "#sk-container-id-3 div.sk-container {\n",
1383
+ " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
1384
+ " but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
1385
+ " so we also need the `!important` here to be able to override the\n",
1386
+ " default hidden behavior on the sphinx rendered scikit-learn.org.\n",
1387
+ " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
1388
+ " display: inline-block !important;\n",
1389
+ " position: relative;\n",
1390
+ "}\n",
1391
+ "\n",
1392
+ "#sk-container-id-3 div.sk-text-repr-fallback {\n",
1393
+ " display: none;\n",
1394
+ "}\n",
1395
+ "\n",
1396
+ "div.sk-parallel-item,\n",
1397
+ "div.sk-serial,\n",
1398
+ "div.sk-item {\n",
1399
+ " /* draw centered vertical line to link estimators */\n",
1400
+ " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
1401
+ " background-size: 2px 100%;\n",
1402
+ " background-repeat: no-repeat;\n",
1403
+ " background-position: center center;\n",
1404
+ "}\n",
1405
+ "\n",
1406
+ "/* Parallel-specific style estimator block */\n",
1407
+ "\n",
1408
+ "#sk-container-id-3 div.sk-parallel-item::after {\n",
1409
+ " content: \"\";\n",
1410
+ " width: 100%;\n",
1411
+ " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
1412
+ " flex-grow: 1;\n",
1413
+ "}\n",
1414
+ "\n",
1415
+ "#sk-container-id-3 div.sk-parallel {\n",
1416
+ " display: flex;\n",
1417
+ " align-items: stretch;\n",
1418
+ " justify-content: center;\n",
1419
+ " background-color: var(--sklearn-color-background);\n",
1420
+ " position: relative;\n",
1421
+ "}\n",
1422
+ "\n",
1423
+ "#sk-container-id-3 div.sk-parallel-item {\n",
1424
+ " display: flex;\n",
1425
+ " flex-direction: column;\n",
1426
+ "}\n",
1427
+ "\n",
1428
+ "#sk-container-id-3 div.sk-parallel-item:first-child::after {\n",
1429
+ " align-self: flex-end;\n",
1430
+ " width: 50%;\n",
1431
+ "}\n",
1432
+ "\n",
1433
+ "#sk-container-id-3 div.sk-parallel-item:last-child::after {\n",
1434
+ " align-self: flex-start;\n",
1435
+ " width: 50%;\n",
1436
+ "}\n",
1437
+ "\n",
1438
+ "#sk-container-id-3 div.sk-parallel-item:only-child::after {\n",
1439
+ " width: 0;\n",
1440
+ "}\n",
1441
+ "\n",
1442
+ "/* Serial-specific style estimator block */\n",
1443
+ "\n",
1444
+ "#sk-container-id-3 div.sk-serial {\n",
1445
+ " display: flex;\n",
1446
+ " flex-direction: column;\n",
1447
+ " align-items: center;\n",
1448
+ " background-color: var(--sklearn-color-background);\n",
1449
+ " padding-right: 1em;\n",
1450
+ " padding-left: 1em;\n",
1451
+ "}\n",
1452
+ "\n",
1453
+ "\n",
1454
+ "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
1455
+ "clickable and can be expanded/collapsed.\n",
1456
+ "- Pipeline and ColumnTransformer use this feature and define the default style\n",
1457
+ "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
1458
+ "*/\n",
1459
+ "\n",
1460
+ "/* Pipeline and ColumnTransformer style (default) */\n",
1461
+ "\n",
1462
+ "#sk-container-id-3 div.sk-toggleable {\n",
1463
+ " /* Default theme specific background. It is overwritten whether we have a\n",
1464
+ " specific estimator or a Pipeline/ColumnTransformer */\n",
1465
+ " background-color: var(--sklearn-color-background);\n",
1466
+ "}\n",
1467
+ "\n",
1468
+ "/* Toggleable label */\n",
1469
+ "#sk-container-id-3 label.sk-toggleable__label {\n",
1470
+ " cursor: pointer;\n",
1471
+ " display: block;\n",
1472
+ " width: 100%;\n",
1473
+ " margin-bottom: 0;\n",
1474
+ " padding: 0.5em;\n",
1475
+ " box-sizing: border-box;\n",
1476
+ " text-align: center;\n",
1477
+ "}\n",
1478
+ "\n",
1479
+ "#sk-container-id-3 label.sk-toggleable__label-arrow:before {\n",
1480
+ " /* Arrow on the left of the label */\n",
1481
+ " content: \"▸\";\n",
1482
+ " float: left;\n",
1483
+ " margin-right: 0.25em;\n",
1484
+ " color: var(--sklearn-color-icon);\n",
1485
+ "}\n",
1486
+ "\n",
1487
+ "#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {\n",
1488
+ " color: var(--sklearn-color-text);\n",
1489
+ "}\n",
1490
+ "\n",
1491
+ "/* Toggleable content - dropdown */\n",
1492
+ "\n",
1493
+ "#sk-container-id-3 div.sk-toggleable__content {\n",
1494
+ " max-height: 0;\n",
1495
+ " max-width: 0;\n",
1496
+ " overflow: hidden;\n",
1497
+ " text-align: left;\n",
1498
+ " /* unfitted */\n",
1499
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
1500
+ "}\n",
1501
+ "\n",
1502
+ "#sk-container-id-3 div.sk-toggleable__content.fitted {\n",
1503
+ " /* fitted */\n",
1504
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
1505
+ "}\n",
1506
+ "\n",
1507
+ "#sk-container-id-3 div.sk-toggleable__content pre {\n",
1508
+ " margin: 0.2em;\n",
1509
+ " border-radius: 0.25em;\n",
1510
+ " color: var(--sklearn-color-text);\n",
1511
+ " /* unfitted */\n",
1512
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
1513
+ "}\n",
1514
+ "\n",
1515
+ "#sk-container-id-3 div.sk-toggleable__content.fitted pre {\n",
1516
+ " /* unfitted */\n",
1517
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
1518
+ "}\n",
1519
+ "\n",
1520
+ "#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
1521
+ " /* Expand drop-down */\n",
1522
+ " max-height: 200px;\n",
1523
+ " max-width: 100%;\n",
1524
+ " overflow: auto;\n",
1525
+ "}\n",
1526
+ "\n",
1527
+ "#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
1528
+ " content: \"▾\";\n",
1529
+ "}\n",
1530
+ "\n",
1531
+ "/* Pipeline/ColumnTransformer-specific style */\n",
1532
+ "\n",
1533
+ "#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
1534
+ " color: var(--sklearn-color-text);\n",
1535
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1536
+ "}\n",
1537
+ "\n",
1538
+ "#sk-container-id-3 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
1539
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1540
+ "}\n",
1541
+ "\n",
1542
+ "/* Estimator-specific style */\n",
1543
+ "\n",
1544
+ "/* Colorize estimator box */\n",
1545
+ "#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
1546
+ " /* unfitted */\n",
1547
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1548
+ "}\n",
1549
+ "\n",
1550
+ "#sk-container-id-3 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
1551
+ " /* fitted */\n",
1552
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1553
+ "}\n",
1554
+ "\n",
1555
+ "#sk-container-id-3 div.sk-label label.sk-toggleable__label,\n",
1556
+ "#sk-container-id-3 div.sk-label label {\n",
1557
+ " /* The background is the default theme color */\n",
1558
+ " color: var(--sklearn-color-text-on-default-background);\n",
1559
+ "}\n",
1560
+ "\n",
1561
+ "/* On hover, darken the color of the background */\n",
1562
+ "#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {\n",
1563
+ " color: var(--sklearn-color-text);\n",
1564
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1565
+ "}\n",
1566
+ "\n",
1567
+ "/* Label box, darken color on hover, fitted */\n",
1568
+ "#sk-container-id-3 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
1569
+ " color: var(--sklearn-color-text);\n",
1570
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1571
+ "}\n",
1572
+ "\n",
1573
+ "/* Estimator label */\n",
1574
+ "\n",
1575
+ "#sk-container-id-3 div.sk-label label {\n",
1576
+ " font-family: monospace;\n",
1577
+ " font-weight: bold;\n",
1578
+ " display: inline-block;\n",
1579
+ " line-height: 1.2em;\n",
1580
+ "}\n",
1581
+ "\n",
1582
+ "#sk-container-id-3 div.sk-label-container {\n",
1583
+ " text-align: center;\n",
1584
+ "}\n",
1585
+ "\n",
1586
+ "/* Estimator-specific */\n",
1587
+ "#sk-container-id-3 div.sk-estimator {\n",
1588
+ " font-family: monospace;\n",
1589
+ " border: 1px dotted var(--sklearn-color-border-box);\n",
1590
+ " border-radius: 0.25em;\n",
1591
+ " box-sizing: border-box;\n",
1592
+ " margin-bottom: 0.5em;\n",
1593
+ " /* unfitted */\n",
1594
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
1595
+ "}\n",
1596
+ "\n",
1597
+ "#sk-container-id-3 div.sk-estimator.fitted {\n",
1598
+ " /* fitted */\n",
1599
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
1600
+ "}\n",
1601
+ "\n",
1602
+ "/* on hover */\n",
1603
+ "#sk-container-id-3 div.sk-estimator:hover {\n",
1604
+ " /* unfitted */\n",
1605
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
1606
+ "}\n",
1607
+ "\n",
1608
+ "#sk-container-id-3 div.sk-estimator.fitted:hover {\n",
1609
+ " /* fitted */\n",
1610
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
1611
+ "}\n",
1612
+ "\n",
1613
+ "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
1614
+ "\n",
1615
+ "/* Common style for \"i\" and \"?\" */\n",
1616
+ "\n",
1617
+ ".sk-estimator-doc-link,\n",
1618
+ "a:link.sk-estimator-doc-link,\n",
1619
+ "a:visited.sk-estimator-doc-link {\n",
1620
+ " float: right;\n",
1621
+ " font-size: smaller;\n",
1622
+ " line-height: 1em;\n",
1623
+ " font-family: monospace;\n",
1624
+ " background-color: var(--sklearn-color-background);\n",
1625
+ " border-radius: 1em;\n",
1626
+ " height: 1em;\n",
1627
+ " width: 1em;\n",
1628
+ " text-decoration: none !important;\n",
1629
+ " margin-left: 1ex;\n",
1630
+ " /* unfitted */\n",
1631
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
1632
+ " color: var(--sklearn-color-unfitted-level-1);\n",
1633
+ "}\n",
1634
+ "\n",
1635
+ ".sk-estimator-doc-link.fitted,\n",
1636
+ "a:link.sk-estimator-doc-link.fitted,\n",
1637
+ "a:visited.sk-estimator-doc-link.fitted {\n",
1638
+ " /* fitted */\n",
1639
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
1640
+ " color: var(--sklearn-color-fitted-level-1);\n",
1641
+ "}\n",
1642
+ "\n",
1643
+ "/* On hover */\n",
1644
+ "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
1645
+ ".sk-estimator-doc-link:hover,\n",
1646
+ "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
1647
+ ".sk-estimator-doc-link:hover {\n",
1648
+ " /* unfitted */\n",
1649
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
1650
+ " color: var(--sklearn-color-background);\n",
1651
+ " text-decoration: none;\n",
1652
+ "}\n",
1653
+ "\n",
1654
+ "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
1655
+ ".sk-estimator-doc-link.fitted:hover,\n",
1656
+ "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
1657
+ ".sk-estimator-doc-link.fitted:hover {\n",
1658
+ " /* fitted */\n",
1659
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
1660
+ " color: var(--sklearn-color-background);\n",
1661
+ " text-decoration: none;\n",
1662
+ "}\n",
1663
+ "\n",
1664
+ "/* Span, style for the box shown on hovering the info icon */\n",
1665
+ ".sk-estimator-doc-link span {\n",
1666
+ " display: none;\n",
1667
+ " z-index: 9999;\n",
1668
+ " position: relative;\n",
1669
+ " font-weight: normal;\n",
1670
+ " right: .2ex;\n",
1671
+ " padding: .5ex;\n",
1672
+ " margin: .5ex;\n",
1673
+ " width: min-content;\n",
1674
+ " min-width: 20ex;\n",
1675
+ " max-width: 50ex;\n",
1676
+ " color: var(--sklearn-color-text);\n",
1677
+ " box-shadow: 2pt 2pt 4pt #999;\n",
1678
+ " /* unfitted */\n",
1679
+ " background: var(--sklearn-color-unfitted-level-0);\n",
1680
+ " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
1681
+ "}\n",
1682
+ "\n",
1683
+ ".sk-estimator-doc-link.fitted span {\n",
1684
+ " /* fitted */\n",
1685
+ " background: var(--sklearn-color-fitted-level-0);\n",
1686
+ " border: var(--sklearn-color-fitted-level-3);\n",
1687
+ "}\n",
1688
+ "\n",
1689
+ ".sk-estimator-doc-link:hover span {\n",
1690
+ " display: block;\n",
1691
+ "}\n",
1692
+ "\n",
1693
+ "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
1694
+ "\n",
1695
+ "#sk-container-id-3 a.estimator_doc_link {\n",
1696
+ " float: right;\n",
1697
+ " font-size: 1rem;\n",
1698
+ " line-height: 1em;\n",
1699
+ " font-family: monospace;\n",
1700
+ " background-color: var(--sklearn-color-background);\n",
1701
+ " border-radius: 1rem;\n",
1702
+ " height: 1rem;\n",
1703
+ " width: 1rem;\n",
1704
+ " text-decoration: none;\n",
1705
+ " /* unfitted */\n",
1706
+ " color: var(--sklearn-color-unfitted-level-1);\n",
1707
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
1708
+ "}\n",
1709
+ "\n",
1710
+ "#sk-container-id-3 a.estimator_doc_link.fitted {\n",
1711
+ " /* fitted */\n",
1712
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
1713
+ " color: var(--sklearn-color-fitted-level-1);\n",
1714
+ "}\n",
1715
+ "\n",
1716
+ "/* On hover */\n",
1717
+ "#sk-container-id-3 a.estimator_doc_link:hover {\n",
1718
+ " /* unfitted */\n",
1719
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
1720
+ " color: var(--sklearn-color-background);\n",
1721
+ " text-decoration: none;\n",
1722
+ "}\n",
1723
+ "\n",
1724
+ "#sk-container-id-3 a.estimator_doc_link.fitted:hover {\n",
1725
+ " /* fitted */\n",
1726
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
1727
+ "}\n",
1728
+ "</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>Pipeline(steps=[(&#x27;preprocessor&#x27;, TextPreprocessor()),\n",
1729
+ " (&#x27;vectorizer&#x27;, TfidfVectorizer()),\n",
1730
+ " (&#x27;classifier&#x27;, RandomForestClassifier())])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;Pipeline<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.pipeline.Pipeline.html\">?<span>Documentation for Pipeline</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[(&#x27;preprocessor&#x27;, TextPreprocessor()),\n",
1731
+ " (&#x27;vectorizer&#x27;, TfidfVectorizer()),\n",
1732
+ " (&#x27;classifier&#x27;, RandomForestClassifier())])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">TextPreprocessor</label><div class=\"sk-toggleable__content fitted\"><pre>TextPreprocessor()</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;TfidfVectorizer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html\">?<span>Documentation for TfidfVectorizer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>TfidfVectorizer()</pre></div> </div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;RandomForestClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.ensemble.RandomForestClassifier.html\">?<span>Documentation for RandomForestClassifier</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>RandomForestClassifier()</pre></div> </div></div></div></div></div></div>"
1733
+ ],
1734
+ "text/plain": [
1735
+ "Pipeline(steps=[('preprocessor', TextPreprocessor()),\n",
1736
+ " ('vectorizer', TfidfVectorizer()),\n",
1737
+ " ('classifier', RandomForestClassifier())])"
1738
+ ]
1739
+ },
1740
+ "execution_count": 32,
1741
+ "metadata": {},
1742
+ "output_type": "execute_result"
1743
+ }
1744
+ ],
1745
+ "source": [
1746
+ "# Train the model\n",
1747
+ "pipeline.fit(X_train, y_train)"
1748
+ ]
1749
+ },
1750
+ {
1751
+ "cell_type": "code",
1752
+ "execution_count": 33,
1753
+ "metadata": {},
1754
+ "outputs": [],
1755
+ "source": [
1756
+ "# Make predictions\n",
1757
+ "y_pred = pipeline.predict(X_test)"
1758
+ ]
1759
+ },
1760
+ {
1761
+ "cell_type": "code",
1762
+ "execution_count": 34,
1763
+ "metadata": {},
1764
+ "outputs": [
1765
+ {
1766
+ "name": "stdout",
1767
+ "output_type": "stream",
1768
+ "text": [
1769
+ "Accuracy: 0.6797950080668121\n",
1770
+ "Classification Report:\n",
1771
+ " precision recall f1-score support\n",
1772
+ "\n",
1773
+ " 0 0.89 0.49 0.63 768\n",
1774
+ " 1 0.98 0.36 0.52 556\n",
1775
+ " 2 0.54 0.82 0.65 3081\n",
1776
+ " 3 0.79 0.95 0.86 3269\n",
1777
+ " 4 1.00 0.26 0.41 215\n",
1778
+ " 5 0.97 0.21 0.34 517\n",
1779
+ " 6 0.69 0.38 0.49 2131\n",
1780
+ "\n",
1781
+ " accuracy 0.68 10537\n",
1782
+ " macro avg 0.84 0.49 0.56 10537\n",
1783
+ "weighted avg 0.73 0.68 0.66 10537\n",
1784
+ "\n"
1785
+ ]
1786
+ }
1787
+ ],
1788
+ "source": [
1789
+ "# Evaluate the model\n",
1790
+ "accuracy = accuracy_score(y_test, y_pred)\n",
1791
+ "report = classification_report(y_test, y_pred)\n",
1792
+ "\n",
1793
+ "print(f'Accuracy: {accuracy}')\n",
1794
+ "print('Classification Report:')\n",
1795
+ "print(report)"
1796
+ ]
1797
+ },
1798
+ {
1799
+ "cell_type": "code",
1800
+ "execution_count": null,
1801
+ "metadata": {},
1802
+ "outputs": [],
1803
+ "source": []
1804
+ },
1805
+ {
1806
+ "cell_type": "code",
1807
+ "execution_count": null,
1808
+ "metadata": {},
1809
+ "outputs": [],
1810
+ "source": []
1811
+ },
1812
+ {
1813
+ "cell_type": "code",
1814
+ "execution_count": 10,
1815
+ "metadata": {},
1816
+ "outputs": [
1817
+ {
1818
+ "name": "stdout",
1819
+ "output_type": "stream",
1820
+ "text": [
1821
+ "{'text': 'A lot of times if I am feeling sad, I immediately think of how others will respond to it. Or I am looking for comfort.. my father is a homophobic, racist, sexist piece of shit and my mother takes care of everything in the house. I hate my dad, when he started saying things like \"there is only two genders\" and \"you are looking for attention\" and making things seem like I was in the wrong no matter how much I was right, I realized how much of a shitbag he was and really felt desperate. I felt desperate for love and so I am confusing that with wanting attention.. am I in the wrong for doing this? Am I depressed or wanting attention?', 'prediction': 'Depression'}\n"
1822
+ ]
1823
+ }
1824
+ ],
1825
+ "source": [
1826
+ "import requests\n",
1827
+ "text = 'A lot of times if I am feeling sad, I immediately think of how others will respond to it. Or I am looking for comfort.. my father is a homophobic, racist, sexist piece of shit and my mother takes care of everything in the house. I hate my dad, when he started saying things like \"there is only two genders\" and \"you are looking for attention\" and making things seem like I was in the wrong no matter how much I was right, I realized how much of a shitbag he was and really felt desperate. I felt desperate for love and so I am confusing that with wanting attention.. am I in the wrong for doing this? Am I depressed or wanting attention?'\n",
1828
+ "url = \"http://127.0.0.1:8000/predict_sentiment\"\n",
1829
+ "data = {\"text\": text}\n",
1830
+ "response = requests.post(url, json=data)\n",
1831
+ "\n",
1832
+ "print(response.json())\n"
1833
+ ]
1834
+ },
1835
+ {
1836
+ "cell_type": "code",
1837
+ "execution_count": null,
1838
+ "metadata": {},
1839
+ "outputs": [],
1840
+ "source": []
1841
+ },
1842
+ {
1843
+ "cell_type": "code",
1844
+ "execution_count": null,
1845
+ "metadata": {},
1846
+ "outputs": [],
1847
+ "source": []
1848
+ }
1849
+ ],
1850
+ "metadata": {
1851
+ "kernelspec": {
1852
+ "display_name": "Python 3",
1853
+ "language": "python",
1854
+ "name": "python3"
1855
+ },
1856
+ "language_info": {
1857
+ "codemirror_mode": {
1858
+ "name": "ipython",
1859
+ "version": 3
1860
+ },
1861
+ "file_extension": ".py",
1862
+ "mimetype": "text/x-python",
1863
+ "name": "python",
1864
+ "nbconvert_exporter": "python",
1865
+ "pygments_lexer": "ipython3",
1866
+ "version": "3.10.14"
1867
+ }
1868
+ },
1869
+ "nbformat": 4,
1870
+ "nbformat_minor": 2
1871
+ }
fastapi_app/__init__.py ADDED
File without changes
fastapi_app/main.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import HTMLResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from pydantic import BaseModel
5
+ import uvicorn
6
+ import os, sys
7
+
8
+ # Add the root directory to sys.path
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
10
+ from model_pipeline.model_predict import load_model, predict as initial_predict
11
+ from llama_pipeline.llama_predict import predict as llama_predict
12
+ from db_connection import insert_db
13
+ from logging_config.logger_config import get_logger
14
+
15
+ # Initialize the FastAPI app
16
+ app = FastAPI()
17
+
18
+ # Initialize the logger
19
+ logger = get_logger(__name__)
20
+
21
+ # Load the latest model at startup
22
+ model = load_model()
23
+
24
+ # Mount the static files directory
25
+ app.mount("/static", StaticFiles(directory="fastapi_app/static"), name="static")
26
+
27
+ @app.get("/", response_class=HTMLResponse)
28
+ def read_root():
29
+ with open("fastapi_app/static/index.html") as f:
30
+ html_content = f.read()
31
+ return HTMLResponse(content=html_content, status_code=200)
32
+
33
+ @app.get("/health")
34
+ def health_check():
35
+ logger.info("Health check endpoint accessed.")
36
+ return {"status": "ok"}
37
+
38
+ class TextInput(BaseModel):
39
+ text: str
40
+
41
+ class PredictionInput(BaseModel):
42
+ text: str
43
+ initial_prediction: str
44
+ llama_category: str
45
+ llama_explanation: str
46
+ user_rating: int
47
+
48
+ @app.post("/predict_sentiment")
49
+ def predict_sentiment(input_data: TextInput):
50
+ logger.info(f"Prediction request received with text: {input_data.text}")
51
+
52
+ # Initial model prediction
53
+ initial_prediction = initial_predict(input_data.text, model = model)
54
+
55
+ # LLaMA 3 prediction
56
+ llama_prediction = llama_predict(input_data.text)
57
+
58
+ # Prepare response
59
+ response = {
60
+ "text": input_data.text,
61
+ "initial_prediction": initial_prediction,
62
+ "llama_category": llama_prediction['Category'],
63
+ "llama_explanation": llama_prediction['Explanation']
64
+ }
65
+
66
+ logger.info(f"Prediction response: {response}")
67
+ return response
68
+
69
+ @app.post("/submit_interaction")
70
+ def submit_interaction(data: PredictionInput):
71
+ logger.info(f"Received interaction data: {data}")
72
+ logger.info(f"Received text: {data.text}")
73
+ logger.info(f"Received initial_prediction: {data.initial_prediction}")
74
+ logger.info(f"Received llama_category: {data.llama_category}")
75
+ logger.info(f"Received llama_explanation: {data.llama_explanation}")
76
+ logger.info(f"Received user_rating: {data.user_rating}")
77
+
78
+ interaction_data = {
79
+ "Input_text": data.text,
80
+ "Model_prediction": data.initial_prediction,
81
+ "Llama_3_Prediction": data.llama_category,
82
+ "Llama_3_Explanation": data.llama_explanation,
83
+ "User Rating": data.user_rating,
84
+ }
85
+
86
+ response = insert_db(interaction_data)
87
+ logger.info(f"Database response: {response}")
88
+ return {"status": "success", "response": response}
89
+
90
+ if __name__ == "__main__":
91
+ uvicorn.run(app, host="0.0.0.0", port=8000)
fastapi_app/static/index.html ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Sentiment Analysis for Mental Health</title>
7
+ <link rel="stylesheet" href="/static/style.css">
8
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css">
9
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/css/all.min.css">
10
+ </head>
11
+ <body>
12
+ <div class="container animate__animated animate__fadeIn">
13
+ <h1>Sentiment Analysis for Mental Health</h1>
14
+ <p>This application uses a machine learning model to analyze the sentiment of text data related to mental health. It helps in understanding the sentiment expressed in user-generated content, such as social media posts or survey responses.</p>
15
+ <p>Enter a sentence below to predict its sentiment as Normal, Depression, Suicidal, Anxiety, Stress, Bi-Polar, or Personality Disorder.</p>
16
+
17
+ <textarea id="textInput" rows="4" placeholder="Enter your text here..." class="animate__animated animate__fadeInLeft"></textarea>
18
+ <button onclick="predictSentiment()" class="animate__animated animate__fadeInRight">Predict Sentiment</button>
19
+
20
+ <div id="results" class="animate__animated animate__fadeInUp">
21
+ <h2>Results:</h2>
22
+ <p><strong>Initial Model Prediction:</strong> <span id="initialPrediction"></span></p>
23
+ <p><strong>LLaMA 3 Category:</strong> <span id="llamaCategory"></span></p>
24
+ <p><strong>LLaMA 3 Explanation:</strong> <span id="llamaExplanation"></span></p>
25
+
26
+ <h2>Rate the Accuracy of the Prediction:</h2>
27
+ <div class="rating animate__animated animate__fadeIn">
28
+ <i class="fas fa-star" onclick="rate(1)"></i>
29
+ <i class="fas fa-star" onclick="rate(2)"></i>
30
+ <i class="fas fa-star" onclick="rate(3)"></i>
31
+ <i class="fas fa-star" onclick="rate(4)"></i>
32
+ <i class="fas fa-star" onclick="rate(5)"></i>
33
+ </div>
34
+ <input type="hidden" id="userRating" value="0">
35
+ </div>
36
+
37
+ <button onclick="submitInteraction()" class="animate__animated animate__fadeIn">Submit Rating</button>
38
+ </div>
39
+ <script src="/static/script.js"></script>
40
+ </body>
41
+ </html>
fastapi_app/static/script.js ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ async function predictSentiment() {
2
+ const textInput = document.getElementById("textInput").value;
3
+ const response = await fetch("/predict_sentiment", {
4
+ method: "POST",
5
+ headers: {
6
+ "Content-Type": "application/json"
7
+ },
8
+ body: JSON.stringify({ text: textInput })
9
+ });
10
+ const result = await response.json();
11
+ document.getElementById("initialPrediction").innerText = result.initial_prediction;
12
+ document.getElementById("llamaCategory").innerText = result.llama_category;
13
+ document.getElementById("llamaExplanation").innerText = result.llama_explanation;
14
+ }
15
+
16
+ function rate(rating) {
17
+ document.getElementById("userRating").value = rating;
18
+ const stars = document.querySelectorAll(".rating .fa-star");
19
+ stars.forEach((star, index) => {
20
+ star.classList.toggle("selected", index < rating);
21
+ });
22
+ }
23
+
24
+ async function submitInteraction() {
25
+ const textInput = document.getElementById("textInput").value;
26
+ const initialPrediction = document.getElementById("initialPrediction").innerText;
27
+ const llamaCategory = document.getElementById("llamaCategory").innerText;
28
+ const llamaExplanation = document.getElementById("llamaExplanation").innerText;
29
+ const userRating = document.getElementById("userRating").value;
30
+
31
+ const data = {
32
+ text: textInput,
33
+ initial_prediction: initialPrediction,
34
+ llama_category: llamaCategory,
35
+ llama_explanation: llamaExplanation,
36
+ user_rating: parseInt(userRating),
37
+ };
38
+
39
+ // display the data in the console
40
+ console.log(data);
41
+
42
+ const response = await fetch("/submit_interaction", {
43
+ method: "POST",
44
+ headers: {
45
+ "Content-Type": "application/json"
46
+ },
47
+ body: JSON.stringify(data)
48
+ });
49
+
50
+ const result = await response.json();
51
+ alert("Thank you for your feedback!");
52
+ }
53
+
fastapi_app/static/style.css ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: Arial, sans-serif;
3
+ background: linear-gradient(to right, #6a11cb, #2575fc);
4
+ display: flex;
5
+ justify-content: center;
6
+ align-items: center;
7
+ height: 100vh;
8
+ margin: 0;
9
+ color: #fff;
10
+ text-align: center;
11
+ }
12
+
13
+ .container {
14
+ background-color: rgba(255, 255, 255, 0.1);
15
+ padding: 20px;
16
+ border-radius: 10px;
17
+ box-shadow: 0 0 20px rgba(0, 0, 0, 0.2);
18
+ width: 80%;
19
+ max-width: 600px;
20
+ }
21
+
22
+ h1 {
23
+ color: #fff;
24
+ }
25
+
26
+ textarea {
27
+ width: 100%;
28
+ padding: 10px;
29
+ margin: 10px 0;
30
+ border-radius: 8px;
31
+ border: none;
32
+ outline: none;
33
+ background-color: rgba(255, 255, 255, 0.2);
34
+ color: #fff;
35
+ }
36
+
37
+ button {
38
+ padding: 10px 20px;
39
+ background-color: #28a745;
40
+ color: white;
41
+ border: none;
42
+ border-radius: 8px;
43
+ cursor: pointer;
44
+ transition: background-color 0.3s ease;
45
+ margin-top: 20px;
46
+ }
47
+
48
+ button:hover {
49
+ background-color: #218838;
50
+ }
51
+
52
+ #results {
53
+ margin-top: 20px;
54
+ text-align: left;
55
+ color: #fff;
56
+ }
57
+
58
+ .rating {
59
+ display: flex;
60
+ justify-content: center;
61
+ align-items: center;
62
+ font-size: 2em;
63
+ }
64
+
65
+ .rating .fa-star {
66
+ cursor: pointer;
67
+ color: #ccc;
68
+ transition: color 0.3s;
69
+ }
70
+
71
+ .rating .fa-star:hover,
72
+ .rating .fa-star:hover ~ .fa-star {
73
+ color: #ffd700;
74
+ }
75
+
76
+ .rating .fa-star.selected {
77
+ color: #ffd700;
78
+ }
image.png ADDED
llama_pipeline/__pycache__/llama_predict.cpython-310.pyc ADDED
Binary file (3.93 kB). View file
 
llama_pipeline/llama_predict.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os, sys
3
+ from langchain_groq import ChatGroq
4
+ from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
5
+ from langchain_core.prompts.prompt import PromptTemplate
6
+
7
+ # Add the root directory to sys.path
8
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+ from logging_config.logger_config import get_logger
10
+
11
+ # Get the logger
12
+ logger = get_logger(__name__)
13
+
14
+ # environment variables
15
+ load_dotenv()
16
+ groq_api_key=os.getenv('GROQ_API_KEY')
17
+
18
+ # initialize the ChatGroq object
19
+ llm=ChatGroq(groq_api_key=groq_api_key,
20
+ model_name="Llama3-8b-8192")
21
+
22
+ # Sentiment Classification
23
+ def sentiment_analyzer(input_text: str) -> str:
24
+ template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
25
+ You are a highly specialized AI trained in clinical psychology and mental health assessment. Your task is to analyze textual input and categorize it into one of the following mental health conditions:
26
+ - Normal
27
+ - Depression
28
+ - Suicidal
29
+ - Anxiety
30
+ - Stress
31
+ - Bi-Polar
32
+ - Personality Disorder
33
+
34
+ Your analysis should be based on clinical symptoms and diagnostic criteria commonly used in mental health practice. Here are some detailed examples:
35
+
36
+ Example 1:
37
+ Text: "I feel an overwhelming sense of sadness and hopelessness. I have lost interest in activities I once enjoyed and find it hard to get out of bed."
38
+ Category: Depression
39
+
40
+ Example 2:
41
+ Text: "I constantly worry about various aspects of my life. My heart races, and I experience physical symptoms like sweating and trembling even when there is no apparent danger."
42
+ Category: Anxiety
43
+
44
+ Example 3:
45
+ Text: "I have thoughts about ending my life. I feel that there is no other way to escape my pain, and I often think about how I might end it."
46
+ Category: Suicidal
47
+
48
+ Example 4:
49
+ Text: "I feel extremely stressed and overwhelmed by my responsibilities. I find it difficult to relax, and I often experience headaches and tension."
50
+ Category: Stress
51
+
52
+ Example 5:
53
+ Text: "I go through periods of extreme happiness and high energy, followed by episodes of deep depression and low energy. These mood swings affect my daily functioning."
54
+ Category: Bi-Polar
55
+
56
+ Example 6:
57
+ Text: "I have trouble maintaining stable relationships and often experience intense emotional reactions. My self-image frequently changes, and I engage in impulsive behaviors."
58
+ Category: Personality Disorder
59
+
60
+ Example 7:
61
+ Text: "I feel generally content and am able to manage my daily activities without significant distress or impairment."
62
+ Category: Normal
63
+
64
+
65
+ Return as out the Category and a brief explanation of your decision in a json format.
66
+
67
+ Now, analyze the following text and determine the most appropriate category from the list above:
68
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
69
+ Human: {input_text}
70
+ <|eot_id|><|start_header_id|>assistant<|end_header_id|>
71
+ AI Assistant:"""
72
+
73
+ sentiment_prompt = PromptTemplate(input_variables=["input_text"], template=template)
74
+ initiator_router = sentiment_prompt | llm | JsonOutputParser()
75
+ output = initiator_router.invoke({"input_text":input_text})
76
+ return output
77
+
78
+
79
+ # making predictions
80
+ def predict(text: str) -> str:
81
+ try:
82
+ logger.info("Making prediction...")
83
+ prediction = sentiment_analyzer(text)
84
+ logger.info(f"Prediction: {prediction}")
85
+ return prediction
86
+ except Exception as e:
87
+ logger.error(f"An error occurred while making the prediction: {e}")
88
+ return str('The prediction could not be made due to an error., Please try again later.')
89
+
90
+ if __name__ == "__main__":
91
+ # Example text input
92
+ example_text = "I feel incredibly anxious about everything and can't stop worrying"
93
+
94
+ # Make a prediction
95
+ prediction = predict(example_text)
96
+ print(prediction)
logging_config/__init__.py ADDED
File without changes
logging_config/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
logging_config/__pycache__/logger_config.cpython-310.pyc ADDED
Binary file (913 Bytes). View file
 
logging_config/logger_config.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ # Ensure the log directory exists
5
+ log_directory = 'logs'
6
+ os.makedirs(log_directory, exist_ok=True)
7
+
8
+ # Define the logging configuration
9
+ logging.basicConfig(
10
+ filename=os.path.join(log_directory, 'app.log'),
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s',
13
+ datefmt='%Y-%m-%d %H:%M:%S'
14
+ )
15
+
16
+ # Get a custom logger
17
+ def get_logger(name):
18
+ logger = logging.getLogger(name)
19
+ logger.setLevel(logging.DEBUG)
20
+
21
+ if not logger.hasHandlers():
22
+ # Create a file handler
23
+ file_handler = logging.FileHandler('logs/app.log')
24
+ file_handler.setLevel(logging.DEBUG)
25
+
26
+ # Create a console handler
27
+ console_handler = logging.StreamHandler()
28
+ console_handler.setLevel(logging.DEBUG)
29
+
30
+ # Create a logging format
31
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32
+ file_handler.setFormatter(formatter)
33
+ console_handler.setFormatter(formatter)
34
+
35
+ # Add the handlers to the logger
36
+ logger.addHandler(file_handler)
37
+ logger.addHandler(console_handler)
38
+
39
+ return logger
logs/app.log ADDED
The diff for this file is too large to render. See raw diff
 
model_pipeline/__init__.py ADDED
File without changes
model_pipeline/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
model_pipeline/__pycache__/model_predict.cpython-310.pyc ADDED
Binary file (2.98 kB). View file
 
model_pipeline/model_predict.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import string
5
+ import joblib
6
+ import pandas as pd
7
+ import numpy as np
8
+ from nltk.corpus import stopwords
9
+ from nltk.stem import WordNetLemmatizer
10
+ import nltk
11
+ from glob import glob
12
+
13
+ # Add the root directory to sys.path
14
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
15
+ from logging_config.logger_config import get_logger
16
+
17
+ # Download necessary NLTK data files
18
+ nltk.download('stopwords')
19
+ nltk.download('wordnet')
20
+
21
+ # Get the logger
22
+ logger = get_logger(__name__)
23
+
24
+ # Custom Preprocessor Class
25
+ class TextPreprocessor:
26
+ def __init__(self):
27
+ self.stop_words = set(stopwords.words('english'))
28
+ self.lemmatizer = WordNetLemmatizer()
29
+ logger.info("TextPreprocessor initialized.")
30
+
31
+ def preprocess_text(self, text):
32
+ logger.info(f"Original text: {text}")
33
+ # Lowercase the text
34
+ text = text.lower()
35
+ logger.info(f"Lowercased text: {text}")
36
+
37
+ # Remove punctuation
38
+ text = re.sub(f'[{re.escape(string.punctuation)}]', '', text)
39
+ logger.info(f"Text after punctuation removal: {text}")
40
+
41
+ # Remove numbers
42
+ text = re.sub(r'\d+', '', text)
43
+ logger.info(f"Text after number removal: {text}")
44
+
45
+ # Tokenize the text
46
+ words = text.split()
47
+ logger.info(f"Tokenized text: {words}")
48
+
49
+ # Remove stopwords and apply lemmatization
50
+ words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]
51
+ logger.info(f"Text after stopword removal and lemmatization: {words}")
52
+
53
+ # Join words back into a single string
54
+ cleaned_text = ' '.join(words)
55
+ logger.info(f"Cleaned text: {cleaned_text}")
56
+
57
+ return cleaned_text
58
+
59
+ def get_latest_model_path(models_dir='./models'):
60
+ model_files = glob(os.path.join(models_dir, 'model_v*.joblib'))
61
+ if not model_files:
62
+ logger.error("No model files found in the models directory.")
63
+ raise FileNotFoundError("No model files found in the models directory.")
64
+
65
+ latest_model_file = max(model_files, key=os.path.getctime)
66
+ logger.info(f"Latest model file found: {latest_model_file}")
67
+ return latest_model_file
68
+
69
+ def load_model():
70
+ model_path = get_latest_model_path()
71
+ logger.info(f"Loading model from {model_path}")
72
+ return joblib.load(model_path)
73
+
74
+ def predict(text, model):
75
+ # Initialize the text preprocessor
76
+ preprocessor = TextPreprocessor()
77
+
78
+ # Preprocess the input text
79
+ logger.info("Preprocessing input text...")
80
+ cleaned_text = preprocessor.preprocess_text(text)
81
+
82
+ # Make a prediction
83
+ logger.info("Making prediction...")
84
+ prediction = model.predict([cleaned_text])
85
+
86
+ logger.info(f"Prediction: {prediction}")
87
+ return prediction[0]
88
+
89
+ if __name__ == "__main__":
90
+ # Example text input
91
+ example_text = "I love programming in Python."
92
+
93
+ # Load the latest model
94
+ model = load_model()
95
+
96
+ # Make a prediction
97
+ prediction = predict(example_text, model)
98
+
99
+ # Print the prediction
100
+ print(f"Prediction: {prediction}")
model_pipeline/model_trainer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pandas as pd
4
+ import joblib
5
+ from datetime import datetime
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.linear_model import LogisticRegression
9
+ from sklearn.pipeline import Pipeline
10
+ from sklearn.metrics import classification_report, accuracy_score
11
+
12
+ # Add the root directory to sys.path
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
14
+
15
+ from logging_config.logger_config import get_logger
16
+
17
+ # Get the logger
18
+ logger = get_logger(__name__)
19
+
20
+ def load_data(file_path):
21
+ logger.info(f"Loading data from {file_path}")
22
+ return pd.read_csv(file_path)
23
+
24
+ def train_model(data):
25
+ logger.info("Starting model training...")
26
+ # check for missing values
27
+ if data.isnull().sum().sum() > 0:
28
+ logger.error("Missing values found in the dataset.")
29
+ # Drop missing values
30
+ data.dropna(inplace=True)
31
+ logger.info("Missing values dropped.")
32
+ # checking the shape of the data
33
+ logger.info(f"Data shape: {data.shape}")
34
+
35
+ # Split data into features and labels
36
+ X = data['cleaned_statement']
37
+ y = data['status'] # Assuming 'sentiment' is the target column
38
+
39
+ # Split data into training and test sets
40
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
41
+
42
+ # Create a pipeline with TF-IDF Vectorizer and Logistic Regression
43
+ pipeline = Pipeline([
44
+ ('tfidf', TfidfVectorizer()),
45
+ ('clf', LogisticRegression())
46
+ ])
47
+
48
+ # Train the pipeline
49
+ pipeline.fit(X_train, y_train)
50
+ logger.info("Model training completed.")
51
+
52
+ # Make predictions
53
+ y_pred = pipeline.predict(X_test)
54
+
55
+ # Evaluate the model
56
+ accuracy = accuracy_score(y_test, y_pred)
57
+ report = classification_report(y_test, y_pred)
58
+
59
+ logger.info(f"Accuracy: {accuracy}")
60
+ logger.info(f"Classification Report:\n{report}")
61
+
62
+ return pipeline, accuracy, report
63
+
64
+ def save_model(pipeline, version):
65
+ # Create the models directory if it doesn't exist
66
+ os.makedirs('./models', exist_ok=True)
67
+
68
+ # Save the pipeline with versioning
69
+ model_filename = f'model_v{version}.joblib'
70
+ model_path = os.path.join('models', model_filename)
71
+ joblib.dump(pipeline, model_path)
72
+ logger.info(f"Model saved as {model_path}")
73
+
74
+ if __name__ == "__main__":
75
+ # Path to the cleaned dataset
76
+ cleaned_data_path = os.path.join('./data', 'cleaned_data.csv')
77
+
78
+ # Load the data
79
+ data = load_data(cleaned_data_path)
80
+
81
+ # Train the model
82
+ pipeline, accuracy, report = train_model(data)
83
+
84
+ # Define the model version based on the current datetime
85
+ version = datetime.now().strftime("%Y%m%d%H%M%S")
86
+
87
+ # Save the model
88
+ save_model(pipeline, version)
89
+
90
+ # Print the results
91
+ print(f"Model version: {version}")
92
+ print(f"Accuracy: {accuracy}")
93
+ print(f"Classification Report:\n{report}")
models/model_v20240717014315.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ac9de0ec670007da30e6e8ef8433064f1f347e1e94ef861d4fdaa871cd310d5
3
+ size 5326113
new_experiement.ipynb ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from dotenv import load_dotenv\n",
10
+ "import os\n",
11
+ "from langchain_groq import ChatGroq\n",
12
+ "from langchain_core.output_parsers import StrOutputParser\n",
13
+ "from langchain_core.prompts.prompt import PromptTemplate"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 3,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "groq_api_key=os.getenv('GROQ_API_KEY')"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 4,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "llm=ChatGroq(groq_api_key=groq_api_key,\n",
32
+ " model_name=\"Llama3-8b-8192\")"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 7,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "def sentiment_analyzer(input_text: str) -> list:\n",
42
+ " template = \"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
43
+ " You are a highly specialized AI trained in clinical psychology and mental health assessment. Your task is to analyze textual input and categorize it into one of the following mental health conditions:\n",
44
+ " - Normal\n",
45
+ " - Depression\n",
46
+ " - Suicidal\n",
47
+ " - Anxiety\n",
48
+ " - Stress\n",
49
+ " - Bi-Polar\n",
50
+ " - Personality Disorder\n",
51
+ "\n",
52
+ " Your analysis should be based on clinical symptoms and diagnostic criteria commonly used in mental health practice. Here are some detailed examples:\n",
53
+ "\n",
54
+ " Example 1:\n",
55
+ " Text: \"I feel an overwhelming sense of sadness and hopelessness. I have lost interest in activities I once enjoyed and find it hard to get out of bed.\"\n",
56
+ " Category: Depression\n",
57
+ "\n",
58
+ " Example 2:\n",
59
+ " Text: \"I constantly worry about various aspects of my life. My heart races, and I experience physical symptoms like sweating and trembling even when there is no apparent danger.\"\n",
60
+ " Category: Anxiety\n",
61
+ "\n",
62
+ " Example 3:\n",
63
+ " Text: \"I have thoughts about ending my life. I feel that there is no other way to escape my pain, and I often think about how I might end it.\"\n",
64
+ " Category: Suicidal\n",
65
+ "\n",
66
+ " Example 4:\n",
67
+ " Text: \"I feel extremely stressed and overwhelmed by my responsibilities. I find it difficult to relax, and I often experience headaches and tension.\"\n",
68
+ " Category: Stress\n",
69
+ "\n",
70
+ " Example 5:\n",
71
+ " Text: \"I go through periods of extreme happiness and high energy, followed by episodes of deep depression and low energy. These mood swings affect my daily functioning.\"\n",
72
+ " Category: Bi-Polar\n",
73
+ "\n",
74
+ " Example 6:\n",
75
+ " Text: \"I have trouble maintaining stable relationships and often experience intense emotional reactions. My self-image frequently changes, and I engage in impulsive behaviors.\"\n",
76
+ " Category: Personality Disorder\n",
77
+ "\n",
78
+ " Example 7:\n",
79
+ " Text: \"I feel generally content and am able to manage my daily activities without significant distress or impairment.\"\n",
80
+ " Category: Normal\n",
81
+ "\n",
82
+ " Now, analyze the following text and determine the most appropriate category from the list above, and return the Category and a brief explanation of your decision:\n",
83
+ " <|eot_id|><|start_header_id|>user<|end_header_id|>\n",
84
+ " Human: {input_text}\n",
85
+ " <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
86
+ " AI Assistant:\"\"\"\n",
87
+ "\n",
88
+ " sentiment_prompt = PromptTemplate(input_variables=[\"input_text\"], template=template)\n",
89
+ " initiator_router = sentiment_prompt | llm | StrOutputParser()\n",
90
+ " output = initiator_router.invoke({\"input_text\":input_text})\n",
91
+ " return output\n"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 8,
97
+ "metadata": {},
98
+ "outputs": [
99
+ {
100
+ "name": "stdout",
101
+ "output_type": "stream",
102
+ "text": [
103
+ "Category: Anxiety\n",
104
+ "\n",
105
+ "Explanation: The text indicates that the person is experiencing excessive and persistent worry, which is a hallmark symptom of anxiety disorder. The phrase \"can't stop worrying\" suggests that the individual is unable to control their worries, which is a common feature of anxiety disorders. Additionally, the phrase \"anxious about everything\" implies that the person is experiencing a pervasive and excessive anxiety that is interfering with their daily life. While anxiety can be a normal response to stressful situations, the severity and pervasiveness described in the text suggest that it may be a clinical concern.\n"
106
+ ]
107
+ }
108
+ ],
109
+ "source": [
110
+ "sentiment = sentiment_analyzer(\"I feel incredibly anxious about everything and can't stop worrying\")\n",
111
+ "print(sentiment)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 12,
117
+ "metadata": {},
118
+ "outputs": [
119
+ {
120
+ "name": "stdout",
121
+ "output_type": "stream",
122
+ "text": [
123
+ "Category: Anxiety\n",
124
+ "\n",
125
+ "My assessment is based on the following symptoms mentioned in the text:\n",
126
+ "\n",
127
+ "* \"Constantly worried\" - This suggests that the individual is experiencing excessive and persistent worry, which is a hallmark symptom of anxiety.\n",
128
+ "* \"Can't seem to find any peace\" - This implies a sense of perpetual unease and inability to relax, which is also characteristic of anxiety.\n",
129
+ "* \"Disturbed sleep\" - Sleep disturbances are a common symptom of anxiety, often caused by racing thoughts and difficulty relaxing.\n",
130
+ "* \"Overwhelmed by even the smallest tasks\" - This suggests that the individual is experiencing feelings of excessive anxiety and difficulty coping with everyday activities, which is another common symptom of anxiety.\n",
131
+ "\n",
132
+ "Overall, the text suggests that the individual is experiencing symptoms that are consistent with an anxiety disorder, such as generalized anxiety disorder or anxiety disorder not otherwise specified.\n"
133
+ ]
134
+ }
135
+ ],
136
+ "source": [
137
+ "sentiment = sentiment_analyzer(\"I feel like everything is falling apart around me. I'm constantly worried and can't seem to find any peace. My sleep is disturbed, and I often feel overwhelmed by even the smallest tasks.\")\n",
138
+ "print(sentiment)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": [
145
+ "## Connecting to database"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 13,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "import os\n",
155
+ "from supabase import create_client, Client\n",
156
+ "\n",
157
+ "url: str = os.environ.get(\"SUPABASE_PROJECT_URL\")\n",
158
+ "key: str = os.environ.get(\"SUPABASE_API_KEY\")\n",
159
+ "supabase: Client = create_client(url, key)"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 17,
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "# signinh in with email and password\n",
169
+ "data = supabase.table(\"Interaction History\").select(\"*\").execute()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 18,
175
+ "metadata": {},
176
+ "outputs": [
177
+ {
178
+ "data": {
179
+ "text/plain": [
180
+ "APIResponse[~_ReturnT](data=[], count=None)"
181
+ ]
182
+ },
183
+ "execution_count": 18,
184
+ "metadata": {},
185
+ "output_type": "execute_result"
186
+ }
187
+ ],
188
+ "source": [
189
+ "data"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 20,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "new_row = {\n",
199
+ " \"Input_text\" : \"I feel incredibly anxious about everything and can't stop worrying\",\n",
200
+ " \"Model_prediction\" : \"Anxiety\",\n",
201
+ " \"Llama_3_Prediction\" : \"Anxiety\",\n",
202
+ " \"Llama_3_Explanation\" : \"Anxiety\",\n",
203
+ " \"User Rating\" : 5,\n",
204
+ "}\n",
205
+ "\n",
206
+ "data = supabase.table(\"Interaction History\").insert(new_row).execute()\n",
207
+ "\n",
208
+ "# Assert we pulled real data.\n",
209
+ "assert len(data.data) > 0"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": 21,
215
+ "metadata": {},
216
+ "outputs": [
217
+ {
218
+ "data": {
219
+ "text/plain": [
220
+ "APIResponse[~_ReturnT](data=[{'id': 2, 'Input_text': \"I feel incredibly anxious about everything and can't stop worrying\", 'Model_prediction': 'Anxiety', 'Llama_3_Prediction': 'Anxiety', 'User Rating': 5, 'Llama_3_Explanation': 'Anxiety'}], count=None)"
221
+ ]
222
+ },
223
+ "execution_count": 21,
224
+ "metadata": {},
225
+ "output_type": "execute_result"
226
+ }
227
+ ],
228
+ "source": [
229
+ "data"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": []
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": []
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": []
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": []
259
+ }
260
+ ],
261
+ "metadata": {
262
+ "kernelspec": {
263
+ "display_name": "Python 3",
264
+ "language": "python",
265
+ "name": "python3"
266
+ },
267
+ "language_info": {
268
+ "codemirror_mode": {
269
+ "name": "ipython",
270
+ "version": 3
271
+ },
272
+ "file_extension": ".py",
273
+ "mimetype": "text/x-python",
274
+ "name": "python",
275
+ "nbconvert_exporter": "python",
276
+ "pygments_lexer": "ipython3",
277
+ "version": "3.10.14"
278
+ }
279
+ },
280
+ "nbformat": 4,
281
+ "nbformat_minor": 2
282
+ }
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.1
2
+ scikit-learn==1.4.2
3
+ pandas==2.2.2
4
+ uvicorn==0.30.1
5
+ notebook==7.2.1
6
+ nltk==3.8.1
7
+ langchain_community==0.2.7
8
+ langchain==0.2.9
9
+ langchain_groq==0.1.6
10
+ langchain_core==0.2.21
11
+ llama-parse==0.4.9
12
+ python-dotenv==1.0.1
13
+ groq==0.9.0
14
+ supabase==2.5.3
todo.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 1. connect a database
2
+ 2. save the data to a database
3
+ 3. add llm api_prediction to the project using Groq and llama 3 8b
4
+ 4. Update the output into two outputs, one for the model prediction, the other for llm prediction.
utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import re
4
+ import string
5
+ import nltk
6
+ from nltk.corpus import stopwords
7
+ from nltk.stem import PorterStemmer, WordNetLemmatizer
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ from sklearn.ensemble import RandomForestClassifier
10
+ from sklearn.base import BaseEstimator, TransformerMixin
11
+ from sklearn.pipeline import Pipeline
12
+
13
+ # Download necessary NLTK data files
14
+ nltk.download('stopwords')
15
+ nltk.download('wordnet')
16
+
17
+ # Custom transformer for text preprocessing
18
+ class TextPreprocessor(BaseEstimator, TransformerMixin):
19
+ def __init__(self):
20
+ self.stop_words = set(stopwords.words('english'))
21
+ self.lemmatizer = WordNetLemmatizer()
22
+
23
+ def preprocess_text(self, text):
24
+ # Lowercase the text
25
+ text = text.lower()
26
+
27
+ # Remove punctuation
28
+ text = re.sub(f'[{re.escape(string.punctuation)}]', '', text)
29
+
30
+ # Remove numbers
31
+ text = re.sub(r'\d+', '', text)
32
+
33
+ # Tokenize the text
34
+ words = text.split()
35
+
36
+ # Remove stopwords and apply lemmatization
37
+ words = [self.lemmatizer.lemmatize(word) for word in words if word not in self.stop_words]
38
+
39
+ # Join words back into a single string
40
+ cleaned_text = ' '.join(words)
41
+
42
+ return cleaned_text
43
+
44
+ def fit(self, X, y=None):
45
+ return self
46
+
47
+ def transform(self, X, y=None):
48
+ return [self.preprocess_text(text) for text in X]
49
+
50
+
51
+ # Model pipeline
52
+ pipeline = Pipeline([
53
+ ('preprocessor', TextPreprocessor()),
54
+ ('vectorizer', TfidfVectorizer()),
55
+ ('classifier', RandomForestClassifier())
56
+ ])
57
+