qJFI's picture
Create app.py
0cbb679 verified
import numpy as np
import xgboost as xgb
import gradio as gr
from scapy.all import rdpcap
from collections import defaultdict
import os
def transform_new_input(new_input):
#Scale input features based on predetermined min/max values
scaled_min = np.array([
1.0, 10.0, 856.0, 5775.0, 42.0, 26.0, 0.0, 278.0, 4.0, 1.0,
-630355.0, 4.0, 50.0
])
scaled_max = np.array([
4.0, 352752.0, 271591638.0, 239241314.0, 421552.0, 3317.0,
6302708.0, 6302708.0, 5.0, 5.0, 1746749.0, 608.0, 1012128.0
])
new_input = np.array(new_input)
scaled_input = (new_input - scaled_min) / (scaled_max - scaled_min)
return scaled_input
class PcapProcessor:
def __init__(self, pcap_file):
#Initialize PCAP processor with file path
self.packets = rdpcap(pcap_file)
self.start_time = None
self.port_stats = defaultdict(lambda: {
'rx_packets': 0,
'rx_bytes': 0,
'tx_packets': 0,
'tx_bytes': 0,
'first_seen': None,
'last_seen': None,
'active_flows': set(),
'packets_matched': 0
})
def process_packets(self, window_size=60):
#Process all packets and extract features
if not self.packets:
return []
self.start_time = float(self.packets[0].time)
# Process each packet
for packet in self.packets:
current_time = float(packet.time)
if 'TCP' in packet or 'UDP' in packet:
try:
src_port = packet.sport
dst_port = packet.dport
pkt_size = len(packet)
# Track flow information
flow_tuple = (packet['IP'].src, packet['IP'].dst,
src_port, dst_port)
# Update port statistics
self._update_port_stats(src_port, pkt_size, True,
current_time, flow_tuple)
self._update_port_stats(dst_port, pkt_size, False,
current_time, flow_tuple)
except Exception as e:
print(f"Error processing packet {packet}: {str(e)}")
continue
# Extract features for each port
features_list = []
for port, stats in self.port_stats.items():
if stats['first_seen'] is not None:
features = self._extract_port_features(port, stats, window_size)
features_list.append(features)
return features_list
def _update_port_stats(self, port, pkt_size, is_source, current_time,
flow_tuple):
#Update statistics for a given port
stats = self.port_stats[port]
if stats['first_seen'] is None:
stats['first_seen'] = current_time
stats['last_seen'] = current_time
if is_source:
stats['tx_packets'] += 1
stats['tx_bytes'] += pkt_size
else:
stats['rx_packets'] += 1
stats['rx_bytes'] += pkt_size
stats['active_flows'].add(flow_tuple)
stats['packets_matched'] += 1
def _extract_port_features(self, port, stats, window_size):
#Extract the 13 features needed for the IDS model
port_alive_duration = stats['last_seen'] - stats['first_seen']
delta_alive_duration = min(port_alive_duration, window_size)
# Calculate rates and loads
total_load = (stats['rx_bytes'] + stats['tx_bytes']) / \
max(port_alive_duration, 1)
features = [
min(port % 4 + 1, 4), # Port Number (1-4)
stats['rx_packets'], # Received Packets
stats['rx_bytes'], # Received Bytes
stats['tx_bytes'], # Sent Bytes
stats['tx_packets'], # Sent Packets
port_alive_duration, # Port alive Duration
stats['rx_bytes'], # Delta Received Bytes
stats['tx_bytes'], # Delta Sent Bytes
min(delta_alive_duration, 5), # Delta Port alive Duration
min((port % 5) + 1, 5), # Connection Point
total_load, # Total Load/Rate
len(stats['active_flows']), # Active Flow Entries
stats['packets_matched'] # Packets Matched
]
return features
def process_pcap_for_ids(pcap_file):
"""Process PCAP file and return features for IDS model"""
processor = PcapProcessor(pcap_file)
features = processor.process_packets()
return features
def predict_from_features(features, model):
"""Make prediction from extracted features"""
# Scale features
scaled_features = transform_new_input(features)
features_matrix = xgb.DMatrix(scaled_features.reshape(1, -1))
# Make prediction and get probability distribution
raw_prediction = model.predict(features_matrix)
probabilities = raw_prediction[0] # Get probability distribution
prediction = np.argmax(probabilities)
# Add threshold for normal traffic
# If highest probability is for normal (class 0) and exceeds threshold
if prediction == 0 and probabilities[0] > 0.6: # 60% confidence threshold
return get_prediction_message(0)
# If no class has high confidence, consider it normal
elif np.max(probabilities) < 0.4: # Low confidence threshold
return get_prediction_message(0)
else:
return get_prediction_message(prediction)
def get_prediction_message(prediction):
"""Get formatted prediction message with confidence levels"""
messages = {
0: ("NORMAL TRAFFIC - No indication of attack.",
"Traffic patterns appear to be within normal parameters."),
1: ("ALERT: Potential BLACKHOLE attack detected.",
"Information: BLACKHOLE attacks occur when a router maliciously drops "
"packets it should forward. Investigate affected routes and traffic patterns."),
2: ("ALERT: Potential TCP-SYN flood attack detected.",
"Information: TCP-SYN flood is a DDoS attack exhausting server resources "
"with half-open connections. Check connection states and implement SYN cookies."),
3: ("ALERT: PORTSCAN activity detected.",
"Information: Port scanning detected - systematic probing of system ports. "
"Review firewall rules and implement connection rate limiting."),
4: ("ALERT: Potential DIVERSION attack detected.",
"Information: Traffic diversion detected. Verify routing integrity and "
"check for signs of traffic manipulation or social engineering attempts.")
}
return messages.get(prediction, ("Unknown Traffic Pattern", "Additional analysis required."))
def process_pcap_input(pcap_file):
"""Process PCAP file input"""
try:
model = xgb.Booster()
model.load_model("m3_xg_boost.model")
features_list = process_pcap_for_ids(pcap_file.name)
if not features_list:
return "No valid network traffic found in PCAP file."
results = []
for idx, features in enumerate(features_list):
result_msg, result_info = predict_from_features(features, model)
results.append(f"Traffic Pattern {idx + 1}:\n{result_msg}\n{result_info}\n")
return "\n".join(results)
except Exception as e:
return f"Error processing PCAP file: {str(e)}"
def process_manual_input(port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes,
delta_duration, conn_point, total_load, active_flows,
packets_matched):
#Process manual input values
try:
model = xgb.Booster()
model.load_model("m3_xg_boost.model")
features = [
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes, delta_duration,
conn_point, total_load, active_flows, packets_matched
]
result_msg, result_info = predict_from_features(features, model)
return f"{result_msg}\n{result_info}"
except Exception as e:
return f"Error processing manual input: {str(e)}"
# Main execution
if __name__ == "__main__":
# Create the interface
with gr.Blocks(theme="default") as interface:
gr.Markdown("""
# Network Intrusion Detection System
Upload a PCAP file or use manual input to detect potential network attacks.
""")
with gr.Tab("PCAP Analysis"):
pcap_input = gr.File(
label="Upload PCAP File",
file_types=[".pcap", ".pcapng"]
)
pcap_output = gr.Textbox(label="Analysis Results")
pcap_button = gr.Button("Analyze PCAP")
pcap_button.click(
fn=process_pcap_input,
inputs=[pcap_input],
outputs=pcap_output
)
with gr.Tab("Manual Input"):
# Manual input components
with gr.Row():
port_num = gr.Slider(1, 4, value=1,
label="Port Number - The switch port through which the flow passed")
rx_packets = gr.Slider(0, 352772, value=0,
label="Received Packets - Number of packets received by the port")
with gr.Row():
rx_bytes = gr.Slider(0, 2.715916e08, value=0,
label="Received Bytes - Number of bytes received by the port")
tx_bytes = gr.Slider(0, 2.392430e08, value=0,
label="Sent Bytes - Number of bytes sent by the port")
with gr.Row():
tx_packets = gr.Slider(0, 421598, value=0,
label="Sent Packets - Number of packets sent by the port")
port_duration = gr.Slider(0, 3317, value=0,
label="Port alive Duration (S) - The time port has been alive in seconds")
with gr.Row():
delta_rx_bytes = gr.Slider(0, 6500000, value=0,
label="Delta Received Bytes")
delta_tx_bytes = gr.Slider(0, 6500000, value=0,
label="Delta Sent Bytes")
with gr.Row():
delta_duration = gr.Slider(0, 5, value=0,
label="Delta Port alive Duration (S)")
conn_point = gr.Slider(1, 5, value=1,
label="Connection Point")
with gr.Row():
total_load = gr.Slider(0, 1800000, value=0,
label="Total Load/Rate")
active_flows = gr.Slider(0, 610, value=0,
label="Active Flow Entries")
with gr.Row():
packets_matched = gr.Slider(0, 1020000, value=0,
label="Packets Matched")
manual_output = gr.Textbox(label="Analysis Results")
manual_button = gr.Button("Analyze Manual Input")
# Connect manual input components
manual_button.click(
fn=process_manual_input,
inputs=[
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes,
delta_duration, conn_point, total_load, active_flows,
packets_matched
],
outputs=manual_output
)
# Example inputs
gr.Examples(
examples=[
[4, 350188, 14877116, 101354648, 159524, 2910, 278, 280,
5, 4, 0, 6, 667324],
[2, 2326, 12856942, 31777516, 2998, 2497, 560, 560,
5, 2, 0, 4, 7259],
[4, 150, 19774, 6475473, 3054, 166, 556, 6068,
5, 4, 502, 6, 7418],
[2, 209, 20671, 6316631, 274, 96, 3527, 2757949,
5, 2, 183877, 8, 90494],
[2, 1733, 37865130, 38063670, 3187, 2152, 0, 556,
5, 3, 0, 4, 14864]
],
inputs=[
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes,
delta_duration, conn_point, total_load, active_flows,
packets_matched
]
)
# Launch the interface
interface.launch()