qJFI's picture
Create app.py
0cbb679 verified
raw
history blame
13 kB
import numpy as np
import xgboost as xgb
import gradio as gr
from scapy.all import rdpcap
from collections import defaultdict
import os
def transform_new_input(new_input):
#Scale input features based on predetermined min/max values
scaled_min = np.array([
1.0, 10.0, 856.0, 5775.0, 42.0, 26.0, 0.0, 278.0, 4.0, 1.0,
-630355.0, 4.0, 50.0
])
scaled_max = np.array([
4.0, 352752.0, 271591638.0, 239241314.0, 421552.0, 3317.0,
6302708.0, 6302708.0, 5.0, 5.0, 1746749.0, 608.0, 1012128.0
])
new_input = np.array(new_input)
scaled_input = (new_input - scaled_min) / (scaled_max - scaled_min)
return scaled_input
class PcapProcessor:
def __init__(self, pcap_file):
#Initialize PCAP processor with file path
self.packets = rdpcap(pcap_file)
self.start_time = None
self.port_stats = defaultdict(lambda: {
'rx_packets': 0,
'rx_bytes': 0,
'tx_packets': 0,
'tx_bytes': 0,
'first_seen': None,
'last_seen': None,
'active_flows': set(),
'packets_matched': 0
})
def process_packets(self, window_size=60):
#Process all packets and extract features
if not self.packets:
return []
self.start_time = float(self.packets[0].time)
# Process each packet
for packet in self.packets:
current_time = float(packet.time)
if 'TCP' in packet or 'UDP' in packet:
try:
src_port = packet.sport
dst_port = packet.dport
pkt_size = len(packet)
# Track flow information
flow_tuple = (packet['IP'].src, packet['IP'].dst,
src_port, dst_port)
# Update port statistics
self._update_port_stats(src_port, pkt_size, True,
current_time, flow_tuple)
self._update_port_stats(dst_port, pkt_size, False,
current_time, flow_tuple)
except Exception as e:
print(f"Error processing packet {packet}: {str(e)}")
continue
# Extract features for each port
features_list = []
for port, stats in self.port_stats.items():
if stats['first_seen'] is not None:
features = self._extract_port_features(port, stats, window_size)
features_list.append(features)
return features_list
def _update_port_stats(self, port, pkt_size, is_source, current_time,
flow_tuple):
#Update statistics for a given port
stats = self.port_stats[port]
if stats['first_seen'] is None:
stats['first_seen'] = current_time
stats['last_seen'] = current_time
if is_source:
stats['tx_packets'] += 1
stats['tx_bytes'] += pkt_size
else:
stats['rx_packets'] += 1
stats['rx_bytes'] += pkt_size
stats['active_flows'].add(flow_tuple)
stats['packets_matched'] += 1
def _extract_port_features(self, port, stats, window_size):
#Extract the 13 features needed for the IDS model
port_alive_duration = stats['last_seen'] - stats['first_seen']
delta_alive_duration = min(port_alive_duration, window_size)
# Calculate rates and loads
total_load = (stats['rx_bytes'] + stats['tx_bytes']) / \
max(port_alive_duration, 1)
features = [
min(port % 4 + 1, 4), # Port Number (1-4)
stats['rx_packets'], # Received Packets
stats['rx_bytes'], # Received Bytes
stats['tx_bytes'], # Sent Bytes
stats['tx_packets'], # Sent Packets
port_alive_duration, # Port alive Duration
stats['rx_bytes'], # Delta Received Bytes
stats['tx_bytes'], # Delta Sent Bytes
min(delta_alive_duration, 5), # Delta Port alive Duration
min((port % 5) + 1, 5), # Connection Point
total_load, # Total Load/Rate
len(stats['active_flows']), # Active Flow Entries
stats['packets_matched'] # Packets Matched
]
return features
def process_pcap_for_ids(pcap_file):
"""Process PCAP file and return features for IDS model"""
processor = PcapProcessor(pcap_file)
features = processor.process_packets()
return features
def predict_from_features(features, model):
"""Make prediction from extracted features"""
# Scale features
scaled_features = transform_new_input(features)
features_matrix = xgb.DMatrix(scaled_features.reshape(1, -1))
# Make prediction and get probability distribution
raw_prediction = model.predict(features_matrix)
probabilities = raw_prediction[0] # Get probability distribution
prediction = np.argmax(probabilities)
# Add threshold for normal traffic
# If highest probability is for normal (class 0) and exceeds threshold
if prediction == 0 and probabilities[0] > 0.6: # 60% confidence threshold
return get_prediction_message(0)
# If no class has high confidence, consider it normal
elif np.max(probabilities) < 0.4: # Low confidence threshold
return get_prediction_message(0)
else:
return get_prediction_message(prediction)
def get_prediction_message(prediction):
"""Get formatted prediction message with confidence levels"""
messages = {
0: ("NORMAL TRAFFIC - No indication of attack.",
"Traffic patterns appear to be within normal parameters."),
1: ("ALERT: Potential BLACKHOLE attack detected.",
"Information: BLACKHOLE attacks occur when a router maliciously drops "
"packets it should forward. Investigate affected routes and traffic patterns."),
2: ("ALERT: Potential TCP-SYN flood attack detected.",
"Information: TCP-SYN flood is a DDoS attack exhausting server resources "
"with half-open connections. Check connection states and implement SYN cookies."),
3: ("ALERT: PORTSCAN activity detected.",
"Information: Port scanning detected - systematic probing of system ports. "
"Review firewall rules and implement connection rate limiting."),
4: ("ALERT: Potential DIVERSION attack detected.",
"Information: Traffic diversion detected. Verify routing integrity and "
"check for signs of traffic manipulation or social engineering attempts.")
}
return messages.get(prediction, ("Unknown Traffic Pattern", "Additional analysis required."))
def process_pcap_input(pcap_file):
"""Process PCAP file input"""
try:
model = xgb.Booster()
model.load_model("m3_xg_boost.model")
features_list = process_pcap_for_ids(pcap_file.name)
if not features_list:
return "No valid network traffic found in PCAP file."
results = []
for idx, features in enumerate(features_list):
result_msg, result_info = predict_from_features(features, model)
results.append(f"Traffic Pattern {idx + 1}:\n{result_msg}\n{result_info}\n")
return "\n".join(results)
except Exception as e:
return f"Error processing PCAP file: {str(e)}"
def process_manual_input(port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes,
delta_duration, conn_point, total_load, active_flows,
packets_matched):
#Process manual input values
try:
model = xgb.Booster()
model.load_model("m3_xg_boost.model")
features = [
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes, delta_duration,
conn_point, total_load, active_flows, packets_matched
]
result_msg, result_info = predict_from_features(features, model)
return f"{result_msg}\n{result_info}"
except Exception as e:
return f"Error processing manual input: {str(e)}"
# Main execution
if __name__ == "__main__":
# Create the interface
with gr.Blocks(theme="default") as interface:
gr.Markdown("""
# Network Intrusion Detection System
Upload a PCAP file or use manual input to detect potential network attacks.
""")
with gr.Tab("PCAP Analysis"):
pcap_input = gr.File(
label="Upload PCAP File",
file_types=[".pcap", ".pcapng"]
)
pcap_output = gr.Textbox(label="Analysis Results")
pcap_button = gr.Button("Analyze PCAP")
pcap_button.click(
fn=process_pcap_input,
inputs=[pcap_input],
outputs=pcap_output
)
with gr.Tab("Manual Input"):
# Manual input components
with gr.Row():
port_num = gr.Slider(1, 4, value=1,
label="Port Number - The switch port through which the flow passed")
rx_packets = gr.Slider(0, 352772, value=0,
label="Received Packets - Number of packets received by the port")
with gr.Row():
rx_bytes = gr.Slider(0, 2.715916e08, value=0,
label="Received Bytes - Number of bytes received by the port")
tx_bytes = gr.Slider(0, 2.392430e08, value=0,
label="Sent Bytes - Number of bytes sent by the port")
with gr.Row():
tx_packets = gr.Slider(0, 421598, value=0,
label="Sent Packets - Number of packets sent by the port")
port_duration = gr.Slider(0, 3317, value=0,
label="Port alive Duration (S) - The time port has been alive in seconds")
with gr.Row():
delta_rx_bytes = gr.Slider(0, 6500000, value=0,
label="Delta Received Bytes")
delta_tx_bytes = gr.Slider(0, 6500000, value=0,
label="Delta Sent Bytes")
with gr.Row():
delta_duration = gr.Slider(0, 5, value=0,
label="Delta Port alive Duration (S)")
conn_point = gr.Slider(1, 5, value=1,
label="Connection Point")
with gr.Row():
total_load = gr.Slider(0, 1800000, value=0,
label="Total Load/Rate")
active_flows = gr.Slider(0, 610, value=0,
label="Active Flow Entries")
with gr.Row():
packets_matched = gr.Slider(0, 1020000, value=0,
label="Packets Matched")
manual_output = gr.Textbox(label="Analysis Results")
manual_button = gr.Button("Analyze Manual Input")
# Connect manual input components
manual_button.click(
fn=process_manual_input,
inputs=[
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes,
delta_duration, conn_point, total_load, active_flows,
packets_matched
],
outputs=manual_output
)
# Example inputs
gr.Examples(
examples=[
[4, 350188, 14877116, 101354648, 159524, 2910, 278, 280,
5, 4, 0, 6, 667324],
[2, 2326, 12856942, 31777516, 2998, 2497, 560, 560,
5, 2, 0, 4, 7259],
[4, 150, 19774, 6475473, 3054, 166, 556, 6068,
5, 4, 502, 6, 7418],
[2, 209, 20671, 6316631, 274, 96, 3527, 2757949,
5, 2, 183877, 8, 90494],
[2, 1733, 37865130, 38063670, 3187, 2152, 0, 556,
5, 3, 0, 4, 14864]
],
inputs=[
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
port_duration, delta_rx_bytes, delta_tx_bytes,
delta_duration, conn_point, total_load, active_flows,
packets_matched
]
)
# Launch the interface
interface.launch()