|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from transformers import BertTokenizer, BertModel |
|
from sklearn.preprocessing import StandardScaler, LabelEncoder |
|
from sklearn.ensemble import IsolationForest |
|
import gradio as gr |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
class FraudDetectionTester: |
|
def __init__(self, model_path='fraud_detection_model.pth'): |
|
"""Initialize the fraud detection tester""" |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
self.model_path = model_path |
|
self.model = None |
|
self.scaler = None |
|
self.label_encoder = None |
|
self.isolation_forest = None |
|
|
|
|
|
self.load_model() |
|
|
|
def create_bert_fraud_model(self, numerical_features_dim): |
|
"""Recreate the BERT fraud detection model architecture""" |
|
|
|
class BERTFraudDetector(nn.Module): |
|
def __init__(self, bert_model_name, numerical_features_dim, dropout_rate=0.3): |
|
super(BERTFraudDetector, self).__init__() |
|
|
|
|
|
self.bert = BertModel.from_pretrained(bert_model_name) |
|
|
|
|
|
for param in self.bert.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
for param in self.bert.encoder.layer[-2:].parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
self.text_projection = nn.Linear(self.bert.config.hidden_size, 256) |
|
self.numerical_projection = nn.Linear(numerical_features_dim, 256) |
|
|
|
|
|
self.anomaly_detector = nn.Sequential( |
|
nn.Linear(256, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(128, 64), |
|
nn.ReLU(), |
|
nn.Linear(64, 1) |
|
) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(512 + 1, 256), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(256, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(128, 64), |
|
nn.ReLU(), |
|
nn.Linear(64, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, input_ids, attention_mask, numerical_features): |
|
|
|
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
text_features = self.text_projection(bert_output.pooler_output) |
|
|
|
|
|
numerical_features = self.numerical_projection(numerical_features) |
|
|
|
|
|
anomaly_score = self.anomaly_detector(numerical_features) |
|
|
|
|
|
combined_features = torch.cat([text_features, numerical_features, anomaly_score], dim=1) |
|
|
|
|
|
fraud_probability = self.classifier(combined_features) |
|
|
|
return fraud_probability.squeeze(), anomaly_score.squeeze() |
|
|
|
return BERTFraudDetector('bert-base-uncased', numerical_features_dim) |
|
|
|
def load_model(self): |
|
"""Load the pre-trained fraud detection model""" |
|
try: |
|
print(f"π Loading model from {self.model_path}...") |
|
|
|
|
|
torch.serialization.add_safe_globals([ |
|
StandardScaler, |
|
LabelEncoder, |
|
IsolationForest |
|
]) |
|
|
|
checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False) |
|
|
|
|
|
self.scaler = checkpoint['scaler'] |
|
self.label_encoder = checkpoint['label_encoder'] |
|
self.isolation_forest = checkpoint['isolation_forest'] |
|
|
|
|
|
numerical_features_dim = 14 |
|
self.model = self.create_bert_fraud_model(numerical_features_dim) |
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
print("β
Model loaded successfully!") |
|
|
|
except FileNotFoundError: |
|
print(f"β Error: Model file '{self.model_path}' not found!") |
|
print("Make sure you have trained and saved the model first.") |
|
raise |
|
except Exception as e: |
|
print(f"β Error loading model: {str(e)}") |
|
raise |
|
|
|
def tokenize_descriptions(self, descriptions, max_length=128): |
|
"""Tokenize transaction descriptions for BERT""" |
|
if hasattr(descriptions, 'tolist'): |
|
descriptions = descriptions.tolist() |
|
elif isinstance(descriptions, str): |
|
descriptions = [descriptions] |
|
elif not isinstance(descriptions, list): |
|
descriptions = list(descriptions) |
|
|
|
descriptions = [str(desc) for desc in descriptions] |
|
|
|
encoded = self.tokenizer( |
|
descriptions, |
|
truncation=True, |
|
padding=True, |
|
max_length=max_length, |
|
return_tensors='pt' |
|
) |
|
|
|
return encoded['input_ids'], encoded['attention_mask'] |
|
|
|
def preprocess_single_transaction(self, transaction): |
|
"""Preprocess a single transaction for prediction""" |
|
if isinstance(transaction, dict): |
|
df = pd.DataFrame([transaction]) |
|
else: |
|
df = pd.DataFrame(transaction) |
|
|
|
|
|
df['amount_log'] = np.log1p(df['amount']) |
|
df['is_weekend'] = (df['day_of_week'] >= 5).astype(int) |
|
df['is_night'] = ((df['hour'] >= 22) | (df['hour'] <= 6)).astype(int) |
|
df['high_frequency'] = (df['transaction_count_1h'] > 3).astype(int) |
|
df['amount_deviation'] = abs(df['amount'] - df['avg_amount_1h']) / (df['avg_amount_1h'] + 1) |
|
|
|
|
|
try: |
|
df['merchant_category_encoded'] = self.label_encoder.transform(df['merchant_category']) |
|
except ValueError: |
|
df['merchant_category_encoded'] = 0 |
|
|
|
|
|
numerical_features = ['amount_log', 'hour', 'day_of_week', 'days_since_last_transaction', |
|
'transaction_count_1h', 'transaction_count_24h', 'avg_amount_1h', |
|
'location_risk_score', 'account_age_days', 'merchant_category_encoded', |
|
'is_weekend', 'is_night', 'high_frequency', 'amount_deviation'] |
|
|
|
X_numerical = self.scaler.transform(df[numerical_features]) |
|
|
|
|
|
df['processed_description'] = df['description'].astype(str).str.lower().str.replace(r'[^\w\s]', '', regex=True) |
|
|
|
return df, X_numerical |
|
|
|
def predict_fraud(self, transaction): |
|
"""Predict fraud for a single transaction""" |
|
try: |
|
|
|
df, X_numerical = self.preprocess_single_transaction(transaction) |
|
|
|
|
|
processed_descriptions = df['processed_description'].tolist() |
|
input_ids, attention_masks = self.tokenize_descriptions(processed_descriptions) |
|
|
|
|
|
with torch.no_grad(): |
|
batch_num = torch.tensor(X_numerical).float().to(self.device) |
|
batch_ids = input_ids.to(self.device) |
|
batch_masks = attention_masks.to(self.device) |
|
|
|
fraud_prob, anomaly_score = self.model(batch_ids, batch_masks, batch_num) |
|
|
|
|
|
isolation_pred = self.isolation_forest.decision_function(X_numerical) |
|
|
|
|
|
if isinstance(fraud_prob, torch.Tensor): |
|
if fraud_prob.dim() == 0: |
|
fraud_prob_val = fraud_prob.item() |
|
anomaly_score_val = anomaly_score.item() |
|
else: |
|
fraud_prob_val = fraud_prob[0].item() |
|
anomaly_score_val = anomaly_score[0].item() |
|
else: |
|
fraud_prob_val = float(fraud_prob) |
|
anomaly_score_val = float(anomaly_score) |
|
|
|
|
|
combined_score = (0.6 * fraud_prob_val + |
|
0.3 * (1 - (isolation_pred[0] + 0.5)) + |
|
0.1 * anomaly_score_val) |
|
|
|
return { |
|
'fraud_probability': float(combined_score), |
|
'is_fraud_predicted': bool(combined_score > 0.5), |
|
'risk_level': self.get_risk_level(combined_score), |
|
'anomaly_score': float(anomaly_score_val), |
|
'bert_score': float(fraud_prob_val), |
|
'isolation_score': float(isolation_pred[0]) |
|
} |
|
|
|
except Exception as e: |
|
return {'error': str(e)} |
|
|
|
def get_risk_level(self, score): |
|
"""Determine risk level based on fraud probability""" |
|
if score > 0.8: |
|
return 'CRITICAL' |
|
elif score > 0.6: |
|
return 'HIGH' |
|
elif score > 0.4: |
|
return 'MEDIUM' |
|
elif score > 0.2: |
|
return 'LOW' |
|
else: |
|
return 'MINIMAL' |
|
|
|
|
|
print("Initializing fraud detection model...") |
|
try: |
|
fraud_detector = FraudDetectionTester('fraud_detection_model.pth') |
|
model_loaded = True |
|
except Exception as e: |
|
print(f"Failed to load model: {e}") |
|
model_loaded = False |
|
|
|
def predict_transaction_fraud( |
|
transaction_id, |
|
amount, |
|
merchant_category, |
|
description, |
|
hour, |
|
day_of_week, |
|
days_since_last_transaction, |
|
transaction_count_1h, |
|
transaction_count_24h, |
|
avg_amount_1h, |
|
location_risk_score, |
|
account_age_days |
|
): |
|
"""Gradio interface function for fraud prediction""" |
|
|
|
if not model_loaded: |
|
return "β Model not loaded. Please ensure 'fraud_detection_model.pth' is available.", "", "", "", "", "" |
|
|
|
|
|
transaction = { |
|
'transaction_id': transaction_id, |
|
'amount': amount, |
|
'merchant_category': merchant_category, |
|
'description': description, |
|
'hour': hour, |
|
'day_of_week': day_of_week, |
|
'days_since_last_transaction': days_since_last_transaction, |
|
'transaction_count_1h': transaction_count_1h, |
|
'transaction_count_24h': transaction_count_24h, |
|
'avg_amount_1h': avg_amount_1h, |
|
'location_risk_score': location_risk_score, |
|
'account_age_days': account_age_days |
|
} |
|
|
|
|
|
result = fraud_detector.predict_fraud(transaction) |
|
|
|
if 'error' in result: |
|
return f"β Error: {result['error']}", "", "", "", "", "" |
|
|
|
|
|
fraud_prob = result['fraud_probability'] |
|
prediction = "π¨ FRAUD DETECTED" if result['is_fraud_predicted'] else "β
LEGITIMATE" |
|
risk_level = result['risk_level'] |
|
|
|
|
|
risk_bar = "β" * int(fraud_prob * 20) + "β" * (20 - int(fraud_prob * 20)) |
|
risk_meter = f"[{risk_bar}] {fraud_prob*100:.1f}%" |
|
|
|
|
|
detailed_scores = f""" |
|
π€ BERT Score: {result['bert_score']:.4f} |
|
ποΈ Isolation Score: {result['isolation_score']:.4f} |
|
π Anomaly Score: {result['anomaly_score']:.4f} |
|
""" |
|
|
|
|
|
summary = f""" |
|
π° Amount: ${amount:.2f} |
|
πͺ Category: {merchant_category} |
|
π Description: {description} |
|
π― Fraud Probability: {fraud_prob:.4f} ({fraud_prob*100:.2f}%) |
|
π Risk Level: {risk_level} |
|
""" |
|
|
|
return prediction, f"{fraud_prob:.4f}", risk_level, risk_meter, detailed_scores, summary |
|
|
|
def load_sample_transaction(sample_type): |
|
"""Load predefined sample transactions""" |
|
samples = { |
|
"Normal Grocery Purchase": { |
|
'transaction_id': 'NORMAL_001', |
|
'amount': 45.67, |
|
'merchant_category': 'grocery', |
|
'description': 'WALMART SUPERCENTER CA 1234', |
|
'hour': 14, |
|
'day_of_week': 2, |
|
'days_since_last_transaction': 1.0, |
|
'transaction_count_1h': 1, |
|
'transaction_count_24h': 3, |
|
'avg_amount_1h': 50.0, |
|
'location_risk_score': 0.1, |
|
'account_age_days': 730 |
|
}, |
|
"Suspicious High Amount": { |
|
'transaction_id': 'SUSPICIOUS_001', |
|
'amount': 2999.99, |
|
'merchant_category': 'online', |
|
'description': 'SUSPICIOUS ELECTRONICS STORE XX 9999', |
|
'hour': 3, |
|
'day_of_week': 6, |
|
'days_since_last_transaction': 60.0, |
|
'transaction_count_1h': 12, |
|
'transaction_count_24h': 25, |
|
'avg_amount_1h': 150.0, |
|
'location_risk_score': 0.95, |
|
'account_age_days': 15 |
|
}, |
|
"Coffee Shop Purchase": { |
|
'transaction_id': 'COFFEE_001', |
|
'amount': 8.50, |
|
'merchant_category': 'restaurant', |
|
'description': 'STARBUCKS COFFEE NY 5678', |
|
'hour': 8, |
|
'day_of_week': 1, |
|
'days_since_last_transaction': 0.5, |
|
'transaction_count_1h': 1, |
|
'transaction_count_24h': 4, |
|
'avg_amount_1h': 8.50, |
|
'location_risk_score': 0.2, |
|
'account_age_days': 1095 |
|
}, |
|
"Foreign ATM Withdrawal": { |
|
'transaction_id': 'ATM_001', |
|
'amount': 500.00, |
|
'merchant_category': 'atm', |
|
'description': 'ATM WITHDRAWAL FOREIGN COUNTRY 0000', |
|
'hour': 23, |
|
'day_of_week': 0, |
|
'days_since_last_transaction': 0.1, |
|
'transaction_count_1h': 5, |
|
'transaction_count_24h': 8, |
|
'avg_amount_1h': 200.0, |
|
'location_risk_score': 0.8, |
|
'account_age_days': 365 |
|
} |
|
} |
|
|
|
if sample_type in samples: |
|
sample = samples[sample_type] |
|
return ( |
|
sample['transaction_id'], |
|
sample['amount'], |
|
sample['merchant_category'], |
|
sample['description'], |
|
sample['hour'], |
|
sample['day_of_week'], |
|
sample['days_since_last_transaction'], |
|
sample['transaction_count_1h'], |
|
sample['transaction_count_24h'], |
|
sample['avg_amount_1h'], |
|
sample['location_risk_score'], |
|
sample['account_age_days'] |
|
) |
|
return [""] * 12 |
|
|
|
|
|
with gr.Blocks(title="π¨ Fraud Detection System", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π¨ Advanced Fraud Detection System |
|
### Powered by BERT and Machine Learning |
|
|
|
This system analyzes financial transactions using advanced AI to detect potential fraud. |
|
Enter transaction details below or use sample transactions to test the system. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown("## π Transaction Details") |
|
|
|
|
|
with gr.Row(): |
|
sample_dropdown = gr.Dropdown( |
|
choices=["Normal Grocery Purchase", "Suspicious High Amount", "Coffee Shop Purchase", "Foreign ATM Withdrawal"], |
|
label="π― Load Sample Transaction", |
|
value="Normal Grocery Purchase" |
|
) |
|
load_sample_btn = gr.Button("π₯ Load Sample", variant="secondary") |
|
|
|
|
|
with gr.Row(): |
|
transaction_id = gr.Textbox(label="Transaction ID", value="TEST_001") |
|
amount = gr.Number(label="π° Amount ($)", value=45.67, minimum=0) |
|
|
|
with gr.Row(): |
|
merchant_category = gr.Dropdown( |
|
choices=["grocery", "restaurant", "gas_station", "retail", "online", "atm", "pharmacy", "entertainment"], |
|
label="πͺ Merchant Category", |
|
value="grocery" |
|
) |
|
description = gr.Textbox(label="π Transaction Description", value="WALMART SUPERCENTER CA 1234") |
|
|
|
with gr.Row(): |
|
hour = gr.Slider(label="π Hour of Day", minimum=0, maximum=23, value=14, step=1) |
|
day_of_week = gr.Slider(label="π
Day of Week (0=Mon, 6=Sun)", minimum=0, maximum=6, value=2, step=1) |
|
|
|
with gr.Row(): |
|
days_since_last = gr.Number(label="π Days Since Last Transaction", value=1.0, minimum=0) |
|
transaction_count_1h = gr.Number(label="π’ Transactions (1h)", value=1, minimum=0) |
|
|
|
with gr.Row(): |
|
transaction_count_24h = gr.Number(label="π’ Transactions (24h)", value=3, minimum=0) |
|
avg_amount_1h = gr.Number(label="π΅ Avg Amount (1h)", value=50.0, minimum=0) |
|
|
|
with gr.Row(): |
|
location_risk_score = gr.Slider(label="π Location Risk Score", minimum=0, maximum=1, value=0.1, step=0.01) |
|
account_age_days = gr.Number(label="π€ Account Age (days)", value=730, minimum=0) |
|
|
|
predict_btn = gr.Button("π Analyze Transaction", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## π Fraud Analysis Results") |
|
|
|
prediction_output = gr.Textbox(label="π― Prediction", interactive=False) |
|
fraud_prob_output = gr.Textbox(label="π Fraud Probability", interactive=False) |
|
risk_level_output = gr.Textbox(label="β οΈ Risk Level", interactive=False) |
|
risk_meter_output = gr.Textbox(label="π Risk Meter", interactive=False) |
|
detailed_scores_output = gr.Textbox(label="π Detailed Scores", interactive=False, lines=4) |
|
summary_output = gr.Textbox(label="π Summary", interactive=False, lines=6) |
|
|
|
|
|
predict_btn.click( |
|
fn=predict_transaction_fraud, |
|
inputs=[ |
|
transaction_id, amount, merchant_category, description, hour, day_of_week, |
|
days_since_last, transaction_count_1h, transaction_count_24h, avg_amount_1h, |
|
location_risk_score, account_age_days |
|
], |
|
outputs=[ |
|
prediction_output, fraud_prob_output, risk_level_output, |
|
risk_meter_output, detailed_scores_output, summary_output |
|
] |
|
) |
|
|
|
load_sample_btn.click( |
|
fn=load_sample_transaction, |
|
inputs=[sample_dropdown], |
|
outputs=[ |
|
transaction_id, amount, merchant_category, description, hour, day_of_week, |
|
days_since_last, transaction_count_1h, transaction_count_24h, avg_amount_1h, |
|
location_risk_score, account_age_days |
|
] |
|
) |
|
|
|
gr.Markdown(""" |
|
--- |
|
### π How to Use: |
|
1. **Load Sample**: Choose a predefined sample transaction to quickly test the system |
|
2. **Enter Details**: Fill in transaction information manually or modify loaded samples |
|
3. **Analyze**: Click "Analyze Transaction" to get fraud detection results |
|
|
|
### π― Understanding Results: |
|
- **Fraud Probability**: Higher values indicate higher fraud risk (0-1 scale) |
|
- **Risk Levels**: MINIMAL β LOW β MEDIUM β HIGH β CRITICAL |
|
- **Risk Meter**: Visual representation of fraud probability |
|
- **Detailed Scores**: Individual model component scores |
|
|
|
### β οΈ Model Requirements: |
|
Ensure `fraud_detection_model.pth` is available in the same directory as this script. |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
debug=True |
|
) |