Spaces:
Sleeping
Sleeping
import numpy as np | |
import xgboost as xgb | |
import gradio as gr | |
from scapy.all import rdpcap | |
from collections import defaultdict | |
import os | |
def transform_new_input(new_input): | |
#Scale input features based on predetermined min/max values | |
scaled_min = np.array([ | |
1.0, 10.0, 856.0, 5775.0, 42.0, 26.0, 0.0, 278.0, 4.0, 1.0, | |
-630355.0, 4.0, 50.0 | |
]) | |
scaled_max = np.array([ | |
4.0, 352752.0, 271591638.0, 239241314.0, 421552.0, 3317.0, | |
6302708.0, 6302708.0, 5.0, 5.0, 1746749.0, 608.0, 1012128.0 | |
]) | |
new_input = np.array(new_input) | |
scaled_input = (new_input - scaled_min) / (scaled_max - scaled_min) | |
return scaled_input | |
class PcapProcessor: | |
def __init__(self, pcap_file): | |
#Initialize PCAP processor with file path | |
self.packets = rdpcap(pcap_file) | |
self.start_time = None | |
self.port_stats = defaultdict(lambda: { | |
'rx_packets': 0, | |
'rx_bytes': 0, | |
'tx_packets': 0, | |
'tx_bytes': 0, | |
'first_seen': None, | |
'last_seen': None, | |
'active_flows': set(), | |
'packets_matched': 0 | |
}) | |
def process_packets(self, window_size=60): | |
#Process all packets and extract features | |
if not self.packets: | |
return [] | |
self.start_time = float(self.packets[0].time) | |
# Process each packet | |
for packet in self.packets: | |
current_time = float(packet.time) | |
if 'TCP' in packet or 'UDP' in packet: | |
try: | |
src_port = packet.sport | |
dst_port = packet.dport | |
pkt_size = len(packet) | |
# Track flow information | |
flow_tuple = (packet['IP'].src, packet['IP'].dst, | |
src_port, dst_port) | |
# Update port statistics | |
self._update_port_stats(src_port, pkt_size, True, | |
current_time, flow_tuple) | |
self._update_port_stats(dst_port, pkt_size, False, | |
current_time, flow_tuple) | |
except Exception as e: | |
print(f"Error processing packet {packet}: {str(e)}") | |
continue | |
# Extract features for each port | |
features_list = [] | |
for port, stats in self.port_stats.items(): | |
if stats['first_seen'] is not None: | |
features = self._extract_port_features(port, stats, window_size) | |
features_list.append(features) | |
return features_list | |
def _update_port_stats(self, port, pkt_size, is_source, current_time, | |
flow_tuple): | |
#Update statistics for a given port | |
stats = self.port_stats[port] | |
if stats['first_seen'] is None: | |
stats['first_seen'] = current_time | |
stats['last_seen'] = current_time | |
if is_source: | |
stats['tx_packets'] += 1 | |
stats['tx_bytes'] += pkt_size | |
else: | |
stats['rx_packets'] += 1 | |
stats['rx_bytes'] += pkt_size | |
stats['active_flows'].add(flow_tuple) | |
stats['packets_matched'] += 1 | |
def _extract_port_features(self, port, stats, window_size): | |
#Extract the 13 features needed for the IDS model | |
port_alive_duration = stats['last_seen'] - stats['first_seen'] | |
delta_alive_duration = min(port_alive_duration, window_size) | |
# Calculate rates and loads | |
total_load = (stats['rx_bytes'] + stats['tx_bytes']) / \ | |
max(port_alive_duration, 1) | |
features = [ | |
min(port % 4 + 1, 4), # Port Number (1-4) | |
stats['rx_packets'], # Received Packets | |
stats['rx_bytes'], # Received Bytes | |
stats['tx_bytes'], # Sent Bytes | |
stats['tx_packets'], # Sent Packets | |
port_alive_duration, # Port alive Duration | |
stats['rx_bytes'], # Delta Received Bytes | |
stats['tx_bytes'], # Delta Sent Bytes | |
min(delta_alive_duration, 5), # Delta Port alive Duration | |
min((port % 5) + 1, 5), # Connection Point | |
total_load, # Total Load/Rate | |
len(stats['active_flows']), # Active Flow Entries | |
stats['packets_matched'] # Packets Matched | |
] | |
return features | |
def process_pcap_for_ids(pcap_file): | |
"""Process PCAP file and return features for IDS model""" | |
processor = PcapProcessor(pcap_file) | |
features = processor.process_packets() | |
return features | |
def predict_from_features(features, model): | |
"""Make prediction from extracted features""" | |
# Scale features | |
scaled_features = transform_new_input(features) | |
features_matrix = xgb.DMatrix(scaled_features.reshape(1, -1)) | |
# Make prediction and get probability distribution | |
raw_prediction = model.predict(features_matrix) | |
probabilities = raw_prediction[0] # Get probability distribution | |
prediction = np.argmax(probabilities) | |
# Add threshold for normal traffic | |
# If highest probability is for normal (class 0) and exceeds threshold | |
if prediction == 0 and probabilities[0] > 0.6: # 60% confidence threshold | |
return get_prediction_message(0) | |
# If no class has high confidence, consider it normal | |
elif np.max(probabilities) < 0.4: # Low confidence threshold | |
return get_prediction_message(0) | |
else: | |
return get_prediction_message(prediction) | |
def get_prediction_message(prediction): | |
"""Get formatted prediction message with confidence levels""" | |
messages = { | |
0: ("NORMAL TRAFFIC - No indication of attack.", | |
"Traffic patterns appear to be within normal parameters."), | |
1: ("ALERT: Potential BLACKHOLE attack detected.", | |
"Information: BLACKHOLE attacks occur when a router maliciously drops " | |
"packets it should forward. Investigate affected routes and traffic patterns."), | |
2: ("ALERT: Potential TCP-SYN flood attack detected.", | |
"Information: TCP-SYN flood is a DDoS attack exhausting server resources " | |
"with half-open connections. Check connection states and implement SYN cookies."), | |
3: ("ALERT: PORTSCAN activity detected.", | |
"Information: Port scanning detected - systematic probing of system ports. " | |
"Review firewall rules and implement connection rate limiting."), | |
4: ("ALERT: Potential DIVERSION attack detected.", | |
"Information: Traffic diversion detected. Verify routing integrity and " | |
"check for signs of traffic manipulation or social engineering attempts.") | |
} | |
return messages.get(prediction, ("Unknown Traffic Pattern", "Additional analysis required.")) | |
def process_pcap_input(pcap_file): | |
"""Process PCAP file input""" | |
try: | |
model = xgb.Booster() | |
model.load_model("m3_xg_boost.model") | |
features_list = process_pcap_for_ids(pcap_file.name) | |
if not features_list: | |
return "No valid network traffic found in PCAP file." | |
results = [] | |
for idx, features in enumerate(features_list): | |
result_msg, result_info = predict_from_features(features, model) | |
results.append(f"Traffic Pattern {idx + 1}:\n{result_msg}\n{result_info}\n") | |
return "\n".join(results) | |
except Exception as e: | |
return f"Error processing PCAP file: {str(e)}" | |
def process_manual_input(port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
port_duration, delta_rx_bytes, delta_tx_bytes, | |
delta_duration, conn_point, total_load, active_flows, | |
packets_matched): | |
#Process manual input values | |
try: | |
model = xgb.Booster() | |
model.load_model("m3_xg_boost.model") | |
features = [ | |
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
port_duration, delta_rx_bytes, delta_tx_bytes, delta_duration, | |
conn_point, total_load, active_flows, packets_matched | |
] | |
result_msg, result_info = predict_from_features(features, model) | |
return f"{result_msg}\n{result_info}" | |
except Exception as e: | |
return f"Error processing manual input: {str(e)}" | |
# Main execution | |
if __name__ == "__main__": | |
# Create the interface | |
with gr.Blocks(theme="default") as interface: | |
gr.Markdown(""" | |
# Network Intrusion Detection System | |
Upload a PCAP file or use manual input to detect potential network attacks. | |
""") | |
with gr.Tab("PCAP Analysis"): | |
pcap_input = gr.File( | |
label="Upload PCAP File", | |
file_types=[".pcap", ".pcapng"] | |
) | |
pcap_output = gr.Textbox(label="Analysis Results") | |
pcap_button = gr.Button("Analyze PCAP") | |
pcap_button.click( | |
fn=process_pcap_input, | |
inputs=[pcap_input], | |
outputs=pcap_output | |
) | |
with gr.Tab("Manual Input"): | |
# Manual input components | |
with gr.Row(): | |
port_num = gr.Slider(1, 4, value=1, | |
label="Port Number - The switch port through which the flow passed") | |
rx_packets = gr.Slider(0, 352772, value=0, | |
label="Received Packets - Number of packets received by the port") | |
with gr.Row(): | |
rx_bytes = gr.Slider(0, 2.715916e08, value=0, | |
label="Received Bytes - Number of bytes received by the port") | |
tx_bytes = gr.Slider(0, 2.392430e08, value=0, | |
label="Sent Bytes - Number of bytes sent by the port") | |
with gr.Row(): | |
tx_packets = gr.Slider(0, 421598, value=0, | |
label="Sent Packets - Number of packets sent by the port") | |
port_duration = gr.Slider(0, 3317, value=0, | |
label="Port alive Duration (S) - The time port has been alive in seconds") | |
with gr.Row(): | |
delta_rx_bytes = gr.Slider(0, 6500000, value=0, | |
label="Delta Received Bytes") | |
delta_tx_bytes = gr.Slider(0, 6500000, value=0, | |
label="Delta Sent Bytes") | |
with gr.Row(): | |
delta_duration = gr.Slider(0, 5, value=0, | |
label="Delta Port alive Duration (S)") | |
conn_point = gr.Slider(1, 5, value=1, | |
label="Connection Point") | |
with gr.Row(): | |
total_load = gr.Slider(0, 1800000, value=0, | |
label="Total Load/Rate") | |
active_flows = gr.Slider(0, 610, value=0, | |
label="Active Flow Entries") | |
with gr.Row(): | |
packets_matched = gr.Slider(0, 1020000, value=0, | |
label="Packets Matched") | |
manual_output = gr.Textbox(label="Analysis Results") | |
manual_button = gr.Button("Analyze Manual Input") | |
# Connect manual input components | |
manual_button.click( | |
fn=process_manual_input, | |
inputs=[ | |
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
port_duration, delta_rx_bytes, delta_tx_bytes, | |
delta_duration, conn_point, total_load, active_flows, | |
packets_matched | |
], | |
outputs=manual_output | |
) | |
# Example inputs | |
gr.Examples( | |
examples=[ | |
[4, 350188, 14877116, 101354648, 159524, 2910, 278, 280, | |
5, 4, 0, 6, 667324], | |
[2, 2326, 12856942, 31777516, 2998, 2497, 560, 560, | |
5, 2, 0, 4, 7259], | |
[4, 150, 19774, 6475473, 3054, 166, 556, 6068, | |
5, 4, 502, 6, 7418], | |
[2, 209, 20671, 6316631, 274, 96, 3527, 2757949, | |
5, 2, 183877, 8, 90494], | |
[2, 1733, 37865130, 38063670, 3187, 2152, 0, 556, | |
5, 3, 0, 4, 14864] | |
], | |
inputs=[ | |
port_num, rx_packets, rx_bytes, tx_bytes, tx_packets, | |
port_duration, delta_rx_bytes, delta_tx_bytes, | |
delta_duration, conn_point, total_load, active_flows, | |
packets_matched | |
] | |
) | |
# Launch the interface | |
interface.launch() |