project-final / app.py
srinuksv's picture
Create app.py
a9280ab verified
import time
import subprocess
import pyshark
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
from selenium.webdriver.chrome.options import Options
import numpy as np
import joblib
import pandas as pd
import scapy.all as scapy
import requests
import gradio as gr
# Load the pre-trained model and feature names
model = joblib.load('extratrees.pkl')
all_features = joblib.load('featurenames.pkl')
# Modify the capture duration to a longer period
def capture_packets(url, capture_duration=30, capture_file="capture.pcap"):
try:
# Start tshark to capture packets
tshark_process = subprocess.Popen(
["tshark", "-i", "any", "-f", "tcp port 80 or tcp port 443 or port 53", "-w", capture_file],
stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
# Wait for tshark to start
time.sleep(2)
# Set up Chrome options
chrome_options = Options()
chrome_options.add_argument("--headless") # Run Chrome in headless mode
chrome_options.add_argument("--no-sandbox")
chrome_options.add_argument("--disable-dev-shm-usage")
# Use Selenium to visit the URL
service = Service(ChromeDriverManager().install()) # Ensure the driver is installed
driver = webdriver.Chrome(service=service, options=chrome_options)
driver.get(url)
# Capture packets for the specified duration
time.sleep(capture_duration)
# Close the browser
driver.quit()
# Stop tshark
tshark_process.terminate()
tshark_process.wait()
# Read captured packets using pyshark for detailed packet information
packets = []
cap = pyshark.FileCapture(capture_file)
for packet in cap:
packets.append(str(packet))
cap.close()
return packets
except Exception as e:
print(f"Error in capturing packets: {e}")
return None
# Function to extract features from captured packets
def extract_features(capture_file):
try:
cap = pyshark.FileCapture(capture_file)
# Initialize features
features = {feature: 0 for feature in all_features}
total_packets = 0
total_bytes = 0
start_time = None
end_time = None
packet_lengths = []
protocol_counts = {'TCP': 0, 'UDP': 0, 'ICMP': 0}
tcp_flags = {'SYN': 0, 'ACK': 0, 'FIN': 0, 'RST': 0}
for packet in cap:
total_packets += 1
total_bytes += int(packet.length)
packet_lengths.append(int(packet.length))
timestamp = float(packet.sniff_time.timestamp())
if start_time is None:
start_time = timestamp
end_time = timestamp
# Counting protocols and flags
if hasattr(packet, 'tcp'):
protocol_counts['TCP'] += 1
if 'SYN' in packet.tcp.flags:
tcp_flags['SYN'] += 1
if 'ACK' in packet.tcp.flags:
tcp_flags['ACK'] += 1
if 'FIN' in packet.tcp.flags:
tcp_flags['FIN'] += 1
if 'RST' in packet.tcp.flags:
tcp_flags['RST'] += 1
elif hasattr(packet, 'udp'):
protocol_counts['UDP'] += 1
elif hasattr(packet, 'icmp'):
protocol_counts['ICMP'] += 1
duration = end_time - start_time if start_time and end_time else 0
# Populate extracted features
features.update({
"Flow Duration": duration,
"Total Packets": total_packets,
"Total Bytes": total_bytes,
"Fwd Packet Length Mean": np.mean(packet_lengths) if packet_lengths else 0,
"Bwd Packet Length Mean": 0, # Assuming no distinction here
"Flow Bytes/s": total_bytes / duration if duration else 0,
"Flow Packets/s": total_packets / duration if duration else 0,
"Average Packet Size": np.mean(packet_lengths) if packet_lengths else 0,
"Min Packet Size": min(packet_lengths) if packet_lengths else 0,
"Max Packet Size": max(packet_lengths) if packet_lengths else 0,
"Packet Length Variance": np.var(packet_lengths) if len(packet_lengths) > 1 else 0,
"TCP Packets": protocol_counts['TCP'],
"UDP Packets": protocol_counts['UDP'],
"ICMP Packets": protocol_counts['ICMP'],
"TCP SYN Flags": tcp_flags['SYN'],
"TCP ACK Flags": tcp_flags['ACK'],
"TCP FIN Flags": tcp_flags['FIN'],
"TCP RST Flags": tcp_flags['RST']
})
return features
except Exception as e:
print(f"Error in extracting features: {e}")
return None
# Function to compare features with CIC-IDS-2017 dataset
def compare_with_dataset(packet_features):
# Convert the extracted features into a format that the model can use
packet_features_series = pd.Series(packet_features)
packet_features_series = packet_features_series.reindex(all_features, fill_value=0)
# Predict using the loaded model
prediction = model.predict([packet_features_series])[0]
return "benign" if prediction == 0 else "malicious"
# Analyze the URL and predict if it's malicious
def analyze_url(url):
try:
# Capture packets using Scapy (updating to capture more specific traffic)
response = requests.get(url)
packets = scapy.sniff(count=100) # Capture packets with Scapy
capture_file = 'capture.pcap'
scapy.wrpcap(capture_file, packets)
# Extract features from the captured packets
packet_features = extract_features(capture_file)
if packet_features is not None:
prediction = compare_with_dataset(packet_features)
# Use Pyshark to capture HTTP/HTTPS/DNS packet details
http_dns_packets = capture_packets(url)
captured_packets = [str(packet) for packet in packets]
return prediction, {"scapy_packets": captured_packets, "http_dns_packets": http_dns_packets}
else:
return "Error in feature extraction", []
except Exception as e:
return str(e), []
# Define the Gradio interface
iface = gr.Interface(
fn=analyze_url,
inputs=gr.Textbox(label="Enter URL"),
outputs=[gr.Textbox(label="Prediction"), gr.JSON(label="Captured Packets")],
title="URL Malicious Activity Detection",
description="Enter a URL to predict if it's malicious or benign by analyzing the network traffic."
)
# Launch the interface
iface.launch(debug=True)