elselse commited on
Commit
9ba88a4
·
verified ·
1 Parent(s): 1d606fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -1,36 +1,50 @@
1
  import gradio as gr
 
2
  from transformers import pipeline
3
 
4
- # Load the Hugging Face model for text classification
5
  classifier = pipeline(
6
  task="text-classification",
7
  model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
8
  return_all_scores=True
9
  )
10
 
 
 
 
 
11
  def predict_cwe(commit_message: str):
12
  """
13
- Predict CWE(s) from a commit message using the model.
14
  """
15
  results = classifier(commit_message)[0]
16
- # Sort the results by score descending
17
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
18
- # Return top 5 predictions as a dictionary
19
- return {item["label"]: round(float(item["score"]), 4) for item in sorted_results[:5]}
20
 
21
- # Build the Gradio interface
 
 
 
 
 
 
 
 
 
22
  demo = gr.Interface(
23
  fn=predict_cwe,
24
  inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
25
  outputs=gr.Label(num_top_classes=5),
26
  title="CWE Prediction from Commit Message",
27
- description="Type a Git commit message and get the most likely CWE classes predicted by the model.",
 
28
  examples=[
29
  ["Fixed buffer overflow in input parsing"],
30
- ["SQL injection possible in user login endpoint"]
 
 
 
31
  ]
32
  )
33
 
34
  if __name__ == "__main__":
35
  demo.launch()
36
-
 
1
  import gradio as gr
2
+ import json
3
  from transformers import pipeline
4
 
5
+ # Load Hugging Face model (text classification)
6
  classifier = pipeline(
7
  task="text-classification",
8
  model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
9
  return_all_scores=True
10
  )
11
 
12
+ # Load child-to-parent mapping
13
+ with open("vulntrain/trainers/child_to_parent_mapping.json", "r") as f:
14
+ child_to_parent = json.load(f)
15
+
16
  def predict_cwe(commit_message: str):
17
  """
18
+ Predict CWE(s) from a commit message and map to parent CWEs.
19
  """
20
  results = classifier(commit_message)[0]
 
21
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
 
 
22
 
23
+ # Map predictions to parent CWE (if available)
24
+ mapped_results = {}
25
+ for item in sorted_results[:5]:
26
+ child_cwe = item["label"].replace("CWE-", "")
27
+ parent_cwe = child_to_parent.get(child_cwe, child_cwe) # default to child if no parent
28
+ mapped_results[f"CWE-{parent_cwe}"] = round(float(item["score"]), 4)
29
+
30
+ return mapped_results
31
+
32
+ # Gradio UI
33
  demo = gr.Interface(
34
  fn=predict_cwe,
35
  inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
36
  outputs=gr.Label(num_top_classes=5),
37
  title="CWE Prediction from Commit Message",
38
+ description="This tool uses a fine-tuned model to predict CWE categories from Git commit messages. "
39
+ "Predicted child CWEs are mapped to their parent CWEs if applicable.",
40
  examples=[
41
  ["Fixed buffer overflow in input parsing"],
42
+ ["SQL injection possible in login flow"],
43
+ ["Improved input validation to prevent XSS"],
44
+ ["Added try/catch to avoid null pointer crash"],
45
+ ["Patched race condition in thread lock logic"]
46
  ]
47
  )
48
 
49
  if __name__ == "__main__":
50
  demo.launch()