qJFI commited on
Commit
0cbb679
·
verified ·
1 Parent(s): 328fdf9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -0
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import xgboost as xgb
3
+ import gradio as gr
4
+ from scapy.all import rdpcap
5
+ from collections import defaultdict
6
+ import os
7
+
8
+ def transform_new_input(new_input):
9
+ #Scale input features based on predetermined min/max values
10
+ scaled_min = np.array([
11
+ 1.0, 10.0, 856.0, 5775.0, 42.0, 26.0, 0.0, 278.0, 4.0, 1.0,
12
+ -630355.0, 4.0, 50.0
13
+ ])
14
+
15
+ scaled_max = np.array([
16
+ 4.0, 352752.0, 271591638.0, 239241314.0, 421552.0, 3317.0,
17
+ 6302708.0, 6302708.0, 5.0, 5.0, 1746749.0, 608.0, 1012128.0
18
+ ])
19
+
20
+ new_input = np.array(new_input)
21
+ scaled_input = (new_input - scaled_min) / (scaled_max - scaled_min)
22
+ return scaled_input
23
+
24
+ class PcapProcessor:
25
+ def __init__(self, pcap_file):
26
+ #Initialize PCAP processor with file path
27
+ self.packets = rdpcap(pcap_file)
28
+ self.start_time = None
29
+ self.port_stats = defaultdict(lambda: {
30
+ 'rx_packets': 0,
31
+ 'rx_bytes': 0,
32
+ 'tx_packets': 0,
33
+ 'tx_bytes': 0,
34
+ 'first_seen': None,
35
+ 'last_seen': None,
36
+ 'active_flows': set(),
37
+ 'packets_matched': 0
38
+ })
39
+
40
+ def process_packets(self, window_size=60):
41
+ #Process all packets and extract features
42
+ if not self.packets:
43
+ return []
44
+
45
+ self.start_time = float(self.packets[0].time)
46
+
47
+ # Process each packet
48
+ for packet in self.packets:
49
+ current_time = float(packet.time)
50
+
51
+ if 'TCP' in packet or 'UDP' in packet:
52
+ try:
53
+ src_port = packet.sport
54
+ dst_port = packet.dport
55
+ pkt_size = len(packet)
56
+
57
+ # Track flow information
58
+ flow_tuple = (packet['IP'].src, packet['IP'].dst,
59
+ src_port, dst_port)
60
+
61
+ # Update port statistics
62
+ self._update_port_stats(src_port, pkt_size, True,
63
+ current_time, flow_tuple)
64
+ self._update_port_stats(dst_port, pkt_size, False,
65
+ current_time, flow_tuple)
66
+ except Exception as e:
67
+ print(f"Error processing packet {packet}: {str(e)}")
68
+ continue
69
+
70
+ # Extract features for each port
71
+ features_list = []
72
+ for port, stats in self.port_stats.items():
73
+ if stats['first_seen'] is not None:
74
+ features = self._extract_port_features(port, stats, window_size)
75
+ features_list.append(features)
76
+
77
+ return features_list
78
+
79
+ def _update_port_stats(self, port, pkt_size, is_source, current_time,
80
+ flow_tuple):
81
+ #Update statistics for a given port
82
+ stats = self.port_stats[port]
83
+
84
+ if stats['first_seen'] is None:
85
+ stats['first_seen'] = current_time
86
+
87
+ stats['last_seen'] = current_time
88
+
89
+ if is_source:
90
+ stats['tx_packets'] += 1
91
+ stats['tx_bytes'] += pkt_size
92
+ else:
93
+ stats['rx_packets'] += 1
94
+ stats['rx_bytes'] += pkt_size
95
+
96
+ stats['active_flows'].add(flow_tuple)
97
+ stats['packets_matched'] += 1
98
+
99
+ def _extract_port_features(self, port, stats, window_size):
100
+ #Extract the 13 features needed for the IDS model
101
+ port_alive_duration = stats['last_seen'] - stats['first_seen']
102
+ delta_alive_duration = min(port_alive_duration, window_size)
103
+
104
+ # Calculate rates and loads
105
+ total_load = (stats['rx_bytes'] + stats['tx_bytes']) / \
106
+ max(port_alive_duration, 1)
107
+
108
+ features = [
109
+ min(port % 4 + 1, 4), # Port Number (1-4)
110
+ stats['rx_packets'], # Received Packets
111
+ stats['rx_bytes'], # Received Bytes
112
+ stats['tx_bytes'], # Sent Bytes
113
+ stats['tx_packets'], # Sent Packets
114
+ port_alive_duration, # Port alive Duration
115
+ stats['rx_bytes'], # Delta Received Bytes
116
+ stats['tx_bytes'], # Delta Sent Bytes
117
+ min(delta_alive_duration, 5), # Delta Port alive Duration
118
+ min((port % 5) + 1, 5), # Connection Point
119
+ total_load, # Total Load/Rate
120
+ len(stats['active_flows']), # Active Flow Entries
121
+ stats['packets_matched'] # Packets Matched
122
+ ]
123
+
124
+ return features
125
+
126
+ def process_pcap_for_ids(pcap_file):
127
+ """Process PCAP file and return features for IDS model"""
128
+ processor = PcapProcessor(pcap_file)
129
+ features = processor.process_packets()
130
+ return features
131
+
132
+ def predict_from_features(features, model):
133
+ """Make prediction from extracted features"""
134
+ # Scale features
135
+ scaled_features = transform_new_input(features)
136
+ features_matrix = xgb.DMatrix(scaled_features.reshape(1, -1))
137
+
138
+ # Make prediction and get probability distribution
139
+ raw_prediction = model.predict(features_matrix)
140
+ probabilities = raw_prediction[0] # Get probability distribution
141
+ prediction = np.argmax(probabilities)
142
+
143
+ # Add threshold for normal traffic
144
+ # If highest probability is for normal (class 0) and exceeds threshold
145
+ if prediction == 0 and probabilities[0] > 0.6: # 60% confidence threshold
146
+ return get_prediction_message(0)
147
+ # If no class has high confidence, consider it normal
148
+ elif np.max(probabilities) < 0.4: # Low confidence threshold
149
+ return get_prediction_message(0)
150
+ else:
151
+ return get_prediction_message(prediction)
152
+
153
+ def get_prediction_message(prediction):
154
+ """Get formatted prediction message with confidence levels"""
155
+ messages = {
156
+ 0: ("NORMAL TRAFFIC - No indication of attack.",
157
+ "Traffic patterns appear to be within normal parameters."),
158
+ 1: ("ALERT: Potential BLACKHOLE attack detected.",
159
+ "Information: BLACKHOLE attacks occur when a router maliciously drops "
160
+ "packets it should forward. Investigate affected routes and traffic patterns."),
161
+ 2: ("ALERT: Potential TCP-SYN flood attack detected.",
162
+ "Information: TCP-SYN flood is a DDoS attack exhausting server resources "
163
+ "with half-open connections. Check connection states and implement SYN cookies."),
164
+ 3: ("ALERT: PORTSCAN activity detected.",
165
+ "Information: Port scanning detected - systematic probing of system ports. "
166
+ "Review firewall rules and implement connection rate limiting."),
167
+ 4: ("ALERT: Potential DIVERSION attack detected.",
168
+ "Information: Traffic diversion detected. Verify routing integrity and "
169
+ "check for signs of traffic manipulation or social engineering attempts.")
170
+ }
171
+ return messages.get(prediction, ("Unknown Traffic Pattern", "Additional analysis required."))
172
+
173
+ def process_pcap_input(pcap_file):
174
+ """Process PCAP file input"""
175
+ try:
176
+ model = xgb.Booster()
177
+ model.load_model("m3_xg_boost.model")
178
+ features_list = process_pcap_for_ids(pcap_file.name)
179
+ if not features_list:
180
+ return "No valid network traffic found in PCAP file."
181
+
182
+ results = []
183
+ for idx, features in enumerate(features_list):
184
+ result_msg, result_info = predict_from_features(features, model)
185
+ results.append(f"Traffic Pattern {idx + 1}:\n{result_msg}\n{result_info}\n")
186
+
187
+ return "\n".join(results)
188
+ except Exception as e:
189
+ return f"Error processing PCAP file: {str(e)}"
190
+
191
+ def process_manual_input(port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
192
+ port_duration, delta_rx_bytes, delta_tx_bytes,
193
+ delta_duration, conn_point, total_load, active_flows,
194
+ packets_matched):
195
+ #Process manual input values
196
+ try:
197
+ model = xgb.Booster()
198
+ model.load_model("m3_xg_boost.model")
199
+ features = [
200
+ port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
201
+ port_duration, delta_rx_bytes, delta_tx_bytes, delta_duration,
202
+ conn_point, total_load, active_flows, packets_matched
203
+ ]
204
+
205
+ result_msg, result_info = predict_from_features(features, model)
206
+ return f"{result_msg}\n{result_info}"
207
+ except Exception as e:
208
+ return f"Error processing manual input: {str(e)}"
209
+
210
+ # Main execution
211
+ if __name__ == "__main__":
212
+ # Create the interface
213
+ with gr.Blocks(theme="default") as interface:
214
+ gr.Markdown("""
215
+ # Network Intrusion Detection System
216
+ Upload a PCAP file or use manual input to detect potential network attacks.
217
+ """)
218
+
219
+ with gr.Tab("PCAP Analysis"):
220
+ pcap_input = gr.File(
221
+ label="Upload PCAP File",
222
+ file_types=[".pcap", ".pcapng"]
223
+ )
224
+ pcap_output = gr.Textbox(label="Analysis Results")
225
+ pcap_button = gr.Button("Analyze PCAP")
226
+ pcap_button.click(
227
+ fn=process_pcap_input,
228
+ inputs=[pcap_input],
229
+ outputs=pcap_output
230
+ )
231
+
232
+ with gr.Tab("Manual Input"):
233
+ # Manual input components
234
+ with gr.Row():
235
+ port_num = gr.Slider(1, 4, value=1,
236
+ label="Port Number - The switch port through which the flow passed")
237
+ rx_packets = gr.Slider(0, 352772, value=0,
238
+ label="Received Packets - Number of packets received by the port")
239
+
240
+ with gr.Row():
241
+ rx_bytes = gr.Slider(0, 2.715916e08, value=0,
242
+ label="Received Bytes - Number of bytes received by the port")
243
+ tx_bytes = gr.Slider(0, 2.392430e08, value=0,
244
+ label="Sent Bytes - Number of bytes sent by the port")
245
+
246
+ with gr.Row():
247
+ tx_packets = gr.Slider(0, 421598, value=0,
248
+ label="Sent Packets - Number of packets sent by the port")
249
+ port_duration = gr.Slider(0, 3317, value=0,
250
+ label="Port alive Duration (S) - The time port has been alive in seconds")
251
+
252
+ with gr.Row():
253
+ delta_rx_bytes = gr.Slider(0, 6500000, value=0,
254
+ label="Delta Received Bytes")
255
+ delta_tx_bytes = gr.Slider(0, 6500000, value=0,
256
+ label="Delta Sent Bytes")
257
+
258
+ with gr.Row():
259
+ delta_duration = gr.Slider(0, 5, value=0,
260
+ label="Delta Port alive Duration (S)")
261
+ conn_point = gr.Slider(1, 5, value=1,
262
+ label="Connection Point")
263
+
264
+ with gr.Row():
265
+ total_load = gr.Slider(0, 1800000, value=0,
266
+ label="Total Load/Rate")
267
+ active_flows = gr.Slider(0, 610, value=0,
268
+ label="Active Flow Entries")
269
+
270
+ with gr.Row():
271
+ packets_matched = gr.Slider(0, 1020000, value=0,
272
+ label="Packets Matched")
273
+
274
+ manual_output = gr.Textbox(label="Analysis Results")
275
+ manual_button = gr.Button("Analyze Manual Input")
276
+
277
+ # Connect manual input components
278
+ manual_button.click(
279
+ fn=process_manual_input,
280
+ inputs=[
281
+ port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
282
+ port_duration, delta_rx_bytes, delta_tx_bytes,
283
+ delta_duration, conn_point, total_load, active_flows,
284
+ packets_matched
285
+ ],
286
+ outputs=manual_output
287
+ )
288
+
289
+ # Example inputs
290
+ gr.Examples(
291
+ examples=[
292
+ [4, 350188, 14877116, 101354648, 159524, 2910, 278, 280,
293
+ 5, 4, 0, 6, 667324],
294
+ [2, 2326, 12856942, 31777516, 2998, 2497, 560, 560,
295
+ 5, 2, 0, 4, 7259],
296
+ [4, 150, 19774, 6475473, 3054, 166, 556, 6068,
297
+ 5, 4, 502, 6, 7418],
298
+ [2, 209, 20671, 6316631, 274, 96, 3527, 2757949,
299
+ 5, 2, 183877, 8, 90494],
300
+ [2, 1733, 37865130, 38063670, 3187, 2152, 0, 556,
301
+ 5, 3, 0, 4, 14864]
302
+ ],
303
+ inputs=[
304
+ port_num, rx_packets, rx_bytes, tx_bytes, tx_packets,
305
+ port_duration, delta_rx_bytes, delta_tx_bytes,
306
+ delta_duration, conn_point, total_load, active_flows,
307
+ packets_matched
308
+ ]
309
+ )
310
+
311
+ # Launch the interface
312
+ interface.launch()