karimaloulou commited on
Commit
95319f8
·
verified ·
1 Parent(s): a03397e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -26
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
3
  from mitreattack.stix20 import MitreAttackData
4
- from descriptions import descriptions # Assurez-vous que descriptions.py est dans le même répertoire
5
 
6
  # Chemins des fichiers JSON
 
7
  enterprise_attack_path = 'enterprise-attack.json'
8
 
9
  # Charger les données ATT&CK
@@ -17,11 +19,15 @@ techniques_str = "\n".join([f"{technique['name']} ({mitre_attack_data.get_attack
17
 
18
  client = InferenceClient(model='mistralai/Mixtral-8x7B-Instruct-v0.1')
19
 
20
- def generate_system_message(log_input):
21
- description_output = descriptions(log_input)
22
- return f"""<s>[INST] Given these TTPs: {techniques_str}\n\n and here are the descriptions: {description_output}\n\nFigure out which technique is used in these logs and respond in bullet points and nothing else.[/INST]"""
23
 
24
- def respond(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
 
 
 
 
25
  messages = [{"role": "system", "content": system_message}]
26
 
27
  for val in history:
@@ -31,6 +37,8 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
31
  messages.append({"role": "assistant", "content": val[1]})
32
 
33
  messages.append({"role": "user", "content": message})
 
 
34
  response = ""
35
 
36
  for message in client.chat_completion(
@@ -41,38 +49,19 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
41
  top_p=top_p,
42
  ):
43
  token = message.choices[0].delta.content
 
44
  response += token
45
  yield response
46
 
47
- def on_log_input_change(log_input):
48
- system_message = generate_system_message(log_input)
49
- return system_message
50
 
51
  demo = gr.ChatInterface(
52
  respond,
53
  additional_inputs=[
54
- gr.Textbox(label="Log Input", placeholder="Enter log here...", lines=4),
55
- gr.Textbox(value=generate_system_message(""), label="System message", interactive=False),
56
  gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
57
  gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
58
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
59
  ],
60
- title="TTP Detection Interface",
61
- description="Enter logs to detect TTPs using the model.",
62
- )
63
-
64
- # Met à jour le message système lorsque l'entrée change
65
- def update_system_message(log_input):
66
- return generate_system_message(log_input)
67
-
68
- # Fonction pour mettre à jour les valeurs de l'interface
69
- def interface_update(log_input, *args):
70
- system_message = update_system_message(log_input)
71
- return gr.update(value=system_message)
72
-
73
- # Associe l'entrée des logs à la fonction de mise à jour
74
- demo.add_component(
75
- gr.Textbox(label="Log Input", placeholder="Enter log here...", lines=4, change=interface_update)
76
  )
77
 
78
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import os
4
  from mitreattack.stix20 import MitreAttackData
5
+ from descriptions import descriptions
6
 
7
  # Chemins des fichiers JSON
8
+ ics_attack_path = 'ics-attack.json'
9
  enterprise_attack_path = 'enterprise-attack.json'
10
 
11
  # Charger les données ATT&CK
 
19
 
20
  client = InferenceClient(model='mistralai/Mixtral-8x7B-Instruct-v0.1')
21
 
 
 
 
22
 
23
+ def respond(
24
+ message,
25
+ history: list[tuple[str, str]],
26
+ system_message,
27
+ max_tokens,
28
+ temperature,
29
+ top_p,
30
+ ):
31
  messages = [{"role": "system", "content": system_message}]
32
 
33
  for val in history:
 
37
  messages.append({"role": "assistant", "content": val[1]})
38
 
39
  messages.append({"role": "user", "content": message})
40
+ message_content = message
41
+
42
  response = ""
43
 
44
  for message in client.chat_completion(
 
49
  top_p=top_p,
50
  ):
51
  token = message.choices[0].delta.content
52
+
53
  response += token
54
  yield response
55
 
 
 
 
56
 
57
  demo = gr.ChatInterface(
58
  respond,
59
  additional_inputs=[
60
+ gr.Textbox(value=f"""<s>[INST] Given these TTPs: {techniques_str}\n\n and here are {descriptions}\n\nfigure out which technique is used in these logs and respond in bullets points and nothing else[/INST]""", label="System message"),
 
61
  gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
62
  gr.Slider(minimum=0.1, maximum=1.0, value=0.1, step=0.1, label="Temperature"),
63
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
64
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
  if __name__ == "__main__":