SQLPromptShield / README.md
salmane11's picture
Update README.md
693c0e0 verified
metadata
library_name: transformers
tags:
  - text-to-SQL
  - SQL
  - code-generation
  - NLQ-to-SQL
  - text2SQL
  - Security
  - Vulnerability detection
datasets:
  - salmane11/SQLShield
language:
  - en
base_model:
  - google-bert/bert-base-uncased

SQLPromptShield

Model Description

SQLPromptShield is a vulnerable prompt detection model. It classifies text-to-SQL prompts as either vulnerable or benign.

The checkpoint included in this repository is based on google-bert/bert-base-uncased and further finetuned on SQLShield, a dataset dedicated to text-to-SQL vulnerability detection composed of vulnerable and safe NLQs and their related SQL queries.

Finetuning Procedure

The model was fine-tuned using the Hugging Face Transformers library. The following steps were used:

  1. Dataset: SQLShield dataset, which consists of labeled pairs of (natural language query, SQL query) with binary classification labels: vulnerable or benign.

  2. Preprocessing:

    • Input Format: Only the natural language query (NLQ) was used as input for classification.

    • Tokenization: Tokenized using bert-base-uncased tokenizer.

    • Max Length: 128 tokens.

    • Padding and truncation applied.

Intended Use and Limitations

SQLPromptShield is intended to be used as a pre-processing filter in applications where natural language queries are converted to SQL. Its main goal is to detect potentially malicious or unsafe inputs before they are passed to SQL generation models or database systems.

Ideal use cases : Natural language interfaces for databases (Text-to-SQL integrated applications)

How to Use

Example 1: Malicious

from transformers import pipeline

sql_prompt_shield = pipeline("text-classification", model="salmane11/SQLPromptShield")

# For the following Table schema
# CREATE TABLE campuses
#   (
#      campus   VARCHAR,
#      location VARCHAR
#   )

input_text = "What are the names of all campuses located at ' UNION SELECT database() #?"

# Text-to-SQL models will generate : SELECT campus FROM campuses WHERE LOCATION = '' UNION SELECT database() #'
# This query when executed will display the database sensitive information like db name and DBMS version

predicted_label = sql_prompt_shield(input_text)
print(predicted_label)
#[{'label': 'MALICIOUS', 'score': 0.9995930790901184}]

Example 2: Safe

from transformers import pipeline

sql_prompt_shield = pipeline("text-classification", model="salmane11/SQLPromptShield")

# For the following Table schema
# CREATE TABLE tv_channel
#   (
#      package_option VARCHAR,
#      series_name    VARCHAR
#   ) 

input_text = "What is the Package Option of TV Channel with serial name 'Sky Radio'?"

# Text-to-SQL models will generate : SELECT Package_Option FROM TV_Channel WHERE series_name = "Sky Radio"
# Which is a safe query.

predicted_label = sql_prompt_shield(input_text)
print(predicted_label)
#[{'label': 'SAFE', 'score': 0.998808741569519}]

Cite our work

Citation