ELECTRA Trainer for Prompt Injection Detection
colab notebook : https://colab.research.google.com/drive/11da3m_gYwmkURcjGn8_kp23GiM-INDrm?usp=sharing
Overview
This repository contains a fine-tuned ELECTRA model designed for detecting prompt injections in AI systems. The model classifies input prompts into two categories: benign and jailbreak. This approach aims to enhance the safety and robustness of AI applications.
Approach and Design Decisions
The primary goal of this project was to create a reliable model that can distinguish between safe and potentially harmful prompts. Key design decisions included:
Model Selection: I chose the ELECTRA model due to its efficient training process and strong performance on text classification tasks. ELECTRA's architecture allows for effective learning from limited data, which is crucial given the specificity of the task.
Data Preparation: A custom dataset was curated, consisting of diverse prompts labeled as either benign or jailbreak. The dataset aimed to balance both classes to mitigate biases during training.
Long Inputs: To handle prompts exceeding the maximum input length of the ELECTRA model, I used truncation. Even though there was a data loss , the model still managed to classify the prompt correctly.
Model Architecture and Training Strategy
The model is based on the google/electra-base-discriminator architecture. Here’s an overview of the training strategy:
Tokenization: I utilized the ELECTRA tokenizer to prepare input prompts. Padding and truncation were handled to ensure uniform input size.
Training Configuration:
- Learning Rate: Set to 5e-05 for stable convergence.
- Batch Size: A batch size of 16 was chosen to balance training speed and memory usage.
- Epochs: The model was trained for 2 epochs to prevent overfitting while still allowing sufficient learning from the dataset.
Evaluation: The model’s performance was evaluated on a validation set, focusing on metrics such as accuracy, precision, recall, and F1 score.
Key Results and Observations
- The model achieved a high accuracy rate on the validation set, indicating its effectiveness in distinguishing between benign and harmful prompts.
Instructions for Running the Inference Pipeline
To run the inference pipeline for classifying prompts, follow these steps:
Install Dependencies: Ensure you have Python installed, and then install the required libraries using pip:
pip install transformers datasets torch
# Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
Tokenizer = AutoTokenizer.from_pretrained("idanpers/JailBreakModel")
model = AutoModelForSequenceClassification.from_pretrained("idanpers/JailBreakModel")
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
report_to="none", # Disable W&B
save_safetensors=False,
)
# Create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,
)
use:
def classify_prompt(prompt):
# Error handling for empty input
if not isinstance(prompt, str) or prompt.strip() == "":
return {"error": "Invalid input. Please provide a non-empty text prompt."}
# Tokenize the input prompt and convert to dataset format expected by trainer.predict
inputs = Tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
dataset = Dataset.from_dict({"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]})
# Use trainer.predict to classify
prediction_output = trainer.predict(dataset)
# Get the softmax probabilities for confidence scores
probs = torch.softmax(torch.tensor(prediction_output.predictions), dim=1).cpu().numpy()
confidence = np.max(probs)
pred_label = np.argmax(probs, axis=1)[0]
# Map prediction to label
label = "PROMPT_INJECTION" if pred_label == 1 else "BENIGN"
return {"label": label, "confidence": confidence}
#Accept input from the user and classify it
prompt = input("Enter a prompt for classification: ")
result = classify_prompt(prompt)
#Check for errors before accessing the classification result
if "error" in result:
print(f"Error: {result['error']}")
else:
print(f"Classification Result: {result['label']}")
print(f"Confidence Score: {result['confidence']:.2f}")
- Downloads last month
- 21
Model tree for idanpers/JailBreakModel
Base model
google/electra-base-discriminator