Markndrei commited on
Commit
f460ec4
Β·
verified Β·
1 Parent(s): 256cd18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -85
app.py CHANGED
@@ -1,85 +1,85 @@
1
- import pandas as pd
2
- import numpy as np
3
- from sklearn.ensemble import RandomForestClassifier
4
- from sklearn.model_selection import train_test_split
5
- from sklearn.metrics import accuracy_score, classification_report
6
- import streamlit as st
7
- import altair as alt
8
-
9
- try:
10
- # Load the data
11
- df = pd.read_csv("fraud_data.csv")
12
-
13
- # Prepare the data for the model
14
- X = df[['TransactionAmount', 'CustomerAge', 'TransactionFrequency']]
15
- y = df['IsFraud']
16
-
17
- except FileNotFoundError:
18
- st.write("Error: Data file not found.")
19
- st.stop()
20
-
21
- except Exception as e:
22
- st.write(f"An error occurred: {e}")
23
- st.stop()
24
-
25
- # Split the data into training and testing sets
26
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
27
-
28
- # Create and train a Random Forest Classifier model
29
- model = RandomForestClassifier(n_estimators=100, random_state=42)
30
- model.fit(X_train, y_train)
31
-
32
- # Make predictions on the testing set
33
- y_pred = model.predict(X_test)
34
-
35
- # Evaluate the model's performance
36
- accuracy = accuracy_score(y_test, y_pred)
37
- report = classification_report(y_test, y_pred, output_dict=True)
38
-
39
- # Create a Streamlit app
40
- st.title("Fraud Detection System")
41
-
42
- # Create tabs
43
- tab1, tab2, tab3 = st.tabs(["Data Visualization", "Model Performance", "Fraud Prediction"])
44
-
45
- # Tab 1: Data Visualization
46
- with tab1:
47
- st.write("### Fraud Data")
48
- st.write(df)
49
-
50
- # Scatter plot
51
- st.write("### Scatter Plot of Features")
52
- for col in ['TransactionAmount', 'CustomerAge', 'TransactionFrequency']:
53
- st.write(f"**{col} vs Fraudulent Transactions**")
54
- st.altair_chart(
55
- alt.Chart(df).mark_circle().encode(
56
- x=col,
57
- y='IsFraud',
58
- tooltip=[col, 'IsFraud']
59
- ).interactive(),
60
- use_container_width=True
61
- )
62
-
63
- # Tab 2: Model Performance
64
- with tab2:
65
- st.write("### Model Performance")
66
- st.write(f"Accuracy: {accuracy:.2f}")
67
- st.write("Classification Report:")
68
- st.json(report)
69
-
70
- # Tab 3: Fraud Prediction
71
- with tab3:
72
- st.write("### Predict Fraudulent Transactions")
73
- amount_input = st.number_input("Transaction Amount", min_value=1.0, value=100.0, step=1.0)
74
- age_input = st.number_input("Customer Age", min_value=18, value=30, step=1)
75
- frequency_input = st.slider("Transaction Frequency (past month)", min_value=1, max_value=100, value=5, step=1)
76
-
77
- if st.button("Predict"):
78
- # Create input array for prediction
79
- input_data = [[amount_input, age_input, frequency_input]]
80
-
81
- # Make prediction
82
- prediction = model.predict(input_data)[0]
83
- result = "Fraudulent" if prediction == 1 else "Legitimate"
84
-
85
- st.write(f"### Prediction: {result}")
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.ensemble import RandomForestClassifier
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.metrics import accuracy_score, classification_report
7
+ from datasets import load_dataset
8
+
9
+ # Load dataset from Hugging Face
10
+ dataset = load_dataset("Nooha/cc_fraud_detection_dataset", split="train")
11
+ df = pd.DataFrame(dataset)
12
+
13
+ # Select relevant features and target variable
14
+ X = df[['Amount', 'Time', 'V1', 'V2', 'V3']]
15
+ y = df['Class']
16
+
17
+ # Split dataset into training and testing sets
18
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
19
+
20
+ # Train a RandomForestClassifier model
21
+ model = RandomForestClassifier(n_estimators=100, random_state=42)
22
+ model.fit(X_train, y_train)
23
+
24
+ y_pred = model.predict(X_test)
25
+
26
+ # Model Performance Metrics
27
+ accuracy = accuracy_score(y_test, y_pred)
28
+ class_report_df = pd.DataFrame(classification_report(y_test, y_pred, output_dict=True)).transpose()
29
+
30
+ # Application Title
31
+ st.title('πŸ’³ Credit Card Fraud Detection System')
32
+
33
+ st.markdown(
34
+ """
35
+ ## πŸ“– Introduction
36
+ Welcome to the **Credit Card Fraud Detection System**! This tool analyzes credit card transactions to detect fraudulent activity using a **Random Forest model**.
37
+ """
38
+ )
39
+
40
+ # Tab Structure
41
+ tab1, tab2, tab3 = st.tabs(['πŸ“Š Dataset Preview', 'πŸ“ˆ Model Performance', 'πŸ” Fraud Prediction'])
42
+
43
+ # Dataset Preview
44
+ with tab1:
45
+ st.markdown(
46
+ """
47
+ ## πŸ“Š Dataset Preview
48
+ Below is a sample of the credit card transaction dataset used for fraud detection.
49
+ """
50
+ )
51
+ st.dataframe(df.head())
52
+
53
+ # Model Performance
54
+ with tab2:
55
+ st.markdown(
56
+ """
57
+ ## πŸ“ˆ Model Performance
58
+ - **Accuracy:** Measures overall model performance.
59
+ - **Classification Report:** Precision, recall, and F1-score breakdown.
60
+ """
61
+ )
62
+
63
+ st.write(f"**πŸ“Œ Model Accuracy:** {accuracy:.2%}")
64
+
65
+ st.markdown("### πŸ“‹ Classification Report")
66
+ st.dataframe(class_report_df)
67
+
68
+ # Fraud Prediction
69
+ with tab3:
70
+ st.markdown("""
71
+ ## πŸ” Fraud Prediction
72
+ Enter transaction details below to predict if it's fraudulent.
73
+ """)
74
+
75
+ amount_input = st.number_input("πŸ’΅ Transaction Amount", min_value=0.0, value=100.0, step=1.0)
76
+ time_input = st.number_input("⏳ Transaction Time", min_value=0.0, value=50000.0, step=1000.0)
77
+ v1_input = st.number_input("πŸ”’ Feature V1", value=0.0, step=0.1)
78
+ v2_input = st.number_input("πŸ”’ Feature V2", value=0.0, step=0.1)
79
+ v3_input = st.number_input("πŸ”’ Feature V3", value=0.0, step=0.1)
80
+
81
+ if st.button("πŸ”Ž Predict Fraud"):
82
+ input_data = np.array([[amount_input, time_input, v1_input, v2_input, v3_input]])
83
+ prediction = model.predict(input_data)[0]
84
+ result = "🚨 Fraudulent" if prediction == 1 else "βœ… Legitimate"
85
+ st.success(f"### 🎯 Prediction: **{result}**")