Spaces:
Sleeping
Sleeping
import time | |
import subprocess | |
import pyshark | |
from selenium import webdriver | |
from selenium.webdriver.chrome.service import Service | |
from webdriver_manager.chrome import ChromeDriverManager | |
from selenium.webdriver.chrome.options import Options | |
import numpy as np | |
import joblib | |
import pandas as pd | |
import scapy.all as scapy | |
import requests | |
import gradio as gr | |
# Load the pre-trained model and feature names | |
model = joblib.load('extratrees.pkl') | |
all_features = joblib.load('featurenames.pkl') | |
# Modify the capture duration to a longer period | |
def capture_packets(url, capture_duration=30, capture_file="capture.pcap"): | |
try: | |
# Start tshark to capture packets | |
tshark_process = subprocess.Popen( | |
["tshark", "-i", "any", "-f", "tcp port 80 or tcp port 443 or port 53", "-w", capture_file], | |
stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
) | |
# Wait for tshark to start | |
time.sleep(2) | |
# Set up Chrome options | |
chrome_options = Options() | |
chrome_options.add_argument("--headless") # Run Chrome in headless mode | |
chrome_options.add_argument("--no-sandbox") | |
chrome_options.add_argument("--disable-dev-shm-usage") | |
# Use Selenium to visit the URL | |
service = Service(ChromeDriverManager().install()) # Ensure the driver is installed | |
driver = webdriver.Chrome(service=service, options=chrome_options) | |
driver.get(url) | |
# Capture packets for the specified duration | |
time.sleep(capture_duration) | |
# Close the browser | |
driver.quit() | |
# Stop tshark | |
tshark_process.terminate() | |
tshark_process.wait() | |
# Read captured packets using pyshark for detailed packet information | |
packets = [] | |
cap = pyshark.FileCapture(capture_file) | |
for packet in cap: | |
packets.append(str(packet)) | |
cap.close() | |
return packets | |
except Exception as e: | |
print(f"Error in capturing packets: {e}") | |
return None | |
# Function to extract features from captured packets | |
def extract_features(capture_file): | |
try: | |
cap = pyshark.FileCapture(capture_file) | |
# Initialize features | |
features = {feature: 0 for feature in all_features} | |
total_packets = 0 | |
total_bytes = 0 | |
start_time = None | |
end_time = None | |
packet_lengths = [] | |
protocol_counts = {'TCP': 0, 'UDP': 0, 'ICMP': 0} | |
tcp_flags = {'SYN': 0, 'ACK': 0, 'FIN': 0, 'RST': 0} | |
for packet in cap: | |
total_packets += 1 | |
total_bytes += int(packet.length) | |
packet_lengths.append(int(packet.length)) | |
timestamp = float(packet.sniff_time.timestamp()) | |
if start_time is None: | |
start_time = timestamp | |
end_time = timestamp | |
# Counting protocols and flags | |
if hasattr(packet, 'tcp'): | |
protocol_counts['TCP'] += 1 | |
if 'SYN' in packet.tcp.flags: | |
tcp_flags['SYN'] += 1 | |
if 'ACK' in packet.tcp.flags: | |
tcp_flags['ACK'] += 1 | |
if 'FIN' in packet.tcp.flags: | |
tcp_flags['FIN'] += 1 | |
if 'RST' in packet.tcp.flags: | |
tcp_flags['RST'] += 1 | |
elif hasattr(packet, 'udp'): | |
protocol_counts['UDP'] += 1 | |
elif hasattr(packet, 'icmp'): | |
protocol_counts['ICMP'] += 1 | |
duration = end_time - start_time if start_time and end_time else 0 | |
# Populate extracted features | |
features.update({ | |
"Flow Duration": duration, | |
"Total Packets": total_packets, | |
"Total Bytes": total_bytes, | |
"Fwd Packet Length Mean": np.mean(packet_lengths) if packet_lengths else 0, | |
"Bwd Packet Length Mean": 0, # Assuming no distinction here | |
"Flow Bytes/s": total_bytes / duration if duration else 0, | |
"Flow Packets/s": total_packets / duration if duration else 0, | |
"Average Packet Size": np.mean(packet_lengths) if packet_lengths else 0, | |
"Min Packet Size": min(packet_lengths) if packet_lengths else 0, | |
"Max Packet Size": max(packet_lengths) if packet_lengths else 0, | |
"Packet Length Variance": np.var(packet_lengths) if len(packet_lengths) > 1 else 0, | |
"TCP Packets": protocol_counts['TCP'], | |
"UDP Packets": protocol_counts['UDP'], | |
"ICMP Packets": protocol_counts['ICMP'], | |
"TCP SYN Flags": tcp_flags['SYN'], | |
"TCP ACK Flags": tcp_flags['ACK'], | |
"TCP FIN Flags": tcp_flags['FIN'], | |
"TCP RST Flags": tcp_flags['RST'] | |
}) | |
return features | |
except Exception as e: | |
print(f"Error in extracting features: {e}") | |
return None | |
# Function to compare features with CIC-IDS-2017 dataset | |
def compare_with_dataset(packet_features): | |
# Convert the extracted features into a format that the model can use | |
packet_features_series = pd.Series(packet_features) | |
packet_features_series = packet_features_series.reindex(all_features, fill_value=0) | |
# Predict using the loaded model | |
prediction = model.predict([packet_features_series])[0] | |
return "benign" if prediction == 0 else "malicious" | |
# Analyze the URL and predict if it's malicious | |
def analyze_url(url): | |
try: | |
# Capture packets using Scapy (updating to capture more specific traffic) | |
response = requests.get(url) | |
packets = scapy.sniff(count=100) # Capture packets with Scapy | |
capture_file = 'capture.pcap' | |
scapy.wrpcap(capture_file, packets) | |
# Extract features from the captured packets | |
packet_features = extract_features(capture_file) | |
if packet_features is not None: | |
prediction = compare_with_dataset(packet_features) | |
# Use Pyshark to capture HTTP/HTTPS/DNS packet details | |
http_dns_packets = capture_packets(url) | |
captured_packets = [str(packet) for packet in packets] | |
return prediction, {"scapy_packets": captured_packets, "http_dns_packets": http_dns_packets} | |
else: | |
return "Error in feature extraction", [] | |
except Exception as e: | |
return str(e), [] | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=analyze_url, | |
inputs=gr.Textbox(label="Enter URL"), | |
outputs=[gr.Textbox(label="Prediction"), gr.JSON(label="Captured Packets")], | |
title="URL Malicious Activity Detection", | |
description="Enter a URL to predict if it's malicious or benign by analyzing the network traffic." | |
) | |
# Launch the interface | |
iface.launch(debug=True) |