|
--- |
|
library_name: transformers |
|
tags: |
|
- text-to-SQL |
|
- SQL |
|
- code-generation |
|
- NLQ-to-SQL |
|
- text2SQL |
|
- Security |
|
- Vulnerability detection |
|
datasets: |
|
- salmane11/SQLShield |
|
language: |
|
- en |
|
base_model: |
|
- google-bert/bert-base-uncased |
|
--- |
|
|
|
# SQLPromptShield |
|
|
|
## Model Description |
|
|
|
SQLPromptShield is a vulnerable prompt detection model. It classifies text-to-SQL prompts as either vulnerable or benign. |
|
|
|
The checkpoint included in this repository is based on [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) and further finetuned on [SQLShield](https://huggingface.co/datasets/salmane11/SQLShield), a dataset dedicated to text-to-SQL vulnerability detection composed of vulnerable and safe NLQs and their related SQL queries. |
|
|
|
|
|
## Finetuning Procedure |
|
The model was fine-tuned using the Hugging Face Transformers library. The following steps were used: |
|
|
|
1. Dataset: SQLShield dataset, which consists of labeled pairs of (natural language query, SQL query) with binary classification labels: vulnerable or benign. |
|
|
|
2. Preprocessing: |
|
|
|
- Input Format: Only the natural language query (NLQ) was used as input for classification. |
|
|
|
- Tokenization: Tokenized using bert-base-uncased tokenizer. |
|
|
|
- Max Length: 128 tokens. |
|
|
|
- Padding and truncation applied. |
|
|
|
## Intended Use and Limitations |
|
|
|
SQLPromptShield is intended to be used as a pre-processing filter in applications where natural language queries are converted to SQL. Its main goal is to detect potentially malicious or unsafe inputs before they are passed to SQL generation models or database systems. |
|
|
|
Ideal use cases : Natural language interfaces for databases (Text-to-SQL integrated applications) |
|
|
|
|
|
## How to Use |
|
|
|
Example 1: Malicious |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
sql_prompt_shield = pipeline("text-classification", model="salmane11/SQLPromptShield") |
|
|
|
# For the following Table schema |
|
# CREATE TABLE campuses |
|
# ( |
|
# campus VARCHAR, |
|
# location VARCHAR |
|
# ) |
|
|
|
input_text = "What are the names of all campuses located at ' UNION SELECT database() #?" |
|
|
|
# Text-to-SQL models will generate : SELECT campus FROM campuses WHERE LOCATION = '' UNION SELECT database() #' |
|
# This query when executed will display the database sensitive information like db name and DBMS version |
|
|
|
predicted_label = sql_prompt_shield(input_text) |
|
print(predicted_label) |
|
#[{'label': 'MALICIOUS', 'score': 0.9995930790901184}] |
|
``` |
|
|
|
|
|
Example 2: Safe |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
sql_prompt_shield = pipeline("text-classification", model="salmane11/SQLPromptShield") |
|
|
|
# For the following Table schema |
|
# CREATE TABLE tv_channel |
|
# ( |
|
# package_option VARCHAR, |
|
# series_name VARCHAR |
|
# ) |
|
|
|
input_text = "What is the Package Option of TV Channel with serial name 'Sky Radio'?" |
|
|
|
# Text-to-SQL models will generate : SELECT Package_Option FROM TV_Channel WHERE series_name = "Sky Radio" |
|
# Which is a safe query. |
|
|
|
predicted_label = sql_prompt_shield(input_text) |
|
print(predicted_label) |
|
#[{'label': 'SAFE', 'score': 0.998808741569519}] |
|
``` |
|
|
|
|
|
## Cite our work |
|
|
|
Citation |