Spaces:
Sleeping
Sleeping
File size: 4,175 Bytes
eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3d7830a eb30cad 3c6d0fe 3a6bb00 eb30cad 3c6d0fe 3a6bb00 eb30cad 3a6bb00 eb30cad 120d185 3d7830a 26ce0ac 3d7830a 120d185 eb30cad 120d185 eb30cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import tensorflow as tf
import pickle
import numpy as np
from sklearn.preprocessing import LabelEncoder
# Load saved components
with open('preprocessing_params.pkl', 'rb') as f:
preprocessing_params = pickle.load(f)
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
with open('url_tokenizer.pkl', 'rb') as f:
url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
html_tokenizer = pickle.load(f)
# Load the model with custom loss
@tf.keras.utils.register_keras_serializable()
class EWCLoss(tf.keras.losses.Loss):
def __init__(self, model=None, fisher_information=None, importance=1.0, reduction='auto', name=None):
super(EWCLoss, self).__init__(reduction=reduction, name=name)
self.model = model
self.fisher_information = fisher_information
self.importance = importance
self.prev_weights = [layer.numpy() for layer in model.trainable_weights] if model else None
def call(self, y_true, y_pred):
standard_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
ewc_loss = 0.0
for layer, fisher_info, prev_weight in zip(self.model.trainable_weights, self.fisher_information, self.prev_weights):
ewc_loss += tf.reduce_sum(fisher_info * tf.square(layer - prev_weight))
return standard_loss + (self.importance / 2.0) * ewc_loss
def get_config(self):
config = super().get_config()
config.update({
'importance': self.importance,
'reduction': self.reduction,
'name': self.name,
})
return config
@classmethod
def from_config(cls, config):
with open('fisher_information.pkl', 'rb') as f:
fisher_information = pickle.load(f)
return cls(model=None, fisher_information=fisher_information, **config)
# Load the model first without the custom loss
model = tf.keras.models.load_model('new_phishing_detection_model.keras', compile=False)
# Reconstruct the EWC loss
ewc_loss = EWCLoss(model=model, fisher_information=fisher_information, importance=1000)
# Compile the model with EWC loss and metrics
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005),
loss=ewc_loss,
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Function to preprocess input
def preprocess_input(input_text, tokenizer, max_length):
sequences = tokenizer.texts_to_sequences([input_text])
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
return padded_sequences
# Function to get prediction
def get_prediction(input_text, input_type):
is_url = input_type == "URL"
if is_url:
input_data = preprocess_input(input_text, url_tokenizer, preprocessing_params['max_url_length'])
input_data = [input_data, np.zeros((1, preprocessing_params['max_html_length']))] # dummy HTML input
else:
input_data = preprocess_input(input_text, html_tokenizer, preprocessing_params['max_html_length'])
input_data = [np.zeros((1, preprocessing_params['max_url_length'])), input_data] # dummy URL input
prediction = model.predict(input_data)[0][0]
return prediction
# Gradio UI
def phishing_detection(input_text, input_type):
prediction = get_prediction(input_text, input_type)
if prediction > 0.5:
return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
else:
return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
iface = gr.Interface(
fn=phishing_detection,
inputs=[
gr.components.Textbox(lines=5, placeholder="Enter URL or HTML code"),
gr.components.Radio(["URL", "HTML"], type="value", label="Input Type")
],
outputs=gr.components.Textbox(label="Phishing Detection Result"),
title="Phishing Detection with Enhanced EWC Model",
description="Check if a URL or HTML is Phishing.",
theme="default"
)
iface.launch() |