Switching to skops format, adding train.py
Browse files- .gitattributes +1 -0
- config.json +2 -2
- prompt_protect_model.skops +3 -0
- train.py +127 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
prompt_protect_model.skops filter=lfs diff=lfs merge=lfs -text
|
config.json
CHANGED
@@ -11,9 +11,9 @@
|
|
11 |
]
|
12 |
},
|
13 |
"model": {
|
14 |
-
"file": "skops
|
15 |
},
|
16 |
-
"model_format": "
|
17 |
"task": "text-classification"
|
18 |
}
|
19 |
}
|
|
|
11 |
]
|
12 |
},
|
13 |
"model": {
|
14 |
+
"file": "prompt_protect_model.skops"
|
15 |
},
|
16 |
+
"model_format": "skops",
|
17 |
"task": "text-classification"
|
18 |
}
|
19 |
}
|
prompt_protect_model.skops
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ef64d9cd521fb3e2b88bfd5ccc3cc75fd02054ea11788e47a7d945874591630
|
3 |
+
size 2100826
|
train.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Released under the MIT License by thevgergroup
|
2 |
+
# Copyright (c) 2024 thevgergroup
|
3 |
+
|
4 |
+
|
5 |
+
from sklearn.pipeline import Pipeline
|
6 |
+
|
7 |
+
from skops import card, hub_utils
|
8 |
+
|
9 |
+
from datasets import load_dataset
|
10 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
11 |
+
|
12 |
+
from sklearn.linear_model import LogisticRegression
|
13 |
+
from sklearn.metrics import classification_report
|
14 |
+
import os
|
15 |
+
from skops.io import dump
|
16 |
+
from pathlib import Path
|
17 |
+
from tempfile import mkdtemp, mkstemp
|
18 |
+
import sklearn
|
19 |
+
from argparse import ArgumentParser
|
20 |
+
|
21 |
+
|
22 |
+
# Define the default values
|
23 |
+
|
24 |
+
data = "deepset/prompt-injections"
|
25 |
+
save_directory = "models"
|
26 |
+
model_name = "prompt_protect_model"
|
27 |
+
repo_id = "thevgergroup/prompt_protect"
|
28 |
+
upload = False
|
29 |
+
commit_message = "Initial commit"
|
30 |
+
|
31 |
+
X_train, X_test, y_train, y_test = None, None, None, None
|
32 |
+
|
33 |
+
|
34 |
+
def load_data(data):
|
35 |
+
# Load the dataset
|
36 |
+
dataset = load_dataset(data)
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def split_data(dataset):
|
41 |
+
global X_train, X_test, y_train, y_test
|
42 |
+
# deepset data is already split into train and test
|
43 |
+
# replate this with your own data splitting logic for other datasets
|
44 |
+
df_train = dataset['train'].to_pandas()
|
45 |
+
df_test = dataset['test'].to_pandas()
|
46 |
+
X_train = df_train['text']
|
47 |
+
y_train = df_train['label']
|
48 |
+
X_test = df_test['text']
|
49 |
+
y_test = df_test['label']
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def train_model(X_train, y_train):
|
54 |
+
# Define the pipeline
|
55 |
+
model = Pipeline(
|
56 |
+
[
|
57 |
+
("vectorize",TfidfVectorizer(max_features=5000) ),
|
58 |
+
("lgr", LogisticRegression()),
|
59 |
+
]
|
60 |
+
)
|
61 |
+
# Fit the model
|
62 |
+
model.fit(X_train, y_train)
|
63 |
+
|
64 |
+
return model
|
65 |
+
|
66 |
+
def evaluate_model(model):
|
67 |
+
# Evaluate the model
|
68 |
+
global X_train, X_test, y_train, y_test
|
69 |
+
y_pred = model.predict(X_test)
|
70 |
+
return classification_report(y_test, y_pred)
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
|
75 |
+
|
76 |
+
parser = ArgumentParser()
|
77 |
+
parser.add_argument("--data", type=str, default="deepset/prompt-injections", help="Dataset to use for training, expects a huggingface dataset with train and test splits and text / label columns")
|
78 |
+
parser.add_argument("--save_directory", type=str, default="models/thevgergroup", help="Directory to save the model to")
|
79 |
+
parser.add_argument("--model_name", type=str, default="prompt_protect_model", help="Name of the model file, will have .skops extension added to it")
|
80 |
+
parser.add_argument("--repo_id", type=str, default="thevgergroup/prompt_protect", help="Repo to push the model to")
|
81 |
+
parser.add_argument("--upload", action="store_true", help="Upload the model to the hub, must be a contributor to the repo")
|
82 |
+
parser.add_argument("--commit-message", type=str, default="Initial commit", help="Commit message for the model push")
|
83 |
+
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
if any(vars(args).values()):
|
87 |
+
data = args.data
|
88 |
+
save_directory = args.save_directory
|
89 |
+
model_name = args.model_name
|
90 |
+
repo_id = args.repo_id
|
91 |
+
upload = args.upload
|
92 |
+
commit_message = args.commit_message
|
93 |
+
|
94 |
+
|
95 |
+
dataset = load_data(data)
|
96 |
+
split_data(dataset)
|
97 |
+
model = train_model(X_train=X_train, y_train=y_train)
|
98 |
+
report = evaluate_model(model)
|
99 |
+
print(report)
|
100 |
+
|
101 |
+
# Save the model
|
102 |
+
|
103 |
+
model_path = os.path.join(save_directory) # this will convert the path to OS specific path
|
104 |
+
print("Saving model to", model_path)
|
105 |
+
os.makedirs(model_path, exist_ok=True)
|
106 |
+
|
107 |
+
model_file = os.path.join(model_path, f"{model_name}.skops")
|
108 |
+
|
109 |
+
dump(model, file=model_file)
|
110 |
+
|
111 |
+
|
112 |
+
if upload:
|
113 |
+
# Push the model to the hub
|
114 |
+
local_repo = mkdtemp(prefix="skops-")
|
115 |
+
print("Creating local repo at", local_repo)
|
116 |
+
hub_utils.init( model=model_file,
|
117 |
+
dst=local_repo,
|
118 |
+
requirements=[f"scikit-learn={sklearn.__version__}"],
|
119 |
+
task="text-classification",
|
120 |
+
data=X_test.to_list(),
|
121 |
+
)
|
122 |
+
|
123 |
+
hub_utils.add_files(__file__, dst=local_repo, exist_ok=True )
|
124 |
+
|
125 |
+
hub_utils.push(source=local_repo, repo_id=repo_id, commit_message=commit_message)
|
126 |
+
|
127 |
+
|