CoderCowMoo commited on
Commit
09c1340
·
verified ·
1 Parent(s): b5fbd04

Upload 3 files

Browse files

Add model .pth, training files.

Files changed (3) hide show
  1. phishing_mlp_model.pth +3 -0
  2. training.py +830 -0
  3. training_results.png +0 -0
phishing_mlp_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dc2bd19101c1eb353d5e9bbf22bf2c76c457a998799e194e409a372ea421353
3
+ size 9348
training.py ADDED
@@ -0,0 +1,830 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader, random_split
5
+ import pandas as pd
6
+ import numpy as np
7
+ import json
8
+ import os
9
+ import re
10
+ import urllib.parse
11
+ import matplotlib.pyplot as plt
12
+ from collections import Counter
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.preprocessing import StandardScaler
15
+ import tqdm
16
+
17
+ # --- Healthcare URL Detection Components ---
18
+
19
+ # Healthcare-related keywords for domain detection
20
+ HEALTHCARE_KEYWORDS = [
21
+ 'health', 'medical', 'hospital', 'clinic', 'pharma', 'patient', 'care', 'med',
22
+ 'doctor', 'physician', 'nurse', 'therapy', 'rehab', 'dental', 'cardio', 'neuro',
23
+ 'oncology', 'pediatric', 'orthopedic', 'surgery', 'diagnostic', 'wellbeing',
24
+ 'wellness', 'ehr', 'emr', 'mychart', 'medicare', 'medicaid', 'insurance'
25
+ ]
26
+
27
+ # Common healthcare institutions and systems
28
+ HEALTHCARE_INSTITUTIONS = [
29
+ 'mayo', 'cleveland', 'hopkins', 'kaiser', 'mount sinai', 'cedars', 'baylor',
30
+ 'nhs', 'quest', 'labcorp', 'cvs', 'walgreens', 'aetna', 'cigna', 'unitedhealthcare',
31
+ 'bluecross', 'anthem', 'humana', 'va.gov', 'cdc', 'who', 'nih'
32
+ ]
33
+
34
+ # Healthcare TLDs and specific domains
35
+ HEALTHCARE_DOMAINS = ['.health', '.healthcare', '.medicine', '.hospital', '.clinic', 'mychart.']
36
+
37
+ # --- Feature Extraction Functions ---
38
+
39
+ def url_length(url):
40
+ """Return the length of the URL."""
41
+ return len(url)
42
+
43
+ def num_dots(url):
44
+ """Return the number of dots in the URL."""
45
+ return url.count('.')
46
+
47
+ def num_hyphens(url):
48
+ """Return the number of hyphens in the URL."""
49
+ return url.count('-')
50
+
51
+ def num_at(url):
52
+ """Return the number of @ symbols in the URL."""
53
+ return url.count('@')
54
+
55
+ def num_tilde(url):
56
+ """Return the number of ~ symbols in the URL."""
57
+ return url.count('~')
58
+
59
+ def num_underscore(url):
60
+ """Return the number of underscores in the URL."""
61
+ return url.count('_')
62
+
63
+ def num_percent(url):
64
+ """Return the number of percent symbols in the URL."""
65
+ return url.count('%')
66
+
67
+ def num_ampersand(url):
68
+ """Return the number of ampersands in the URL."""
69
+ return url.count('&')
70
+
71
+ def num_hash(url):
72
+ """Return the number of hash symbols in the URL."""
73
+ return url.count('#')
74
+
75
+ def has_https(url):
76
+ """Return 1 if the URL uses HTTPS, 0 otherwise."""
77
+ return int(url.startswith('https://'))
78
+
79
+ def has_ip_address(url):
80
+ """Check if the URL contains an IP address instead of a domain name."""
81
+ try:
82
+ parsed_url = urllib.parse.urlparse(url)
83
+ if re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', parsed_url.netloc):
84
+ return 1
85
+ # Check for IPv6
86
+ if re.match(r'^\[[0-9a-fA-F:]+\]$', parsed_url.netloc):
87
+ return 1
88
+ return 0
89
+ except:
90
+ return 0
91
+
92
+ def get_hostname_length(url):
93
+ """Return the length of the hostname."""
94
+ try:
95
+ parsed_url = urllib.parse.urlparse(url)
96
+ return len(parsed_url.netloc)
97
+ except:
98
+ return 0
99
+
100
+ def get_path_length(url):
101
+ """Return the length of the path."""
102
+ try:
103
+ parsed_url = urllib.parse.urlparse(url)
104
+ return len(parsed_url.path)
105
+ except:
106
+ return 0
107
+
108
+ def get_path_level(url):
109
+ """Return the number of directories in the path."""
110
+ try:
111
+ parsed_url = urllib.parse.urlparse(url)
112
+ return parsed_url.path.count('/')
113
+ except:
114
+ return 0
115
+
116
+ def get_subdomain_level(url):
117
+ """Return the number of subdomains in the URL."""
118
+ try:
119
+ parsed_url = urllib.parse.urlparse(url)
120
+ hostname = parsed_url.netloc
121
+ if has_ip_address(url):
122
+ return 0 # IP addresses don't have subdomains
123
+
124
+ parts = hostname.split('.')
125
+ # Remove top-level and second-level domains
126
+ if len(parts) > 2:
127
+ return len(parts) - 2 # Count remaining parts as subdomain levels
128
+ else:
129
+ return 0 # No subdomains
130
+ except:
131
+ return 0
132
+
133
+ def has_double_slash_in_path(url):
134
+ """Check if the path contains a double slash."""
135
+ try:
136
+ parsed_url = urllib.parse.urlparse(url)
137
+ return int('//' in parsed_url.path)
138
+ except:
139
+ return 0
140
+
141
+ def get_tld(url):
142
+ """Extract the top-level domain from a URL."""
143
+ try:
144
+ parsed_url = urllib.parse.urlparse(url)
145
+ hostname = parsed_url.netloc.lower()
146
+ parts = hostname.split('.')
147
+ if len(parts) > 1:
148
+ return parts[-1]
149
+ return ''
150
+ except:
151
+ return ''
152
+
153
+ def count_digits(url):
154
+ """Count the number of digits in the URL."""
155
+ return sum(c.isdigit() for c in url)
156
+
157
+ def digit_ratio(url):
158
+ """Calculate the ratio of digits to the total URL length."""
159
+ if len(url) == 0:
160
+ return 0
161
+ return count_digits(url) / len(url)
162
+
163
+ def count_letters(url):
164
+ """Count the number of letters in the URL."""
165
+ return sum(c.isalpha() for c in url)
166
+
167
+ def letter_ratio(url):
168
+ """Calculate the ratio of letters to the total URL length."""
169
+ if len(url) == 0:
170
+ return 0
171
+ return count_letters(url) / len(url)
172
+
173
+ def count_special_chars(url):
174
+ """Count the number of special characters in the URL."""
175
+ return sum(not c.isalnum() and not c.isspace() for c in url)
176
+
177
+ def special_char_ratio(url):
178
+ """Calculate the ratio of special characters to the total URL length."""
179
+ if len(url) == 0:
180
+ return 0
181
+ return count_special_chars(url) / len(url)
182
+
183
+ def get_query_length(url):
184
+ """Return the length of the query string."""
185
+ try:
186
+ parsed_url = urllib.parse.urlparse(url)
187
+ return len(parsed_url.query)
188
+ except:
189
+ return 0
190
+
191
+ def get_fragment_length(url):
192
+ """Return the length of the fragment."""
193
+ try:
194
+ parsed_url = urllib.parse.urlparse(url)
195
+ return len(parsed_url.fragment)
196
+ except:
197
+ return 0
198
+
199
+ def healthcare_relevance_score(url):
200
+ """
201
+ Calculate a relevance score for healthcare-related URLs.
202
+ Higher scores indicate stronger relation to healthcare.
203
+ """
204
+ url_lower = url.lower()
205
+ parsed_url = urllib.parse.urlparse(url_lower)
206
+ domain = parsed_url.netloc
207
+ path = parsed_url.path
208
+
209
+ score = 0
210
+
211
+ # Check for healthcare keywords in domain
212
+ for keyword in HEALTHCARE_KEYWORDS:
213
+ if keyword in domain:
214
+ score += 3
215
+ elif keyword in path:
216
+ score += 1
217
+
218
+ # Check for healthcare institutions
219
+ for institution in HEALTHCARE_INSTITUTIONS:
220
+ if institution in domain:
221
+ score += 4
222
+ elif institution in path:
223
+ score += 2
224
+
225
+ # Check for healthcare-specific domains and TLDs
226
+ for healthcare_domain in HEALTHCARE_DOMAINS:
227
+ if healthcare_domain in domain:
228
+ score += 3
229
+
230
+ # Check for EHR/patient portal indicators
231
+ if 'portal' in domain or 'portal' in path:
232
+ score += 2
233
+ if 'patient' in domain or 'mychart' in domain:
234
+ score += 3
235
+ if 'ehr' in domain or 'emr' in domain:
236
+ score += 3
237
+
238
+ # Normalize score to be between 0 and 1
239
+ return min(score / 10.0, 1.0)
240
+
241
+ def extract_features(url):
242
+ """Extract all features from a given URL."""
243
+ features = [
244
+ # Core features (the original 17)
245
+ num_dots(url),
246
+ get_subdomain_level(url),
247
+ get_path_level(url),
248
+ url_length(url),
249
+ num_hyphens(url),
250
+ num_at(url),
251
+ num_tilde(url),
252
+ num_underscore(url),
253
+ num_percent(url),
254
+ num_ampersand(url),
255
+ num_hash(url),
256
+ has_https(url),
257
+ has_ip_address(url),
258
+ get_hostname_length(url),
259
+ get_path_length(url),
260
+ has_double_slash_in_path(url),
261
+
262
+ # Additional features
263
+ digit_ratio(url),
264
+ letter_ratio(url),
265
+ special_char_ratio(url),
266
+ get_query_length(url),
267
+ get_fragment_length(url),
268
+ healthcare_relevance_score(url)
269
+ ]
270
+ return features
271
+
272
+ def get_feature_names():
273
+ """Get names of all features in the order they are extracted."""
274
+ return [
275
+ 'num_dots', 'subdomain_level', 'path_level', 'url_length',
276
+ 'num_hyphens', 'num_at', 'num_tilde', 'num_underscore',
277
+ 'num_percent', 'num_ampersand', 'num_hash', 'has_https',
278
+ 'has_ip_address', 'hostname_length', 'path_length',
279
+ 'double_slash_in_path', 'digit_ratio', 'letter_ratio',
280
+ 'special_char_ratio', 'query_length', 'fragment_length',
281
+ 'healthcare_relevance'
282
+ ]
283
+
284
+ # --- Dataset Loading and Processing ---
285
+
286
+ class URLDataset(Dataset):
287
+ def __init__(self, features, labels):
288
+ """
289
+ Custom PyTorch Dataset for URL features and labels.
290
+
291
+ Args:
292
+ features (numpy.ndarray): Feature vectors for each URL
293
+ labels (numpy.ndarray): Labels for each URL (0 for benign, 1 for malicious)
294
+ """
295
+ self.features = torch.tensor(features, dtype=torch.float32)
296
+ self.labels = torch.tensor(labels, dtype=torch.long)
297
+
298
+ def __len__(self):
299
+ return len(self.labels)
300
+
301
+ def __getitem__(self, idx):
302
+ return self.features[idx], self.labels[idx]
303
+
304
+ def load_huggingface_data(file_path):
305
+ """
306
+ Load the Hugging Face dataset from a JSON file.
307
+
308
+ Args:
309
+ file_path: Path to the JSON file
310
+
311
+ Returns:
312
+ List of tuples containing (url, label)
313
+ """
314
+ with open(file_path, 'r', encoding='utf-8') as f:
315
+ data = json.load(f)
316
+
317
+ url_data = []
318
+ for item in data:
319
+ url = item.get('text', '')
320
+ label = item.get('label', -1)
321
+ if url and label != -1: # Only add entries with valid URLs and labels
322
+ url_data.append((url, label))
323
+
324
+ print(f"Loaded {len(url_data)} URLs from Hugging Face dataset")
325
+ return url_data
326
+
327
+ def load_phiusiil_data(file_path):
328
+ """
329
+ Load the PhiUSIIL dataset from a CSV file.
330
+
331
+ Args:
332
+ file_path: Path to the CSV file
333
+
334
+ Returns:
335
+ List of tuples containing (url, label)
336
+ """
337
+ df = pd.read_csv(file_path)
338
+
339
+ url_data = []
340
+ for _, row in df.iterrows():
341
+ url = row['URL']
342
+ label = row['label']
343
+ if isinstance(url, str) and url.strip() and not pd.isna(label):
344
+ url_data.append((url, label))
345
+
346
+ print(f"Loaded {len(url_data)} URLs from PhiUSIIL dataset")
347
+ return url_data
348
+
349
+ def load_kaggle_data(file_path):
350
+ """
351
+ Load the Kaggle malicious_phish.csv dataset.
352
+
353
+ Args:
354
+ file_path: Path to the CSV file
355
+
356
+ Returns:
357
+ List of tuples containing (url, label)
358
+ """
359
+ df = pd.read_csv(file_path)
360
+
361
+ url_data = []
362
+ for _, row in df.iterrows():
363
+ url = row['url']
364
+ type_val = row['type']
365
+
366
+ # Convert to binary classification (0 for benign, 1 for all others)
367
+ label = 0 if type_val.lower() == 'benign' else 1
368
+
369
+ if isinstance(url, str) and url.strip():
370
+ url_data.append((url, label))
371
+
372
+ print(f"Loaded {len(url_data)} URLs from Kaggle dataset")
373
+ return url_data
374
+
375
+ def combine_and_deduplicate(datasets):
376
+ """
377
+ Combine multiple datasets and remove duplicates by URL.
378
+
379
+ Args:
380
+ datasets: List of datasets, each containing (url, label) tuples
381
+
382
+ Returns:
383
+ Tuple of (urls, labels) with duplicates removed
384
+ """
385
+ url_to_label = {}
386
+
387
+ # Process each dataset
388
+ for dataset in datasets:
389
+ for url, label in dataset:
390
+ # If we've seen this URL before with a different label,
391
+ # prefer the malicious label (1) for safety
392
+ if url in url_to_label:
393
+ url_to_label[url] = max(url_to_label[url], label)
394
+ else:
395
+ url_to_label[url] = label
396
+
397
+ # Convert to lists
398
+ urls = list(url_to_label.keys())
399
+ labels = list(url_to_label.values())
400
+
401
+ print(f"After deduplication: {len(urls)} unique URLs")
402
+
403
+ # Report class distribution
404
+ label_counts = Counter(labels)
405
+ print(f"Class distribution - Benign (0): {label_counts[0]}, Malicious (1): {label_counts[1]}")
406
+
407
+ return urls, labels
408
+
409
+ def extract_all_features(urls):
410
+ """
411
+ Extract features from a list of URLs.
412
+
413
+ Args:
414
+ urls: List of URL strings
415
+
416
+ Returns:
417
+ Numpy array of feature vectors
418
+ """
419
+ feature_vectors = []
420
+
421
+ # Use tqdm for a progress bar
422
+ for url in tqdm.tqdm(urls, desc="Extracting features"):
423
+ try:
424
+ features = extract_features(url)
425
+ feature_vectors.append(features)
426
+ except Exception as e:
427
+ print(f"Error extracting features from {url}: {str(e)}")
428
+ # Insert a vector of zeros in case of error
429
+ feature_vectors.append([0] * len(get_feature_names()))
430
+
431
+ return np.array(feature_vectors, dtype=np.float32)
432
+
433
+ # --- MLP Model ---
434
+ class PhishingMLP(nn.Module):
435
+ def __init__(self, input_size=22, hidden_sizes=[22, 30, 10], output_size=1):
436
+ """
437
+ Multilayer Perceptron for Phishing URL Detection.
438
+
439
+ Args:
440
+ input_size: Number of input features (default: 22)
441
+ hidden_sizes: List of neurons in each hidden layer
442
+ output_size: Number of output classes (1 for binary)
443
+ """
444
+ super(PhishingMLP, self).__init__()
445
+
446
+ self.layers = nn.ModuleList()
447
+
448
+ # Input layer to first hidden layer
449
+ self.layers.append(nn.Linear(input_size, hidden_sizes[0]))
450
+ self.layers.append(nn.ReLU())
451
+
452
+ # Hidden layers
453
+ for i in range(len(hidden_sizes) - 1):
454
+ self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
455
+ self.layers.append(nn.ReLU())
456
+
457
+ # Output layer
458
+ self.layers.append(nn.Linear(hidden_sizes[-1], output_size))
459
+ self.layers.append(nn.Sigmoid())
460
+
461
+ def forward(self, x):
462
+ """Forward pass through the network."""
463
+ for layer in self.layers:
464
+ x = layer(x)
465
+ return x
466
+
467
+ # --- Training Functions ---
468
+ def train_mlp(model, train_loader, val_loader, epochs=25, learning_rate=0.001, device="cpu"):
469
+ """
470
+ Train the MLP model.
471
+
472
+ Args:
473
+ model: The MLP model
474
+ train_loader: DataLoader for training data
475
+ val_loader: DataLoader for validation data
476
+ epochs: Number of training epochs
477
+ learning_rate: Learning rate for optimization
478
+ device: Device to train on (cpu or cuda)
479
+
480
+ Returns:
481
+ Tuple of (trained_model, train_losses, val_losses, val_accuracies)
482
+ """
483
+ model.to(device)
484
+ criterion = nn.BCELoss()
485
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
486
+
487
+ train_losses = []
488
+ val_losses = []
489
+ val_accuracies = []
490
+
491
+ print(f"Training on {device}...")
492
+ for epoch in range(epochs):
493
+ # Training phase
494
+ model.train()
495
+ running_loss = 0.0
496
+
497
+ for inputs, labels in train_loader:
498
+ inputs, labels = inputs.to(device), labels.to(device)
499
+
500
+ # Zero the parameter gradients
501
+ optimizer.zero_grad()
502
+
503
+ # Forward + backward + optimize
504
+ outputs = model(inputs)
505
+ loss = criterion(outputs, labels.unsqueeze(1).float())
506
+ loss.backward()
507
+ optimizer.step()
508
+
509
+ running_loss += loss.item()
510
+
511
+ # Calculate average training loss
512
+ epoch_train_loss = running_loss / len(train_loader)
513
+ train_losses.append(epoch_train_loss)
514
+
515
+ # Validation phase
516
+ model.eval()
517
+ val_loss = 0.0
518
+ correct = 0
519
+ total = 0
520
+
521
+ with torch.no_grad():
522
+ for inputs, labels in val_loader:
523
+ inputs, labels = inputs.to(device), labels.to(device)
524
+ outputs = model(inputs)
525
+
526
+ # Calculate validation loss
527
+ loss = criterion(outputs, labels.unsqueeze(1).float())
528
+ val_loss += loss.item()
529
+
530
+ # Calculate accuracy
531
+ predicted = (outputs > 0.5).float()
532
+ total += labels.size(0)
533
+ correct += (predicted.squeeze() == labels.float()).sum().item()
534
+
535
+ # Calculate average validation loss and accuracy
536
+ epoch_val_loss = val_loss / len(val_loader)
537
+ val_losses.append(epoch_val_loss)
538
+
539
+ val_accuracy = 100 * correct / total
540
+ val_accuracies.append(val_accuracy)
541
+
542
+ # Print progress
543
+ print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
544
+
545
+ return model, train_losses, val_losses, val_accuracies
546
+
547
+ def evaluate_model(model, test_loader, device):
548
+ """
549
+ Evaluate the trained model on test data.
550
+
551
+ Args:
552
+ model: Trained model
553
+ test_loader: DataLoader for test data
554
+ device: Device to evaluate on
555
+
556
+ Returns:
557
+ Tuple of (accuracy, precision, recall, f1_score)
558
+ """
559
+ model.to(device)
560
+ model.eval()
561
+
562
+ correct = 0
563
+ total = 0
564
+ true_positives = 0
565
+ false_positives = 0
566
+ false_negatives = 0
567
+ healthcare_correct = 0
568
+ healthcare_total = 0
569
+
570
+ feature_idx = get_feature_names().index('healthcare_relevance')
571
+ healthcare_threshold = 0.5 # Threshold for considering a URL healthcare-related
572
+
573
+ with torch.no_grad():
574
+ for inputs, labels in test_loader:
575
+ inputs, labels = inputs.to(device), labels.to(device)
576
+
577
+ # Forward pass
578
+ outputs = model(inputs)
579
+ predicted = (outputs > 0.5).float().squeeze()
580
+
581
+ # Update counts
582
+ total += labels.size(0)
583
+ correct += (predicted == labels.float()).sum().item()
584
+
585
+ # Metrics calculation
586
+ for i in range(labels.size(0)):
587
+ if labels[i] == 1 and predicted[i] == 1:
588
+ true_positives += 1
589
+ elif labels[i] == 0 and predicted[i] == 1:
590
+ false_positives += 1
591
+ elif labels[i] == 1 and predicted[i] == 0:
592
+ false_negatives += 1
593
+
594
+ # Check healthcare relevance
595
+ if inputs[i, feature_idx] >= healthcare_threshold:
596
+ healthcare_total += 1
597
+ if predicted[i] == labels[i]:
598
+ healthcare_correct += 1
599
+
600
+ # Calculate metrics
601
+ accuracy = 100 * correct / total
602
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0.0
603
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0.0
604
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
605
+
606
+ # Healthcare-specific accuracy
607
+ healthcare_accuracy = 100 * healthcare_correct / healthcare_total if healthcare_total > 0 else 0.0
608
+
609
+ print(f"Overall Test Accuracy: {accuracy:.2f}%")
610
+ print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")
611
+ print(f"Healthcare URLs identified: {healthcare_total} ({healthcare_total/total*100:.2f}%)")
612
+ print(f"Healthcare URL Detection Accuracy: {healthcare_accuracy:.2f}%")
613
+
614
+ return accuracy, precision, recall, f1, healthcare_accuracy
615
+
616
+ def plot_training_results(train_losses, val_losses, val_accuracies):
617
+ """
618
+ Plot training metrics.
619
+
620
+ Args:
621
+ train_losses: List of training losses
622
+ val_losses: List of validation losses
623
+ val_accuracies: List of validation accuracies
624
+ """
625
+ plt.figure(figsize=(15, 5))
626
+
627
+ # Plot losses
628
+ plt.subplot(1, 2, 1)
629
+ plt.plot(train_losses, label='Training Loss')
630
+ plt.plot(val_losses, label='Validation Loss')
631
+ plt.xlabel('Epoch')
632
+ plt.ylabel('Loss')
633
+ plt.title('Training and Validation Loss')
634
+ plt.legend()
635
+
636
+ # Plot accuracy
637
+ plt.subplot(1, 2, 2)
638
+ plt.plot(val_accuracies, label='Validation Accuracy')
639
+ plt.xlabel('Epoch')
640
+ plt.ylabel('Accuracy (%)')
641
+ plt.title('Validation Accuracy')
642
+ plt.legend()
643
+
644
+ plt.tight_layout()
645
+ plt.savefig('training_results.png')
646
+ plt.show()
647
+
648
+ def analyze_healthcare_features(features, labels, pred_labels):
649
+ """
650
+ Analyze how the model performs on healthcare-related URLs.
651
+
652
+ Args:
653
+ features: Feature vectors
654
+ labels: True labels
655
+ pred_labels: Predicted labels
656
+ """
657
+ healthcare_idx = get_feature_names().index('healthcare_relevance')
658
+ healthcare_scores = features[:, healthcare_idx]
659
+
660
+ # Define thresholds
661
+ thresholds = [0.1, 0.3, 0.5, 0.7, 0.9]
662
+
663
+ print("\n=== Healthcare URL Analysis ===")
664
+ print("Healthcare relevance score distribution:")
665
+ for threshold in thresholds:
666
+ count = np.sum(healthcare_scores >= threshold)
667
+ percent = (count / len(healthcare_scores)) * 100
668
+ print(f" Score >= {threshold}: {count} URLs ({percent:.2f}%)")
669
+
670
+ # Analyze performance at different healthcare relevance levels
671
+ for threshold in thresholds:
672
+ mask = healthcare_scores >= threshold
673
+ if np.sum(mask) == 0:
674
+ continue
675
+
676
+ h_labels = labels[mask]
677
+ h_preds = pred_labels[mask]
678
+ h_accuracy = np.mean(h_labels == h_preds) * 100
679
+
680
+ benign_count = np.sum(h_labels == 0)
681
+ malicious_count = np.sum(h_labels == 1)
682
+
683
+ print(f"\nFor healthcare relevance >= {threshold}:")
684
+ print(f" URLs: {np.sum(mask)} ({benign_count} benign, {malicious_count} malicious)")
685
+ print(f" Accuracy: {h_accuracy:.2f}%")
686
+
687
+ # Calculate healthcare-specific metrics
688
+ tp = np.sum((h_labels == 1) & (h_preds == 1))
689
+ fp = np.sum((h_labels == 0) & (h_preds == 1))
690
+ fn = np.sum((h_labels == 1) & (h_preds == 0))
691
+
692
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
693
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
694
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
695
+
696
+ print(f" Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
697
+
698
+ # Calculate false positive rate for healthcare URLs
699
+ if benign_count > 0:
700
+ h_fpr = np.sum((h_labels == 0) & (h_preds == 1)) / benign_count
701
+ print(f" False Positive Rate: {h_fpr:.4f}")
702
+
703
+ # Calculate false negative rate for healthcare URLs
704
+ if malicious_count > 0:
705
+ h_fnr = np.sum((h_labels == 1) & (h_preds == 0)) / malicious_count
706
+ print(f" False Negative Rate: {h_fnr:.4f}")
707
+
708
+ # --- Main Function ---
709
+ def main():
710
+ """Main function to run the entire pipeline."""
711
+ # Configuration
712
+ batch_size = 32
713
+ learning_rate = 0.001
714
+ epochs = 20
715
+ test_size = 0.2
716
+ val_size = 0.2
717
+ random_seed = 42
718
+ device = "cuda" if torch.cuda.is_available() else "cpu"
719
+
720
+ # Filenames
721
+ huggingface_file = "urls.json"
722
+ phiusiil_file = "PhiUSIIL_Phishing_URL_Dataset.csv"
723
+ kaggle_file = "malicious_phish.csv"
724
+
725
+ # Load datasets
726
+ print("Loading datasets...")
727
+ huggingface_data = load_huggingface_data(huggingface_file)
728
+ phiusiil_data = load_phiusiil_data(phiusiil_file)
729
+ kaggle_data = load_kaggle_data(kaggle_file)
730
+
731
+ # Combine and deduplicate datasets
732
+ print("Combining and deduplicating datasets...")
733
+ urls, labels = combine_and_deduplicate([huggingface_data, phiusiil_data, kaggle_data])
734
+
735
+ # Extract features
736
+ print("Extracting features...")
737
+ features = extract_all_features(urls)
738
+
739
+ # Split into train, validation, and test sets
740
+ print("Splitting data...")
741
+ X_train_val, X_test, y_train_val, y_test = train_test_split(
742
+ features, labels, test_size=test_size, random_state=random_seed, stratify=labels
743
+ )
744
+
745
+ X_train, X_val, y_train, y_val = train_test_split(
746
+ X_train_val, y_train_val, test_size=val_size/(1-test_size),
747
+ random_state=random_seed, stratify=y_train_val
748
+ )
749
+
750
+ # Standardize features
751
+ print("Standardizing features...")
752
+ scaler = StandardScaler()
753
+ X_train = scaler.fit_transform(X_train)
754
+ X_val = scaler.transform(X_val)
755
+ X_test = scaler.transform(X_test)
756
+
757
+ # Create PyTorch datasets and dataloaders
758
+ print("Creating DataLoaders...")
759
+ train_dataset = URLDataset(X_train, y_train)
760
+ val_dataset = URLDataset(X_val, y_val)
761
+ test_dataset = URLDataset(X_test, y_test)
762
+
763
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
764
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
765
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
766
+
767
+ # Initialize and train model
768
+ print("Initializing model...")
769
+ input_size = features.shape[1] # Number of features
770
+ model = PhishingMLP(input_size=input_size)
771
+
772
+ print("Training model...")
773
+ trained_model, train_losses, val_losses, val_accuracies = train_mlp(
774
+ model, train_loader, val_loader, epochs=epochs,
775
+ learning_rate=learning_rate, device=device
776
+ )
777
+
778
+ # Save trained model
779
+ print("Saving model...")
780
+ model_path = "phishing_mlp_model.pth"
781
+ torch.save(trained_model.state_dict(), model_path)
782
+ print(f"Model saved to {model_path}")
783
+
784
+ # Evaluate on test set
785
+ print("\nEvaluating model on test set...")
786
+ acc, prec, rec, f1, healthcare_acc = evaluate_model(trained_model, test_loader, device)
787
+
788
+ # Plot results
789
+ plot_training_results(train_losses, val_losses, val_accuracies)
790
+
791
+ # Further healthcare analysis
792
+ y_pred = []
793
+ trained_model.eval()
794
+ with torch.no_grad():
795
+ for inputs, _ in test_loader:
796
+ inputs = inputs.to(device)
797
+ outputs = trained_model(inputs)
798
+ predicted = (outputs > 0.5).float().squeeze().cpu().numpy()
799
+ y_pred.extend(predicted.tolist())
800
+
801
+ analyze_healthcare_features(X_test, np.array(y_test), np.array(y_pred))
802
+
803
+ # Print feature importance summary
804
+ feature_names = get_feature_names()
805
+ healthcare_idx = feature_names.index('healthcare_relevance')
806
+ healthcare_scores = features[:, healthcare_idx]
807
+ high_healthcare = healthcare_scores >= 0.5
808
+
809
+ print("\n=== Healthcare URL Examples ===")
810
+ high_healthcare_indices = np.where(high_healthcare)[0][:5] # Get first 5 indices
811
+ for idx in high_healthcare_indices:
812
+ print(f"URL: {urls[idx]}")
813
+ print(f"Healthcare Score: {healthcare_scores[idx]:.2f}")
814
+ print(f"Label: {'Malicious' if labels[idx] == 1 else 'Benign'}")
815
+ print()
816
+
817
+ # Summary
818
+ print("\n=== Summary ===")
819
+ print(f"Total URLs processed: {len(urls)}")
820
+ print(f"Training set: {len(X_train)} URLs")
821
+ print(f"Validation set: {len(X_val)} URLs")
822
+ print(f"Test set: {len(X_test)} URLs")
823
+ print(f"Model input features: {input_size}")
824
+ print(f"Test Accuracy: {acc:.2f}%")
825
+ print(f"Healthcare URL Accuracy: {healthcare_acc:.2f}%")
826
+ print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1-Score: {f1:.4f}")
827
+ print("\nTraining complete!")
828
+
829
+ if __name__ == "__main__":
830
+ main()
training_results.png ADDED