WenqingZhang commited on
Commit
f04dd6a
·
verified ·
1 Parent(s): 073e1be

Upload 18 files

Browse files
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
- title: CipherClause
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: The privacy preserving AI develop by the CypherClause team !
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test
3
+ emoji: 🏆
4
+ colorFrom: red
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from requests import head
3
+ from transformer_vectorizer import TransformerVectorizer
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+ import numpy as np
6
+ from concrete.ml.deployment import FHEModelClient
7
+ import numpy
8
+ import os
9
+ from pathlib import Path
10
+ import requests
11
+ import json
12
+ import base64
13
+ import subprocess
14
+ import shutil
15
+ import time
16
+
17
+ # This repository's directory
18
+ REPO_DIR = Path(__file__).parent
19
+
20
+ subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
21
+
22
+ # Wait 5 sec for the server to start
23
+ time.sleep(5)
24
+
25
+ # Encrypted data limit for the browser to display
26
+ # (encrypted data is too large to display in the browser)
27
+ ENCRYPTED_DATA_BROWSER_LIMIT = 500
28
+ N_USER_KEY_STORED = 20
29
+ model_names=['financial_rating','legal_rating']
30
+
31
+
32
+ FHE_MODEL_PATH = "deployment/financial_rating"
33
+ FHE_LEGAL_PATH = "deployment/legal_rating"
34
+ #FHE_LEGAL_PATH="deployment/legal_rating"
35
+
36
+ print("Loading the transformer model...")
37
+
38
+ # Initialize the transformer vectorizer
39
+ transformer_vectorizer = TransformerVectorizer()
40
+ vectorizer = TfidfVectorizer()
41
+
42
+ def clean_tmp_directory():
43
+ # Allow 20 user keys to be stored.
44
+ # Once that limitation is reached, deleted the oldest.
45
+ path_sub_directories = sorted([f for f in Path(".fhe_keys/").iterdir() if f.is_dir()], key=os.path.getmtime)
46
+
47
+ user_ids = []
48
+ if len(path_sub_directories) > N_USER_KEY_STORED:
49
+ n_files_to_delete = len(path_sub_directories) - N_USER_KEY_STORED
50
+ for p in path_sub_directories[:n_files_to_delete]:
51
+ user_ids.append(p.name)
52
+ shutil.rmtree(p)
53
+
54
+ list_files_tmp = Path("tmp/").iterdir()
55
+ # Delete all files related to user_id
56
+ for file in list_files_tmp:
57
+ for user_id in user_ids:
58
+ if file.name.endswith(f"{user_id}.npy"):
59
+ file.unlink()
60
+ mes=[]
61
+
62
+ def keygen(selected_tasks):
63
+ # Clean tmp directory if needed
64
+ clean_tmp_directory()
65
+
66
+ print("Initializing FHEModelClient...")
67
+
68
+
69
+
70
+ if not selected_tasks:
71
+ return "choose a task first" # 修改提示信息为英文
72
+ user_id = numpy.random.randint(0, 2**32)
73
+ if "legal_rating" in selected_tasks:
74
+ model_names.append('legal_rating')
75
+ # Let's create a user_id
76
+
77
+ fhe_api= FHEModelClient(FHE_LEGAL_PATH, f".fhe_keys/{user_id}")
78
+
79
+
80
+ if "financial_rating" in selected_tasks:
81
+ model_names.append('financial_rating')
82
+
83
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
84
+
85
+ # Let's create a user_id
86
+
87
+
88
+ fhe_api.load()
89
+
90
+
91
+ # Generate a fresh key
92
+ fhe_api.generate_private_and_evaluation_keys(force=True)
93
+ evaluation_key = fhe_api.get_serialized_evaluation_keys()
94
+
95
+ # Save evaluation_key in a file, since too large to pass through regular Gradio
96
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
97
+ numpy.save(f"tmp/tmp_evaluation_key_{user_id}.npy", evaluation_key)
98
+
99
+ return [list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT], user_id]
100
+
101
+
102
+
103
+
104
+
105
+ def encode_quantize_encrypt(text, user_id):
106
+ if not user_id:
107
+ raise gr.Error("You need to generate FHE keys first.")
108
+ if "legal_rating" in model_names:
109
+ fhe_api = FHEModelClient(FHE_LEGAL_PATH, f".fhe_keys/{user_id}")
110
+ encodings =vectorizer.fit_transform([text]).toarray()
111
+ if encodings.shape[1] < 1736:
112
+ # 在后面填充零
113
+ padding = np.zeros((1, 1736 - encodings.shape[1]))
114
+ encodings = np.hstack((encodings, padding))
115
+ elif encodings.shape[1] > 1736:
116
+ # 截取前1736列
117
+ encodings = encodings[:, :1736]
118
+ else:
119
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
120
+ encodings = transformer_vectorizer.transform([text])
121
+
122
+ fhe_api.load()
123
+ quantized_encodings = fhe_api.model.quantize_input(encodings).astype(numpy.uint8)
124
+ encrypted_quantized_encoding = fhe_api.quantize_encrypt_serialize(encodings)
125
+
126
+ # Save encrypted_quantized_encoding in a file, since too large to pass through regular Gradio
127
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
128
+ numpy.save(f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy", encrypted_quantized_encoding)
129
+
130
+ # Compute size
131
+ encrypted_quantized_encoding_shorten = list(encrypted_quantized_encoding)[:ENCRYPTED_DATA_BROWSER_LIMIT]
132
+ encrypted_quantized_encoding_shorten_hex = ''.join(f'{i:02x}' for i in encrypted_quantized_encoding_shorten)
133
+ return (
134
+ encodings[0],
135
+ quantized_encodings[0],
136
+ encrypted_quantized_encoding_shorten_hex,
137
+ )
138
+
139
+
140
+
141
+ def run_fhe(user_id):
142
+ encoded_data_path = Path(f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy")
143
+ if not user_id:
144
+ raise gr.Error("You need to generate FHE keys first.")
145
+ if not encoded_data_path.is_file():
146
+ raise gr.Error("No encrypted data was found. Encrypt the data before trying to predict.")
147
+
148
+ # Read encrypted_quantized_encoding from the file
149
+ encrypted_quantized_encoding = numpy.load(encoded_data_path)
150
+
151
+ # Read evaluation_key from the file
152
+ evaluation_key = numpy.load(f"tmp/tmp_evaluation_key_{user_id}.npy")
153
+
154
+ # Use base64 to encode the encodings and evaluation key
155
+ encrypted_quantized_encoding = base64.b64encode(encrypted_quantized_encoding).decode()
156
+ encoded_evaluation_key = base64.b64encode(evaluation_key).decode()
157
+
158
+ query = {}
159
+ query["evaluation_key"] = encoded_evaluation_key
160
+ query["encrypted_encoding"] = encrypted_quantized_encoding
161
+ headers = {"Content-type": "application/json"}
162
+ if "legal_rating" in model_names:
163
+ response = requests.post(
164
+ "http://localhost:8000/predict_legal", data=json.dumps(query), headers=headers
165
+ )
166
+ else:
167
+ response = requests.post(
168
+ "http://localhost:8000/predict_sentiment", data=json.dumps(query), headers=headers
169
+ )
170
+ encrypted_prediction = base64.b64decode(response.json()["encrypted_prediction"])
171
+
172
+ # Save encrypted_prediction in a file, since too large to pass through regular Gradio
173
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
174
+ numpy.save(f"tmp/tmp_encrypted_prediction_{user_id}.npy", encrypted_prediction)
175
+ encrypted_prediction_shorten = list(encrypted_prediction)[:ENCRYPTED_DATA_BROWSER_LIMIT]
176
+ encrypted_prediction_shorten_hex = ''.join(f'{i:02x}' for i in encrypted_prediction_shorten)
177
+ return encrypted_prediction_shorten_hex
178
+
179
+
180
+ def decrypt_prediction(user_id):
181
+ encoded_data_path = Path(f"tmp/tmp_encrypted_prediction_{user_id}.npy")
182
+ if not user_id:
183
+ raise gr.Error("You need to generate FHE keys first.")
184
+ if not encoded_data_path.is_file():
185
+ raise gr.Error("No encrypted prediction was found. Run the prediction over the encrypted data first.")
186
+
187
+ # Read encrypted_prediction from the file
188
+ encrypted_prediction = numpy.load(encoded_data_path).tobytes()
189
+
190
+ if "legal_rating" in model_names:
191
+ fhe_api = FHEModelClient(FHE_LEGAL_PATH, f".fhe_keys/{user_id}")
192
+
193
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
194
+ fhe_api.load()
195
+
196
+ # We need to retrieve the private key that matches the client specs (see issue #18)
197
+ fhe_api.generate_private_and_evaluation_keys(force=False)
198
+
199
+ predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_prediction)
200
+ print(predictions)
201
+
202
+ return {
203
+ "low_relative": predictions[0][0],
204
+ "medium_relative": predictions[0][1],
205
+ "high_relative": predictions[0][2],
206
+ }
207
+
208
+
209
+ demo = gr.Blocks()
210
+
211
+
212
+ print("Starting the demo...")
213
+ with demo:
214
+
215
+ gr.Markdown(
216
+ """
217
+
218
+ <h2 align="center">📄Cipher Clause</h2>
219
+ <p align="center">
220
+ <img width=200 src="https://www.helloimg.com/i/2024/09/28/66f7f6701bcfb.jpeg">
221
+ </p>
222
+
223
+ """
224
+ )
225
+
226
+
227
+ gr.Markdown(
228
+ """
229
+ <p align="center">
230
+ </p>
231
+ <p align="center">
232
+ </p>
233
+ """
234
+ )
235
+
236
+ gr.Markdown("## Notes")
237
+ gr.Markdown(
238
+ """
239
+ - The private key is used to encrypt and decrypt the data and shall never be shared.
240
+ - The evaluation key is a public key that the server needs to process encrypted data.
241
+ """
242
+ )
243
+ gr.Markdown(
244
+ """
245
+ <hr/>
246
+ """
247
+ )
248
+ gr.Markdown("# Step 0: Select Task")
249
+ task_checkbox = gr.CheckboxGroup(
250
+ choices=["legal_rating", "financial_rating"],
251
+ label="select_tasks"
252
+ )
253
+ gr.Markdown(
254
+ """
255
+ <hr/>
256
+ """
257
+ )
258
+ gr.Markdown("# Step 1: Generate the keys")
259
+
260
+ b_gen_key_and_install = gr.Button("Generate all the keys and send public part to server")
261
+
262
+ evaluation_key = gr.Textbox(
263
+ label="Evaluation key (truncated):",
264
+ max_lines=4,
265
+ interactive=False,
266
+ )
267
+
268
+ user_id = gr.Textbox(
269
+ label="",
270
+ max_lines=4,
271
+ interactive=False,
272
+ visible=False
273
+ )
274
+ gr.Markdown(
275
+ """
276
+ <hr/>
277
+ """
278
+ )
279
+ gr.Markdown("# Step 2: Provide a contract or clause")
280
+ gr.Markdown("## Client side")
281
+ gr.Markdown(
282
+ "Enter a contract or clause you want to analysis)."
283
+ )
284
+ text = gr.Textbox(label="Enter some words:", value="The Employee is entitled to two weeks of paid vacation annually, to be scheduled at the mutual convenience of the Employee and Employer.")
285
+ gr.Markdown(
286
+ """
287
+ <hr/>
288
+ """
289
+ )
290
+ gr.Markdown("# Step 3: Encode the message with the private key")
291
+ b_encode_quantize_text = gr.Button(
292
+ "Encode, quantize and encrypt the text with vectorizer, and send to server"
293
+ )
294
+
295
+ with gr.Row():
296
+ encoding = gr.Textbox(
297
+ label="Representation:",
298
+ max_lines=4,
299
+ interactive=False,
300
+ )
301
+ quantized_encoding = gr.Textbox(
302
+ label="Quantized representation:", max_lines=4, interactive=False
303
+ )
304
+ encrypted_quantized_encoding = gr.Textbox(
305
+ label="Encrypted quantized representation (truncated):",
306
+ max_lines=4,
307
+ interactive=False,
308
+ )
309
+ gr.Markdown(
310
+ """
311
+ <hr/>
312
+ """
313
+ )
314
+ gr.Markdown("# Step 4: Run the FHE evaluation")
315
+ gr.Markdown("## Server side")
316
+ gr.Markdown(
317
+ "The encrypted value is received by the server. Thanks to the evaluation key and to FHE, the server can compute the (encrypted) prediction directly over encrypted values. Once the computation is finished, the server returns the encrypted prediction to the client."
318
+ )
319
+
320
+ b_run_fhe = gr.Button("Run FHE execution there")
321
+ encrypted_prediction = gr.Textbox(
322
+ label="Encrypted prediction (truncated):",
323
+ max_lines=4,
324
+ interactive=False,
325
+ )
326
+ gr.Markdown(
327
+ """
328
+ <hr/>
329
+ """
330
+ )
331
+ gr.Markdown("# Step 5: Decrypt the class")
332
+ gr.Markdown("## Client side")
333
+ gr.Markdown(
334
+ "The encrypted sentiment is sent back to client, who can finally decrypt it with its private key. Only the client is aware of the original tweet and the prediction."
335
+ )
336
+ b_decrypt_prediction = gr.Button("Decrypt prediction")
337
+
338
+ labels_sentiment = gr.Label(label="level:")
339
+
340
+ # Button for key generation
341
+ b_gen_key_and_install.click(keygen, inputs=[task_checkbox], outputs=[evaluation_key, user_id])
342
+
343
+ # Button to quantize and encrypt
344
+ b_encode_quantize_text.click(
345
+ encode_quantize_encrypt,
346
+ inputs=[text, user_id],
347
+ outputs=[
348
+ encoding,
349
+ quantized_encoding,
350
+ encrypted_quantized_encoding,
351
+ ],
352
+ )
353
+
354
+ # Button to send the encodings to the server using post at (localhost:8000/predict_sentiment)
355
+ b_run_fhe.click(run_fhe, inputs=[user_id], outputs=[encrypted_prediction])
356
+
357
+ # Button to decrypt the prediction on the client
358
+ b_decrypt_prediction.click(decrypt_prediction, inputs=[user_id], outputs=[labels_sentiment])
359
+ gr.Markdown(
360
+ "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). Try it yourself and don't forget to star on Github &#11088;."
361
+ )
362
+ demo.launch(share=False)
deployment/financial_rating/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b65015dc1ba9f0c02eed7fa6915fedd4b68e69dc89d3421b31598a368e75e33
3
+ size 3409316
deployment/financial_rating/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5217aed38b47b116fa6643e41cfe97ea5f680748a0b0f4a9e034fba636a31774
3
+ size 69335
deployment/financial_rating/versions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"concrete-python": "2.5", "concrete-ml": "1.4.0", "python": "3.10.12"}
deployment/legal_rating/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2327bdcfd3225ed32962d1a7db19feb69f47564925a52b4a0522d63043e5455d
3
+ size 1178525
deployment/legal_rating/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d357a8b4985486db5a5800d8109775b7f1491bac3b137acb54684e6fa6ebad59
3
+ size 1005337
deployment/legal_rating/versions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"concrete-python": "2.5", "concrete-ml": "1.4.0", "python": "3.10.12"}
deployment/samples_for_compilation.csv ADDED
The diff for this file is too large to render. See raw diff
 
deployment/sentiment_fhe_model/client.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbde0af1d92d5c2b2e42d8d439ae75328773dac591826559fbc2043356c22388
3
+ size 3887326
deployment/sentiment_fhe_model/server.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6ab63f4cae95dd9c418df05b0041d567f485ae16ae84f02068165b3df659baf
3
+ size 3004
deployment/sentiment_fhe_model/versions.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"concrete-python": "2.5", "concrete-ml": "1.4.0", "python": "3.10.11"}
deployment/serialized_model ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ concrete-ml==1.4.0
2
+ gradio
3
+ pandas==1.4.3
4
+ transformers==4.36.0
5
+
6
+
server.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Server that will listen for GET requests from the client."""
2
+ from fastapi import FastAPI
3
+ from joblib import load
4
+ from concrete.ml.deployment import FHEModelServer
5
+ from pydantic import BaseModel
6
+ import base64
7
+ from pathlib import Path
8
+
9
+ current_dir = Path(__file__).parent
10
+
11
+ # Load the model
12
+ fhe_model = FHEModelServer("deployment/financial_rating")
13
+ fhe_legal_model = FHEModelServer("deployment/legal_rating")
14
+ class PredictRequest(BaseModel):
15
+ evaluation_key: str
16
+ encrypted_encoding: str
17
+
18
+ # Initialize an instance of FastAPI
19
+ app = FastAPI()
20
+
21
+ # Define the default route
22
+ @app.get("/")
23
+ def root():
24
+ return {"message": "Welcome to Your Sentiment Classification FHE Model Server!"}
25
+
26
+ @app.post("/predict_sentiment")
27
+ def predict_sentiment(query: PredictRequest):
28
+ fhe_model = FHEModelServer("deployment/financial_rating")
29
+
30
+ encrypted_encoding = base64.b64decode(query.encrypted_encoding)
31
+ evaluation_key = base64.b64decode(query.evaluation_key)
32
+ prediction = fhe_model.run(encrypted_encoding, evaluation_key)
33
+
34
+ # Encode base64 the prediction
35
+ encoded_prediction = base64.b64encode(prediction).decode()
36
+ return {"encrypted_prediction": encoded_prediction}
37
+
38
+ @app.post("/predict_legal")
39
+ def predict_legal(query: PredictRequest):
40
+ fhe_legal_model = FHEModelServer("deployment/legal_rating")
41
+
42
+ encrypted_encoding = base64.b64decode(query.encrypted_encoding)
43
+ evaluation_key = base64.b64decode(query.evaluation_key)
44
+ prediction = fhe_legal_model.run(encrypted_encoding, evaluation_key)
45
+
46
+ # Encode base64 the prediction
47
+ encoded_prediction = base64.b64encode(prediction).decode()
48
+ return {"encrypted_prediction": encoded_prediction}
tmp/text.txt ADDED
File without changes
transformer_vectorizer.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Let's import a few requirements
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ import numpy
5
+
6
+ class TransformerVectorizer:
7
+ def __init__(self):
8
+ # Load the tokenizer (converts text to tokens)
9
+ self.tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
10
+
11
+ # Load the pre-trained model
12
+ self.transformer_model = AutoModelForSequenceClassification.from_pretrained(
13
+ "cardiffnlp/twitter-roberta-base-sentiment-latest"
14
+ )
15
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
+
17
+ def text_to_tensor(
18
+ self,
19
+ texts: list,
20
+ ) -> numpy.ndarray:
21
+ """Function that transforms a list of texts to their learned representation.
22
+
23
+ Args:
24
+ list_text_X (list): List of texts to be transformed.
25
+
26
+ Returns:
27
+ numpy.ndarray: Transformed list of texts.
28
+ """
29
+ # First, tokenize all the input text
30
+ tokenized_text_X_train = self.tokenizer.batch_encode_plus(
31
+ texts, return_tensors="pt"
32
+ )["input_ids"]
33
+
34
+ # Depending on the hardware used, the number of examples to be processed can be reduced
35
+ # Here we split the data into 100 examples per batch
36
+ tokenized_text_X_train_split = torch.split(tokenized_text_X_train, split_size_or_sections=50)
37
+
38
+ # Send the model to the device
39
+ transformer_model = self.transformer_model.to(self.device)
40
+ output_hidden_states_list = []
41
+
42
+ for tokenized_x in tokenized_text_X_train_split:
43
+ # Pass the tokens through the transformer model and get the hidden states
44
+ # Only keep the last hidden layer state for now
45
+ output_hidden_states = transformer_model(tokenized_x.to(self.device), output_hidden_states=True)[
46
+ 1
47
+ ][-1]
48
+ # Average over the tokens axis to get a representation at the text level.
49
+ output_hidden_states = output_hidden_states.mean(dim=1)
50
+ output_hidden_states = output_hidden_states.detach().cpu().numpy()
51
+ output_hidden_states_list.append(output_hidden_states)
52
+
53
+ self.encodings = numpy.concatenate(output_hidden_states_list, axis=0)
54
+ return self.encodings
55
+
56
+ def transform(self, texts: list):
57
+ return self.text_to_tensor(texts)
58
+
usecase.jpeg ADDED