elselse commited on
Commit
1d606fb
·
verified ·
1 Parent(s): 88f4376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,23 +1,36 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- MODEL_NAME = "CIRCL/cwe-parent-vulnerability-classification-roberta-base"
5
-
6
- classifier = pipeline("text-classification", model=MODEL_NAME, return_all_scores=True)
 
 
 
7
 
8
- def classify_cwe(text):
9
- results = classifier(text)[0]
10
- # Sort by confidence score descending
 
 
 
11
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
12
- return {res["label"]: round(res["score"], 4) for res in sorted_results[:5]}
 
13
 
14
- interface = gr.Interface(
15
- fn=classify_cwe,
16
- inputs=gr.Textbox(lines=5, placeholder="Enter vulnerability commit message..."),
 
17
  outputs=gr.Label(num_top_classes=5),
18
- title="CWE Vulnerability Classifier",
19
- description="Enter a vulnerability commit message to predict the most likely CWE parent family"
 
 
 
 
20
  )
21
 
22
- # Launch the Gradio app
23
- interface.launch()
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ # Load the Hugging Face model for text classification
5
+ classifier = pipeline(
6
+ task="text-classification",
7
+ model="CIRCL/cwe-parent-vulnerability-classification-roberta-base",
8
+ return_all_scores=True
9
+ )
10
 
11
+ def predict_cwe(commit_message: str):
12
+ """
13
+ Predict CWE(s) from a commit message using the model.
14
+ """
15
+ results = classifier(commit_message)[0]
16
+ # Sort the results by score descending
17
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
18
+ # Return top 5 predictions as a dictionary
19
+ return {item["label"]: round(float(item["score"]), 4) for item in sorted_results[:5]}
20
 
21
+ # Build the Gradio interface
22
+ demo = gr.Interface(
23
+ fn=predict_cwe,
24
+ inputs=gr.Textbox(lines=3, placeholder="Enter your commit message here..."),
25
  outputs=gr.Label(num_top_classes=5),
26
+ title="CWE Prediction from Commit Message",
27
+ description="Type a Git commit message and get the most likely CWE classes predicted by the model.",
28
+ examples=[
29
+ ["Fixed buffer overflow in input parsing"],
30
+ ["SQL injection possible in user login endpoint"]
31
+ ]
32
  )
33
 
34
+ if __name__ == "__main__":
35
+ demo.launch()
36
+