LaurentiuStancioiu commited on
Commit
38753f5
·
verified ·
1 Parent(s): a8fe9c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -70
app.py CHANGED
@@ -3,94 +3,85 @@ import os
3
  import pandas as pd
4
  import numpy as np
5
  import openai
6
- from sklearn.manifold import TSNE
7
  import joblib
8
  import gradio as gr
9
  from typing import Optional
10
  import altair as alt
11
- load_dotenv()
12
- #plt.style.use('seaborn-poster')
13
 
14
- EMBEDDING_MODEL = "text-embedding-ada-002"
 
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
 
 
 
 
 
 
17
 
18
  def get_embedding(text: str, model=EMBEDDING_MODEL) -> list[float]:
19
  """
20
- Gets a text as an input and the embedding model used from Openai
21
- Returns the embeddings of that blurb of text
22
  """
23
  return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"]
24
 
 
 
 
 
 
25
 
 
 
26
 
27
- def get_plot(website: Optional[str]) -> alt.Chart:
28
- df = pd.read_csv("data_plot.csv")
29
- matrix = np.array(df.embeddings.apply(eval).to_list())
30
- website_embed = get_embedding(website, model=EMBEDDING_MODEL)
31
- website_embed = np.array(website_embed)
32
- matrix = np.append(matrix, website_embed.reshape(1, -1), axis=0)
33
-
34
- tsne = TSNE(n_components=2, perplexity=50, random_state=42, init='random', learning_rate=200)
35
- vis_dims = tsne.fit_transform(matrix)
36
-
37
- df_vis = pd.DataFrame(vis_dims, columns=['x', 'y'])
38
- df_vis['type'] = df['type']
39
- df_vis["url"] = df["url"]
40
- df_vis.loc[df_vis.index[-1], 'type'] = 'Our Data'
41
- df_vis.loc[df_vis.index[-1], 'url'] = website
42
- # Define color scale
43
- scale = alt.Scale(domain=['benign', 'defacement', 'phishing', 'malware', 'Our Data'],
44
- range=['red', 'darkorange', 'gold', 'turquoise', 'black'])
45
-
46
- # Create the scatter plot
47
- scatter_plot = alt.Chart(df_vis).mark_circle(size=60).encode(
48
- x='x',
49
- y='y',
50
- color=alt.Color('type', scale=scale),
51
- tooltip=['type', 'url']
52
- ).interactive()
53
-
54
- return scatter_plot
55
-
56
- def predict_label(website: Optional[str] = "") -> str:
57
- """
58
- It takes the blurb of text and predicts whether it is malicious or not
59
 
60
- """
61
- loaded_model = joblib.load("model.joblib")
62
- embedding = get_embedding(website, model = EMBEDDING_MODEL)
63
- embedding = np.array(embedding)
64
- y_predicted = loaded_model.predict(embedding.reshape(1, -1))
65
- if y_predicted[0] == "benign":
66
- return "This website is most probably safe."
67
- elif y_predicted[0] != "benign":
68
- return "This website is most probably malicious."
69
-
70
- #def my_app(website: str):
71
- # return (get_plot(website), predict_label(website))
72
-
73
- #get_plot(website = "https://www.youtube.com/watch?v=RiCQzBluTxU")
74
- #print(predict_label(website = "https://www.youtube.com/watch?v=RiCQzBluTxU"))
75
-
76
- def gradio_interface(website: Optional[str] = ""):
77
- if website == "" or website == None:
78
- pass
79
- else:
80
- prediction = predict_label(website)
81
- #plot = get_plot(website)
82
 
83
- return prediction
 
84
 
85
- interface = gr.Interface(
86
- fn=gradio_interface,
87
- inputs="text",
88
- outputs="text",
89
- live=True,
90
- title="Malicious Website Detector",
91
- description="This website comes as a helping tool for those that want to surf safely on the internet.\n Attention: Not all predictions are true and this should be taken as a demo for now."
92
- )
93
- interface.launch()
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
 
3
  import pandas as pd
4
  import numpy as np
5
  import openai
6
+ from sklearn.manifold import TSNE
7
  import joblib
8
  import gradio as gr
9
  from typing import Optional
10
  import altair as alt
 
 
11
 
12
+ # Load environment variables and set API key
13
+ load_dotenv()
14
  openai.api_key = os.getenv("OPENAI_API_KEY")
15
 
16
+ # Load resources once
17
+ EMBEDDING_MODEL = "text-embedding-ada-002"
18
+ df = pd.read_csv("data_plot.csv")
19
+ matrix = np.array(df.embeddings.apply(eval).to_list())
20
+ loaded_model = joblib.load("model.joblib")
21
 
22
  def get_embedding(text: str, model=EMBEDDING_MODEL) -> list[float]:
23
  """
24
+ Gets a text as an input and the embedding model used from OpenAI.
25
+ Returns the embeddings of that blurb of text.
26
  """
27
  return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"]
28
 
29
+ def get_plot(website: Optional[str], matrix=matrix, df=df) -> alt.Chart:
30
+ if website:
31
+ website_embed = get_embedding(website, model=EMBEDDING_MODEL)
32
+ website_embed = np.array(website_embed)
33
+ updated_matrix = np.append(matrix, website_embed.reshape(1, -1), axis=0)
34
 
35
+ tsne = TSNE(n_components=2, perplexity=50, random_state=42, init='random', learning_rate=200)
36
+ vis_dims = tsne.fit_transform(updated_matrix)
37
 
38
+ df_vis = pd.DataFrame(vis_dims, columns=['x', 'y'])
39
+ df_vis['type'] = df['type'].tolist() + ['Our Data']
40
+ df_vis["url"] = df["url"].tolist() + [website]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ df_vis.reset_index(drop=True, inplace=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ scale = alt.Scale(domain=['benign', 'defacement', 'phishing', 'malware', 'Our Data'],
45
+ range=['red', 'darkorange', 'gold', 'turquoise', 'black'])
46
 
47
+ scatter_plot = alt.Chart(df_vis).mark_circle(size=60).encode(
48
+ x='x',
49
+ y='y',
50
+ color=alt.Color('type', scale=scale),
51
+ tooltip=['type', 'url']
52
+ ).interactive()
53
+ return scatter_plot
54
+ else:
55
+ return None
56
 
57
+ def predict_label(website: Optional[str] = "") -> str:
58
+ if website:
59
+ embedding = get_embedding(website, model=EMBEDDING_MODEL)
60
+ embedding = np.array(embedding)
61
+ y_predicted = loaded_model.predict(embedding.reshape(1, -1))
62
+ return "This website is most probably safe." if y_predicted[0] == "benign" else "This website is most probably malicious."
63
+ else:
64
+ return "Please enter a website URL."
65
+
66
+ def gradio_app():
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("# Malicious Website Detector")
69
+ gr.Markdown("This tool helps you identify potentially malicious websites. \n **Note:** This is a demonstration and results may not be accurate.")
70
+ website_input = gr.Textbox(label="Enter website URL")
71
+ predict_button = gr.Button("Predict")
72
+ prediction_output = gr.Textbox(label="Prediction", interactive=True) # Ensure the output is interactive
73
+ plot_output = gr.Plot(label="Website Embedding Plot")
74
+
75
+ def update_output(website):
76
+ prediction = predict_label(website)
77
+ plot = get_plot(website) if website else None
78
+ return prediction, plot
79
+
80
+ predict_button.click(update_output, inputs=website_input, outputs=[prediction_output, plot_output])
81
+
82
+ demo.launch()
83
+
84
+ if __name__ == "__main__":
85
+ gradio_app()
86
 
87