SQLQueryShield
Model Description
SQLQueryShield is a vulnerable SQL query detection model. It classifies SQL queries as either vulnerable (e.g., prone to SQL injection or unsafe execution) or benign (safe to execute).
The checkpoint included in this repository is based on microsoft/codebert-base and further finetuned on SQLShield, a dataset dedicated to text-to-SQL vulnerability detection composed of vulnerable and safe NLQs and their related SQL queries.
Finetuning Procedure
The model was fine-tuned using the Hugging Face Transformers library. The following steps were used:
Dataset: SSQLShield, only the SQL queries from the (NLQ, SQL) pairs were used for training.
Preprocessing:
Input Format: Raw SQL query strings.
Tokenization: Tokenized using microsoft/codebert-base.
Max Length: 128 tokens.
Padding and truncation applied.
Intended Use and Limitations
SQLQueryShield is intended for use as a post-generation filter or analysis tool in any system that executes or generates SQL queries. Its main role is to detect whether a SQL query is potentially harmful due to vulnerability patterns such as SQL injection, improper string concatenation, or unsafe expressions.
Ideal use cases:
- Filtering SQL queries in Text-to-SQL applications
- Post-processing or validating user-generated SQL before execution
How to Use
Example 1: Malicious
from transformers import pipeline
sql_query_shield = pipeline("text-classification", model="salmane11/SQLQueryShield")
# For the following Table schema
# CREATE TABLE campuses
# (
# campus VARCHAR,
# location VARCHAR
# )
query = "SELECT campus FROM campuses WHERE location = '' UNION SELECT database() --"
prediction = sql_query_shield(query)
print(prediction)
#[{'label': 'MALICIOUS', 'score': 0.9995294809341431}]
Example 2: Safe
from transformers import pipeline
sql_query_shield = pipeline("text-classification", model="salmane11/SQLQueryShield")
# For the following Table schema
# CREATE TABLE tv_channel
# (
# package_option VARCHAR,
# series_name VARCHAR
# )
query = "SELECT package_option FROM tv_channel WHERE series_name = 'Sky Radio'"
prediction = sql_query_shield(query)
print(prediction)
#[{'label': 'SAFE', 'score': 0.999503493309021}]
Cite our work
Citation
- Downloads last month
- 4
Model tree for salmane11/SQLQueryShield
Base model
microsoft/codebert-base