Alejandro Velez commited on
Commit
db1d3e7
·
1 Parent(s): 1b531a5

config fix and cleanup

Browse files
Files changed (48) hide show
  1. config.json +4 -4
  2. examples/cell_classification.ipynb +0 -0
  3. examples/extract_and_plot_cell_embeddings.ipynb +0 -0
  4. examples/gene_classification.ipynb +0 -0
  5. examples/in_silico_perturbation.ipynb +0 -159
  6. examples/multitask_cell_classification.ipynb +0 -420
  7. examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb +0 -365
  8. examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py +0 -167
  9. examples/tokenizing_scRNAseq_data.ipynb +0 -91
  10. geneformer/__init__.py +0 -34
  11. geneformer/classifier.py +0 -1563
  12. geneformer/classifier_utils.py +0 -648
  13. geneformer/collator_for_classification.py +0 -667
  14. geneformer/emb_extractor.py +0 -863
  15. geneformer/evaluation_utils.py +0 -287
  16. geneformer/in_silico_perturber.py +0 -1579
  17. geneformer/in_silico_perturber_stats.py +0 -1104
  18. geneformer/mtl/__init__.py +0 -1
  19. geneformer/mtl/collators.py +0 -76
  20. geneformer/mtl/data.py +0 -162
  21. geneformer/mtl/eval_utils.py +0 -88
  22. geneformer/mtl/imports.py +0 -43
  23. geneformer/mtl/model.py +0 -121
  24. geneformer/mtl/optuna_utils.py +0 -27
  25. geneformer/mtl/train.py +0 -380
  26. geneformer/mtl/train_utils.py +0 -161
  27. geneformer/mtl/utils.py +0 -129
  28. geneformer/mtl_classifier.py +0 -363
  29. geneformer/perturber_utils.py +0 -919
  30. geneformer/pretrainer.py +0 -640
  31. geneformer/tokenizer.py +0 -685
  32. gf-12L-30M-i2048/config.json +0 -23
  33. gf-12L-30M-i2048/pytorch_model.bin +0 -3
  34. gf-12L-30M-i2048/training_args.bin +0 -3
  35. gf-12L-95M-i4096/config.json +0 -24
  36. gf-12L-95M-i4096/generation_config.json +0 -5
  37. gf-12L-95M-i4096/model.safetensors +0 -3
  38. gf-12L-95M-i4096/training_args.bin +0 -3
  39. gf-20L-95M-i4096/config.json +0 -24
  40. gf-20L-95M-i4096/generation_config.json +0 -5
  41. gf-20L-95M-i4096/model.safetensors +0 -3
  42. gf-20L-95M-i4096/training_args.bin +0 -3
  43. gf-20L-95M-i4096_config.json +0 -24
  44. gf-20L-95M-i4096_generation_config.json +0 -5
  45. gf-6L-30M-i2048/config.json +0 -23
  46. gf-6L-30M-i2048/model.safetensors +0 -3
  47. gf-6L-30M-i2048/pytorch_model.bin +0 -3
  48. gf-6L-30M-i2048/training_args.bin +0 -3
config.json CHANGED
@@ -6,14 +6,14 @@
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
  "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
 
6
  "classifier_dropout": null,
7
  "hidden_act": "relu",
8
  "hidden_dropout_prob": 0.02,
9
+ "hidden_size": 896,
10
  "initializer_range": 0.02,
11
+ "intermediate_size": 1792,
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 4096,
14
  "model_type": "bert",
15
+ "num_attention_heads": 14,
16
+ "num_hidden_layers": 20,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
  "torch_dtype": "float32",
examples/cell_classification.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
examples/extract_and_plot_cell_embeddings.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
examples/gene_classification.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
examples/in_silico_perturbation.ipynb DELETED
@@ -1,159 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "id": "e10ac0c9-40ce-41fb-b6fa-3d62b76f2e57",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "from geneformer import InSilicoPerturber\n",
11
- "from geneformer import InSilicoPerturberStats\n",
12
- "from geneformer import EmbExtractor"
13
- ]
14
- },
15
- {
16
- "cell_type": "markdown",
17
- "id": "cbd6851c-060e-4967-b816-e605ffe58b23",
18
- "metadata": {
19
- "tags": []
20
- },
21
- "source": [
22
- "### in silico perturbation in deletion mode to determine genes whose deletion in the dilated cardiomyopathy (dcm) state significantly shifts the embedding towards non-failing (nf) state"
23
- ]
24
- },
25
- {
26
- "cell_type": "code",
27
- "execution_count": null,
28
- "id": "c53e98cd-c603-4878-82ba-db471181bb55",
29
- "metadata": {},
30
- "outputs": [],
31
- "source": [
32
- "# first obtain start, goal, and alt embedding positions\n",
33
- "# this function was changed to be separate from perturb_data\n",
34
- "# to avoid repeating calcuations when parallelizing perturb_data\n",
35
- "cell_states_to_model={\"state_key\": \"disease\", \n",
36
- " \"start_state\": \"dcm\", \n",
37
- " \"goal_state\": \"nf\", \n",
38
- " \"alt_states\": [\"hcm\"]}\n",
39
- "\n",
40
- "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
41
- "\n",
42
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
43
- "# (otherwise the EmbExtractor will use the current default model dictionary)\n",
44
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
45
- "embex = EmbExtractor(model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
46
- " num_classes=3,\n",
47
- " filter_data=filter_data_dict,\n",
48
- " max_ncells=1000,\n",
49
- " emb_layer=0,\n",
50
- " summary_stat=\"exact_mean\",\n",
51
- " forward_batch_size=256,\n",
52
- " nproc=16)\n",
53
- "\n",
54
- "state_embs_dict = embex.get_state_embs(cell_states_to_model,\n",
55
- " \"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
56
- " \"path/to/input_data\",\n",
57
- " \"path/to/output_directory\",\n",
58
- " \"output_prefix\")"
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": null,
64
- "id": "981e1190-62da-4543-b7d3-6e2a2d6a6d56",
65
- "metadata": {
66
- "tags": []
67
- },
68
- "outputs": [],
69
- "source": [
70
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
71
- "# (otherwise the InSilicoPerturber will use the current default model dictionary)\n",
72
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
73
- "isp = InSilicoPerturber(perturb_type=\"delete\",\n",
74
- " perturb_rank_shift=None,\n",
75
- " genes_to_perturb=\"all\",\n",
76
- " combos=0,\n",
77
- " anchor_gene=None,\n",
78
- " model_type=\"CellClassifier\", # if using previously fine-tuned cell classifier model\n",
79
- " num_classes=3,\n",
80
- " emb_mode=\"cell\",\n",
81
- " cell_emb_style=\"mean_pool\",\n",
82
- " filter_data=filter_data_dict,\n",
83
- " cell_states_to_model=cell_states_to_model,\n",
84
- " state_embs_dict=state_embs_dict,\n",
85
- " max_ncells=2000,\n",
86
- " emb_layer=0,\n",
87
- " forward_batch_size=400,\n",
88
- " nproc=16)"
89
- ]
90
- },
91
- {
92
- "cell_type": "code",
93
- "execution_count": null,
94
- "id": "0525a663-871a-4ce0-a135-cc203817ffa9",
95
- "metadata": {},
96
- "outputs": [],
97
- "source": [
98
- "# outputs intermediate files from in silico perturbation\n",
99
- "\n",
100
- "isp.perturb_data(\"../fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224\", # example 30M fine-tuned model\n",
101
- " \"path/to/input_data\",\n",
102
- " \"path/to/isp_output_directory\",\n",
103
- " \"output_prefix\")"
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": null,
109
- "id": "f8aadabb-516a-4dc0-b307-6de880e64e26",
110
- "metadata": {},
111
- "outputs": [],
112
- "source": [
113
- "# OF NOTE: token_dictionary_file must be set to the gc-30M token dictionary if using a 30M series model\n",
114
- "# (otherwise the InSilicoPerturberStats will use the current default model dictionary)\n",
115
- "# 30M token dictionary: https://huggingface.co/ctheodoris/Geneformer/blob/main/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl\n",
116
- "ispstats = InSilicoPerturberStats(mode=\"goal_state_shift\",\n",
117
- " genes_perturbed=\"all\",\n",
118
- " combos=0,\n",
119
- " anchor_gene=None,\n",
120
- " cell_states_to_model=cell_states_to_model)"
121
- ]
122
- },
123
- {
124
- "cell_type": "code",
125
- "execution_count": null,
126
- "id": "ffecfae6-e737-43e3-99e9-fa37ff46610b",
127
- "metadata": {},
128
- "outputs": [],
129
- "source": [
130
- "# extracts data from intermediate files and processes stats to output in final .csv\n",
131
- "ispstats.get_stats(\"path/to/isp_output_directory\", # this should be the directory \n",
132
- " None,\n",
133
- " \"path/to/isp_stats_output_directory\",\n",
134
- " \"output_prefix\")"
135
- ]
136
- }
137
- ],
138
- "metadata": {
139
- "kernelspec": {
140
- "display_name": "Python 3 (ipykernel)",
141
- "language": "python",
142
- "name": "python3"
143
- },
144
- "language_info": {
145
- "codemirror_mode": {
146
- "name": "ipython",
147
- "version": 3
148
- },
149
- "file_extension": ".py",
150
- "mimetype": "text/x-python",
151
- "name": "python",
152
- "nbconvert_exporter": "python",
153
- "pygments_lexer": "ipython3",
154
- "version": "3.10.15"
155
- }
156
- },
157
- "nbformat": 4,
158
- "nbformat_minor": 5
159
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/multitask_cell_classification.ipynb DELETED
@@ -1,420 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "866f100c-e11a-4e7b-a37c-831775d845a7",
6
- "metadata": {},
7
- "source": [
8
- "# Geneformer Multi-Task Cell Classifier Tutorial\n",
9
- "\n",
10
- "This tutorial demonstrates how to use the Geneformer Multi-Task Cell Classifier and optimizatize hyperparameter for fine-tuning"
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "id": "311ba456-b44d-40c7-941d-3fc03bcda85a",
16
- "metadata": {},
17
- "source": [
18
- "## 1. Installation and Imports\n",
19
- "\n",
20
- "First import the necessary modules."
21
- ]
22
- },
23
- {
24
- "cell_type": "code",
25
- "execution_count": 3,
26
- "id": "cd9defdc-0524-4c3b-a741-27117ed3a5be",
27
- "metadata": {},
28
- "outputs": [],
29
- "source": [
30
- "from geneformer import MTLClassifier"
31
- ]
32
- },
33
- {
34
- "cell_type": "markdown",
35
- "id": "790e9c3c-f6d9-44b3-b9a5-05725760f4fd",
36
- "metadata": {},
37
- "source": [
38
- "## 2. Set up Paths and Parameters\n",
39
- "\n",
40
- "Now, let's set up the necessary paths and parameters for our classifier. We'll also define our task columns, which are specific columns from our dataset that represent the classification tasks we want to train the model on."
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": null,
46
- "id": "04a04197-8e45-47f8-a86f-202209ea10ae",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "# Define paths\n",
51
- "pretrained_path = \"/path/to/pretrained/Geneformer/model\" \n",
52
- "# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
53
- "train_path = \"/path/to/train/data.dataset\"\n",
54
- "val_path = \"/path/to/val/data.dataset\"\n",
55
- "test_path = \"/path/to/test/data.dataset\"\n",
56
- "results_dir = \"/path/to/results/directory\"\n",
57
- "model_save_path = \"/path/to/model/save/path\"\n",
58
- "tensorboard_log_dir = \"/path/to/tensorboard/log/dir\"\n",
59
- "\n",
60
- "# Define tasks and hyperparameters\n",
61
- "# task_columns should be a list of column names from your dataset\n",
62
- "# Each column represents a specific classification task (e.g. cell type, disease state)\n",
63
- "task_columns = [\"cell_type\", \"disease_state\"] # Example task columns\n",
64
- "\n",
65
- "hyperparameters = {\n",
66
- " \"learning_rate\": {\"type\": \"float\", \"low\": 1e-5, \"high\": 1e-3, \"log\": True},\n",
67
- " \"warmup_ratio\": {\"type\": \"float\", \"low\": 0.005, \"high\": 0.01},\n",
68
- " \"weight_decay\": {\"type\": \"float\", \"low\": 0.01, \"high\": 0.1},\n",
69
- " \"dropout_rate\": {\"type\": \"float\", \"low\": 0.0, \"high\": 0.7},\n",
70
- " \"lr_scheduler_type\": {\"type\": \"categorical\", \"choices\": [\"cosine\"]},\n",
71
- " \"task_weights\": {\"type\": \"float\", \"low\": 0.1, \"high\": 2.0}\n",
72
- "}"
73
- ]
74
- },
75
- {
76
- "cell_type": "markdown",
77
- "id": "31857690-a739-435a-aefd-f171fafc1b78",
78
- "metadata": {},
79
- "source": [
80
- "In the code above, we've defined `task_columns` as `[\"cell_type\", \"disease_state\"]`. This means our model will be trained to classify cells based on two tasks:\n",
81
- "1. Identifying the cell type\n",
82
- "2. Determining the disease state\n",
83
- "3. Note: \"unique_cell_id\" is a required column in the dataset for logging and inference purposes\n",
84
- "\n",
85
- "These column names should correspond to actual columns in your dataset. Each column should contain the labels for that specific classification task.\n",
86
- "\n",
87
- "For example, your dataset might look something like this:\n",
88
- "\n",
89
- " | unique_cell_id | input_ids | ... | cell_type | disease_state |\n",
90
- " |----------------|-----------|-----|-----------|---------------|\n",
91
- " | cell1 | ... | ... | neuron | healthy |\n",
92
- " | cell2 | ... | ... | astrocyte | diseased |\n",
93
- " | ... | ... | ... | ... | ... |\n",
94
- "The model will learn to predict classes within 'cell_type' and 'disease_state' "
95
- ]
96
- },
97
- {
98
- "cell_type": "markdown",
99
- "id": "b9e3050a-6162-4c01-b6fd-8784bf4ab1e4",
100
- "metadata": {},
101
- "source": [
102
- "## 3. Initialize the MTLClassifier\n",
103
- "\n",
104
- "Now, let's create an instance of the MTLClassifier with our defined parameters and task columns."
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": null,
110
- "id": "e27caac9-670c-409d-9313-50201c665cb9",
111
- "metadata": {},
112
- "outputs": [],
113
- "source": [
114
- "mc = MTLClassifier(\n",
115
- " task_columns=task_columns, # Our defined classification tasks\n",
116
- " study_name=\"MTLClassifier_example\",\n",
117
- " pretrained_path=pretrained_path,\n",
118
- " train_path=train_path,\n",
119
- " val_path=val_path,\n",
120
- " test_path=test_path,\n",
121
- " model_save_path=model_save_path,\n",
122
- " results_dir=results_dir,\n",
123
- " tensorboard_log_dir=tensorboard_log_dir,\n",
124
- " hyperparameters=hyperparameters,\n",
125
- " n_trials=15, # Number of trials for hyperparameter optimization (at least 50 suggested)\n",
126
- " epochs=1, # Number of training epochs (1 suggested to prevent overfitting)\n",
127
- " batch_size=8, # Adjust based on available GPU memory\n",
128
- " seed=42\n",
129
- ")"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "0d729444-e3ad-4584-9659-0c464ac97462",
135
- "metadata": {},
136
- "source": [
137
- "## 4. Run Hyperparameter Optimization\n",
138
- "\n",
139
- "Now, let's run the Optuna study to optimize our hyperparameters for both classification tasks."
140
- ]
141
- },
142
- {
143
- "cell_type": "code",
144
- "execution_count": null,
145
- "id": "9298aa3e-6a52-4aa8-b9ff-b63d97beac93",
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "mc.run_optuna_study()"
150
- ]
151
- },
152
- {
153
- "cell_type": "markdown",
154
- "id": "af23075d-d07b-43d3-bc5d-4df4d5d7199b",
155
- "metadata": {},
156
- "source": [
157
- "## 5. Evaluate the Model on Test Data\n",
158
- "\n",
159
- "After optimization, we can evaluate our model on the test dataset. This will provide performance metrics for both classification tasks. CSV containing following keys will be generated in specified results directiory \"Cell ID, task(1...n) True,task(1.,.n) Pred,task(1...n) Probabilities\""
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "id": "461bf8d3-b964-4ff4-994f-9f3d313d4614",
166
- "metadata": {},
167
- "outputs": [],
168
- "source": [
169
- "mc.load_and_evaluate_test_model()"
170
- ]
171
- },
172
- {
173
- "cell_type": "markdown",
174
- "id": "31cfeb2d-6673-4b02-a79c-2533cc5e4d28",
175
- "metadata": {},
176
- "source": [
177
- "## 6. (Optional) Manual Hyperparameter Tuning\n",
178
- "\n",
179
- "If you prefer to set hyperparameters manually, you can use the following approach:"
180
- ]
181
- },
182
- {
183
- "cell_type": "code",
184
- "execution_count": null,
185
- "id": "8ee6b99f-42e9-4abf-a292-aa9047735e0e",
186
- "metadata": {},
187
- "outputs": [],
188
- "source": [
189
- "manual_hyperparameters = {\n",
190
- " \"learning_rate\": 0.001,\n",
191
- " \"warmup_ratio\": 0.01,\n",
192
- " \"weight_decay\": 0.1,\n",
193
- " \"dropout_rate\": 0.1,\n",
194
- " \"lr_scheduler_type\": \"cosine\",\n",
195
- " \"task_weights\": [1, 1], # Weights for each task (cell_type, disease_state)\n",
196
- " \"max_layers_to_freeze\": 2\n",
197
- "}\n",
198
- "\n",
199
- "mc_manual = MTLClassifier(\n",
200
- " task_columns=task_columns,\n",
201
- " study_name=\"mtl_manual\",\n",
202
- " pretrained_path=pretrained_path,\n",
203
- " train_path=train_path,\n",
204
- " val_path=val_path,\n",
205
- " test_path=test_path,\n",
206
- " model_save_path=model_save_path,\n",
207
- " results_dir=results_dir,\n",
208
- " tensorboard_log_dir=tensorboard_log_dir,\n",
209
- " manual_hyperparameters=manual_hyperparameters,\n",
210
- " use_manual_hyperparameters=True,\n",
211
- " epochs=10,\n",
212
- " batch_size=32,\n",
213
- " seed=42\n",
214
- ")\n",
215
- "\n",
216
- "mc_manual.run_manual_tuning()"
217
- ]
218
- },
219
- {
220
- "cell_type": "markdown",
221
- "id": "dbaac008-fc00-4b71-8e78-89b2d922d9d8",
222
- "metadata": {},
223
- "source": [
224
- "# Geneformer In Silico Perturber Tutorial (MTL Quantized)\n",
225
- "This demonstrates how to use the Geneformer In Silico Perturber with a Multi-Task Learning (MTL) model in a quantized configuration to optimize runtime and memory."
226
- ]
227
- },
228
- {
229
- "cell_type": "code",
230
- "execution_count": null,
231
- "id": "2e15ad57-736c-48f0-be87-39cf5015bc5c",
232
- "metadata": {},
233
- "outputs": [],
234
- "source": [
235
- "from geneformer import InSilicoPerturber, EmbExtractor, InSilicoPerturberStats"
236
- ]
237
- },
238
- {
239
- "cell_type": "code",
240
- "execution_count": null,
241
- "id": "43c18140-151e-4d44-95b4-a9b3a47172cf",
242
- "metadata": {},
243
- "outputs": [],
244
- "source": [
245
- "# Define paths\n",
246
- "model_directory = \"/path/to/model/save/path\"\n",
247
- "input_data_file = \"/path/to/input/data.dataset\"\n",
248
- "output_directory = \"/path/to/output/directory\"\n",
249
- "output_prefix = \"mtl_quantized_perturbation\"\n",
250
- "\n",
251
- "# Define parameters\n",
252
- "perturb_type = \"delete\" # or \"overexpress\"\n",
253
- "\n",
254
- "# Define cell states to model\n",
255
- "cell_states_to_model = {\n",
256
- " \"state_key\": \"disease_state\", \n",
257
- " \"start_state\": \"disease\", \n",
258
- " \"goal_state\": \"control\"\n",
259
- "}\n",
260
- "\n",
261
- "# Define filter data\n",
262
- "filter_data_dict = {\n",
263
- " \"cell_type\": [\"Fibroblast\"]\n",
264
- "}"
265
- ]
266
- },
267
- {
268
- "cell_type": "markdown",
269
- "id": "3010d0bf-b23c-45c1-ac12-8c472dc8b7a1",
270
- "metadata": {},
271
- "source": [
272
- "## 3. Extract State Embeddings\n",
273
- "\n",
274
- "Before we initialize the InSilicoPerturber, we need to extract the state embeddings using the EmbExtractor."
275
- ]
276
- },
277
- {
278
- "cell_type": "code",
279
- "execution_count": null,
280
- "id": "215f0a90-8041-417d-a5d3-b2483626c3b2",
281
- "metadata": {},
282
- "outputs": [],
283
- "source": [
284
- "# Initialize EmbExtractor\n",
285
- "embex = EmbExtractor(\n",
286
- " filter_data_dict=filter_data_dict,\n",
287
- " max_ncells=1000, # Number of cells to extract embeddings for\n",
288
- " emb_layer=0, # Use the second to last layer\n",
289
- " emb_mode = \"cls\",\n",
290
- " summary_stat=\"exact_mean\",\n",
291
- " forward_batch_size=8, # Adjust based on available GPU memory\n",
292
- " nproc=4\n",
293
- ")\n",
294
- "\n",
295
- "# Extract state embeddings\n",
296
- "state_embs_dict = embex.get_state_embs(\n",
297
- " cell_states_to_model,\n",
298
- " model_directory=model_directory,\n",
299
- " input_data_file=input_data_file,\n",
300
- " output_directory=output_directory,\n",
301
- " output_prefix=output_prefix\n",
302
- ")"
303
- ]
304
- },
305
- {
306
- "cell_type": "markdown",
307
- "id": "23f14e36-4529-4fb2-8af9-7f4875cf81e3",
308
- "metadata": {},
309
- "source": [
310
- "## 4. Initialize the InSilicoPerturber\n",
311
- "\n",
312
- "Now that we have our state embeddings, let's create an instance of the InSilicoPerturber with MTL and quantized configurations."
313
- ]
314
- },
315
- {
316
- "cell_type": "code",
317
- "execution_count": null,
318
- "id": "09f985a1-91bc-4e8d-8001-a3663531b570",
319
- "metadata": {},
320
- "outputs": [],
321
- "source": [
322
- "# Initialize InSilicoPerturber\n",
323
- "isp = InSilicoPerturber(\n",
324
- " perturb_type=perturb_type,\n",
325
- " genes_to_perturb=\"all\", # Perturb all genes\n",
326
- " model_type=\"MTLCellClassifier-Quantized\", # Use quantized MTL model\n",
327
- " emb_mode=\"cls\", # Use CLS token embedding\n",
328
- " cell_states_to_model=cell_states_to_model,\n",
329
- " state_embs_dict=state_embs_dict,\n",
330
- " max_ncells=1000, # Number of cells to perturb (larger number increases power)\n",
331
- " emb_layer=0, \n",
332
- " forward_batch_size=8, # Adjust based on available GPU memory\n",
333
- " nproc=1\n",
334
- ")"
335
- ]
336
- },
337
- {
338
- "cell_type": "markdown",
339
- "id": "cfcc2c1e-fd7f-4a36-99fc-ac7f43e5be6b",
340
- "metadata": {},
341
- "source": [
342
- "## 5. Run In Silico Perturbation\n",
343
- "\n",
344
- "Run the in silico perturbation on the dataset."
345
- ]
346
- },
347
- {
348
- "cell_type": "code",
349
- "execution_count": null,
350
- "id": "cf030c09-8ae4-45a7-aaf7-3fc2af4fe296",
351
- "metadata": {},
352
- "outputs": [],
353
- "source": [
354
- "# Run perturbation and output intermediate files\n",
355
- "isp.perturb_data(\n",
356
- " model_directory=model_directory,\n",
357
- " input_data_file=input_data_file,\n",
358
- " output_directory=output_directory,\n",
359
- " output_prefix=output_prefix\n",
360
- ")"
361
- ]
362
- },
363
- {
364
- "cell_type": "markdown",
365
- "id": "bb8ec074-6f2f-422b-a973-37ed32a15c38",
366
- "metadata": {},
367
- "source": [
368
- "## 6. Process Results with InSilicoPerturberStats\n",
369
- "\n",
370
- "After running the perturbation, we'll use InSilicoPerturberStats to process the intermediate files and generate the final statistics."
371
- ]
372
- },
373
- {
374
- "cell_type": "code",
375
- "execution_count": null,
376
- "id": "0a748043-43fc-47ad-ace5-f0ae3dd34674",
377
- "metadata": {},
378
- "outputs": [],
379
- "source": [
380
- "# Initialize InSilicoPerturberStats\n",
381
- "ispstats = InSilicoPerturberStats(\n",
382
- " mode=\"goal_state_shift\",\n",
383
- " genes_perturbed=\"all\",\n",
384
- " combos=0,\n",
385
- " anchor_gene=None,\n",
386
- " cell_states_to_model=cell_states_to_model\n",
387
- ")\n",
388
- "\n",
389
- "# Process stats and output final .csv\n",
390
- "ispstats.get_stats(\n",
391
- " input_data_file,\n",
392
- " None,\n",
393
- " output_directory,\n",
394
- " output_prefix\n",
395
- ")"
396
- ]
397
- }
398
- ],
399
- "metadata": {
400
- "kernelspec": {
401
- "display_name": "Python 3 (ipykernel)",
402
- "language": "python",
403
- "name": "python3"
404
- },
405
- "language_info": {
406
- "codemirror_mode": {
407
- "name": "ipython",
408
- "version": 3
409
- },
410
- "file_extension": ".py",
411
- "mimetype": "text/x-python",
412
- "name": "python",
413
- "nbconvert_exporter": "python",
414
- "pygments_lexer": "ipython3",
415
- "version": "3.11.5"
416
- }
417
- },
418
- "nbformat": 4,
419
- "nbformat_minor": 5
420
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/pretraining_new_model/obtain_nonzero_median_digests.ipynb DELETED
@@ -1,365 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "charged-worcester",
6
- "metadata": {},
7
- "source": [
8
- "# Obtain non-zero median expression value of each gene across Genecorpus-30M"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "28e87f2a-a33e-4fe3-81af-ad4cd62fcc1b",
14
- "metadata": {},
15
- "source": [
16
- "#### Upon request, we are providing the code that we used for obtaining the non-zero median expression value of each gene across the broad range of cell types represented in Genecorpus-30M that we use as a normalization factor to prioritize genes that uniquely distinguish cell state.\n",
17
- "\n",
18
- "#### Please read the important information below before using this code.\n",
19
- "\n",
20
- "#### If using Geneformer, to ensure consistency of the normalization factor used for each gene for all future datasets, <ins>**users should use the Geneformer transcriptome tokenizer to tokenize their datasets and should not re-calculate this normalization factor for their individual dataset** </ins>. This code for re-calculating the normalization factor should only be used by users who are pretraining a new model from scratch with a new pretraining corpus other than Genecorpus-30M.\n",
21
- "\n",
22
- "#### It is critical that this calculation is performed on a large-scale pretraining corpus that has tens of millions of cells from a broad range of human tissues. <ins>**The richness of variable cell states in the pretraining corpus is what allows this normalization factor to accomplish the goal of prioritizing genes that uniquely distinguish cell states.** </ins> This normalization factor for each gene is calculated once from the large-scale pretraining corpus and is used for all future datasets presented to the model. \n",
23
- "\n",
24
- "#### Of note, as discussed in the Methods, we only included droplet-based sequencing platforms in the pretraining corpus to assure expression value unit comparability for the calculation of this normalization factor. Users wishing to pretrain a new model from scratch with a new pretraining corpus should choose either droplet-based or plate-based platforms for calculating this normalization factor, or they should exercise caution that including both platforms may cause unintended effects on the results. Once the normalization factor is calculated however, data from any platform can be used with the model because the expression value units will be consistent within each individual cell.\n",
25
- "\n",
26
- "#### Please see the Methods in the manuscript for a description of the procedure enacted by this code, an excerpt of which is below for convenience:\n",
27
- "\n",
28
- "#### \"To accomplish this, we first calculated the non-zero median value of expression of each detected gene across all cells passing quality filtering from the entire Genecorpus-30M. We aggregated the transcript count distribution for each gene in a memory-efficient manner by scanning through chunks of .loom data using loompy, normalizing the gene transcript counts in each cell by the total transcript count of that cell to account for varying sequencing depth and updating the normalized count distribution of the gene within the t-digest data structure developed for accurate online accumulation of rank-based statistics. We then normalized the genes in each single-cell transcriptome by the non-zero median value of expression of that gene across Genecorpus-30M and ordered the genes by the rank of their normalized expression in that specific cell. Of note, we opted to use the non-zero median value of expression rather than include zeros in the distribution so as not to weight the value by tissue representation within Genecorpus-30M, assuming that a representative range of transcript values would be observed within the cells in which each gene was detected. This normalization factor for each gene is calculated once from the pretraining corpus and is used for all future datasets presented to the model. The provided tokenizer code includes this normalization procedure and should be used for tokenizing new datasets presented to Geneformer to ensure consistency of the normalization factor used for each gene.\""
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 1,
34
- "id": "textile-destruction",
35
- "metadata": {},
36
- "outputs": [],
37
- "source": [
38
- "import os\n",
39
- "import numpy as np\n",
40
- "import loompy as lp\n",
41
- "import pandas as pd\n",
42
- "import crick\n",
43
- "import pickle\n",
44
- "import math\n",
45
- "from tqdm.notebook import tqdm"
46
- ]
47
- },
48
- {
49
- "cell_type": "markdown",
50
- "id": "4af8cfef-05f2-47e0-b8d2-71ca025059c7",
51
- "metadata": {
52
- "tags": []
53
- },
54
- "source": [
55
- "### The following code is an example of how the nonzero median expression values are obtained for a single input file. This calculation should be run as a script to be parallelized for all dataset files."
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": 30,
61
- "id": "physical-intro",
62
- "metadata": {},
63
- "outputs": [],
64
- "source": [
65
- "input_file = \"study1.loom\"\n",
66
- "current_database = \"database1\"\n",
67
- "\n",
68
- "rootdir = f\"/path/to/{current_database}/data/\"\n",
69
- "output_file = input_file.replace(\".loom\", \".gene_median_digest_dict.pickle\")\n",
70
- "outdir = rootdir.replace(\"/data/\", \"/tdigest/\")\n",
71
- "\n",
72
- "with lp.connect(f\"{rootdir}{input_file}\") as data:\n",
73
- " # define coordinates of protein-coding or miRNA genes\n",
74
- " coding_miRNA_loc = np.where((data.ra.gene_type == \"protein_coding\") | (data.ra.gene_type == \"miRNA\"))[0]\n",
75
- " coding_miRNA_genes = data.ra[\"ensembl_id\"][coding_miRNA_loc]\n",
76
- " \n",
77
- " # initiate tdigests\n",
78
- " median_digests = [crick.tdigest.TDigest() for _ in range(len(coding_miRNA_loc))]\n",
79
- " \n",
80
- " # initiate progress meters\n",
81
- " progress = tqdm(total=len(coding_miRNA_loc))\n",
82
- " last_view_row = 0\n",
83
- " progress.update(0)\n",
84
- " \n",
85
- " for (ix, selection, view) in data.scan(items=coding_miRNA_loc, axis=0):\n",
86
- " # define coordinates of cells passing filter\n",
87
- " filter_passed_loc = np.where(view.ca.filter_pass == 1)[0]\n",
88
- " subview = view.view[:, filter_passed_loc]\n",
89
- " # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision\n",
90
- " subview_norm_array = subview[:,:]/subview.ca.n_counts*10_000\n",
91
- " # if integer, convert to float to prevent error with filling with nan\n",
92
- " if np.issubdtype(subview_norm_array.dtype, np.integer):\n",
93
- " subview_norm_array = subview_norm_array.astype(np.float32)\n",
94
- " # mask zeroes from distribution tdigest by filling with nan\n",
95
- " nonzero_data = np.ma.masked_equal(subview_norm_array, 0.0).filled(np.nan)\n",
96
- " # update tdigests\n",
97
- " [median_digests[i+last_view_row].update(nonzero_data[i,:]) for i in range(nonzero_data.shape[0])]\n",
98
- " # update progress meters\n",
99
- " progress.update(view.shape[0])\n",
100
- " last_view_row = last_view_row + view.shape[0]\n",
101
- " \n",
102
- "median_digest_dict = dict(zip(coding_miRNA_genes, median_digests))\n",
103
- "with open(f\"{outdir}{output_file}\", \"wb\") as fp:\n",
104
- " pickle.dump(median_digest_dict, fp)"
105
- ]
106
- },
107
- {
108
- "cell_type": "markdown",
109
- "id": "190a3754-aafa-4ccf-ba97-951c94ea3030",
110
- "metadata": {
111
- "tags": []
112
- },
113
- "source": [
114
- "### After the above code is run as a script in parallel for all datasets to obtain the nonzero median tdigests for their contained genes, the following code can be run to merge the tdigests across all datasets."
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": 2,
120
- "id": "distributed-riding",
121
- "metadata": {},
122
- "outputs": [],
123
- "source": [
124
- "# merge new tdigests into total tdigest dict\n",
125
- "def merge_digest(dict_key_ensembl_id, dict_value_tdigest, new_tdigest_dict):\n",
126
- " new_gene_tdigest = new_tdigest_dict.get(dict_key_ensembl_id)\n",
127
- " if new_gene_tdigest is not None:\n",
128
- " dict_value_tdigest.merge(new_gene_tdigest)\n",
129
- " return dict_value_tdigest\n",
130
- " elif new_gene_tdigest is None:\n",
131
- " return dict_value_tdigest"
132
- ]
133
- },
134
- {
135
- "cell_type": "code",
136
- "execution_count": null,
137
- "id": "distinct-library",
138
- "metadata": {},
139
- "outputs": [],
140
- "source": [
141
- "# use tdigest1.merge(tdigest2) to merge tdigest1, tdigest2, ...tdigestn\n",
142
- "# then, extract median by tdigest1.quantile(0.5)\n",
143
- "\n",
144
- "databases = [\"database1\", \"database2\", \"...databaseN\"]\n",
145
- "\n",
146
- "# obtain gene list\n",
147
- "gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n",
148
- "func_gene_list = [i for i in gene_info[(gene_info[\"gene_type\"] == \"protein_coding\") | (gene_info[\"gene_type\"] == \"miRNA\")][\"ensembl_id\"]]\n",
149
- "\n",
150
- "# initiate tdigests\n",
151
- "median_digests = [crick.tdigest.TDigest() for _ in range(len(func_gene_list))]\n",
152
- "total_tdigest_dict = dict(zip(func_gene_list, median_digests))\n",
153
- "\n",
154
- "# merge tdigests\n",
155
- "for current_database in databases:\n",
156
- " rootdir = f\"/path/to/{current_database}/tdigest/\"\n",
157
- " \n",
158
- " for subdir, dirs, files in os.walk(rootdir):\t\n",
159
- " for file in files:\n",
160
- " if file.endswith(\".gene_median_digest_dict.pickle\"):\n",
161
- " with open(f\"{rootdir}{file}\", \"rb\") as fp:\n",
162
- " tdigest_dict = pickle.load(fp)\n",
163
- " total_tdigest_dict = {k: merge_digest(k,v,tdigest_dict) for k, v in total_tdigest_dict.items()}\n",
164
- "\n",
165
- "# save dict of merged tdigests\n",
166
- "with open(f\"/path/to/total_gene_tdigest_dict.pickle\", \"wb\") as fp:\n",
167
- " pickle.dump(total_tdigest_dict, fp)\n",
168
- "\n",
169
- "# extract medians and save dict\n",
170
- "total_median_dict = {k: v.quantile(0.5) for k, v in total_tdigest_dict.items()}\n",
171
- "with open(f\"/path/to/total_gene_median_dict.pickle\", \"wb\") as fp:\n",
172
- " pickle.dump(total_median_dict, fp)\n",
173
- "\n",
174
- "# save dict of only detected genes' medians \n",
175
- "detected_median_dict = {k: v for k, v in total_median_dict.items() if not math.isnan(v)}\n",
176
- "with open(f\"/path/to/detected_gene_median_dict.pickle\", \"wb\") as fp:\n",
177
- " pickle.dump(detected_median_dict, fp)"
178
- ]
179
- },
180
- {
181
- "cell_type": "markdown",
182
- "id": "e8e17ad6-79ac-4f34-aa0c-1eaa1bace2e5",
183
- "metadata": {
184
- "tags": []
185
- },
186
- "source": [
187
- "### The below code displays some characteristics of the genes detected in the pretraining corpus."
188
- ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": 38,
193
- "id": "decent-switzerland",
194
- "metadata": {},
195
- "outputs": [],
196
- "source": [
197
- "gene_detection_counts_dict = {k: v.size() for k, v in total_tdigest_dict.items()}"
198
- ]
199
- },
200
- {
201
- "cell_type": "code",
202
- "execution_count": 44,
203
- "id": "polished-innocent",
204
- "metadata": {},
205
- "outputs": [
206
- {
207
- "name": "stderr",
208
- "output_type": "stream",
209
- "text": [
210
- "/home1/ct68/miniconda3/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
211
- " warnings.warn(msg, FutureWarning)\n"
212
- ]
213
- },
214
- {
215
- "data": {
216
- "image/png": "iVBORw0KGgoAAAANSUhEUgAABR8AAAMRCAYAAABlG8GWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAABcSAAAXEgFnn9JSAAC/KUlEQVR4nOzdd5hjZ3X48e/Z7l2vK240G0wzBgOmmmp6NT/TQmjBlCS0ACGE3gklJARCLyGYGgi9hRqwgYBpxnRMMTZgsI1x2+Lt5/fHe8d7dUfSSBpdaWb2+3kePaN7dcs7M1ea0dF5z4nMRJIkSZIkSZLGbdm0ByBJkiRJkiRpaTL4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWrFi2gOQJEmTERFXAY4GDgf2A1YDG4BLgPOA0zPzT9Ma3yAi4mTgkbVVj8rMk/tsfwTwm9qqczLziDbGJqm3iDib8toz4xqZefZ0RqM9XUQcD3yltsq/DZLUIoOPkiQtYRFxLHAScF863/j32v4c4AvAe4GvZWa2OkANLCJOAe4wzmNmZozzeJImw9cDSdJi4rRrSZL6iIjlEXFpRGR1e/uA271m0mNtjOeGEfFF4HvA3zFA4LFyOPDXwKnAryLiERHh/wuSJEmSRuKbCUmS+rsJsE9t+Ss9trtZY7tT2hpQP1E8AzgduEufTRO4mDLtuld24zWBdwPfGOsgJUmSJO0xnHYtSVJ/zWltp/TY7k61+7uAr7Yymj6qDMX/AB7V5eE/AB8DPgt8F7gwM3dW+60GrgPcFrgf5XtZXtv3ei0OW6M7DXjXtAchaUHw9UCStGAZfJQkqb/ja/fPzMw/9NiuHnz8QWZe3N6QenoDswOPG4FXAK/JzMu77ZSZW4EfVbc3R8SRwHMptSKtAbZwnZmZb5n2IKRB2Myjdb4eSJIWLKddS5LUQ5VJeNvaqq5TriNiFXDrubZrU0Q8Gnh8Y/X5wO0y8+W9Ao/dZOavM/PRlO/prDEOU5IkSdIexuCjJEm93RjYr7bcK6h4K2BtbfmUdobTXURcGWg2uLkEuG1mnjHqcTPzNErNyy+MPDhJkiRJezSDj5Ik9bZY6j0+n85mNwBPycxfzffAmXkZ8JfzPY4kSZKkPZM1HyVJ6u342v2fZOYFPbarBx+/n5mXtjekThFxILPrPH41M989rnNk5q5R9ouI/SlZoYcAB1G6av8J+A1wWmZuG9cY2xQRhwI3Ao4A9gVWAZcDlwG/BX49jkCveouIFcAtgBsABwJbKE2UvtfWzz4i1lGu3+sA+1M+WPhjZg7U1CMiDgeOpVz7BwKbgQuAnwI/zMxeXeYHHV8ARwLHAIdRPoCI6jwXAedQ6gCeN+Lx96mOfR1KBvhaYCuwCTgXOBv4aWZun8/3MR8RsS9wG+DawN7ApZTr4quZeeGYzrEGuD1wdeBgys/gt8C3MvO34zjHYlS7/q5L+dnsQ0lsuRi4kPLc/E1L574ecDRwJeAAYAfld/8r4EeZ+acxnWcl5TXgBpTXgMspz+Fv+ZovScMx+ChJUhdVvcfb1Vad0mO7vYBbzrVdix4GrG6se9OEx3CFiFgOPBL4a+DmdHbNrtsYEZ8BXpyZP5vU+AZVBbseW91uOsD2FwFfAz4MvH/UgO1CEBF3BT5H5wyZl2TmC4c4xq2BU+n8X/M1mfm0Hts3A3HXyMyzI2It8EzgiZQAXrd9Twf+KTM/NsT4TgLeWVt1amYeXz12beBFwAOY/dyCPh2FI2I/4O8p2cLX6TOE8yLivcArMvOiQcddnWN/4OnAwylBn7m2/y3wReA9mXnqANvfHXgycDfmfq+wJSK+A3wMOLlfo62IOBs4vLbqGpl5dp/tXwTUr7l3ZeZJ1WPXBP4JeCCwssvuGRFfAp6VmafP8T30Ov9VgJdRroO9e2zzDcpr2Bd6jPnFmfmiUc6/EFXX3v2Be1MCsl2fk7XtzwXeAbx+vsHgqhHaPwInAFfus2lGxA+BjwLvyMxzRzjXeuA5wOPoLL1S3+ZnwAsy88PDHl+S9kROu5Yk7bEiInvdgJ2UTIcZT+yx3WZKJtyMf+hz3ONb+Dbu21i+iBIImLiIuC0lq+sdlGyRXoFHKG/mHwz8KCJeXmXRLAhVxtrpwJsZIPBYOQD4f8B7mD0FflHJzC8CL2msfl4VlJpTRFwJ+CCdgatvUoKIA4uIawDfA15A/yDHscBHI+JDVZbayKrGTT8CHkr3wGO/ff+W0qDpBfQPPAIcSgkgnhURDxziHHcDfkkJjMwZeKxcHXgM8O9zHHtNRHyQEni+F4MlKayhfEjzb8wuU9GKiHgI8GPgIXQPPELJAL0r8K2I+KsRz/FzygcpXQOPlVsDn4+If68+sFqyIuIY4DzgP4D7MUfgsXIVyvPhVxFxnxHPuyYi3kT5ffwt/QOPUH73NwJeDHx6hPPdiPIa8Cx6BB4rRwEfiog3L/XfvSSNgy+UkiQtUlV23m0aq78xjenMEfEI4H/pHnRJyhTljV0eWw48G/ivhfAGLiIOoGQw3rDHJpspUwo3TWxQ0/FS4PO15WXAeyPiqv12qn6H7wPq210I/MWQ03OvRLmertdYv5Ey9bGbBwIfj4ihgoYzIuKRlMB5c/9LgJ5jj4jlEfF64C10fmAxYydlKurWLo/tC/x3RPzdAOO7DfApugd9EthA+Vl3O88gPgz8RY/HtgJ/pjyPp5bVGxEPp1xfe9VW76L8fLu97q0ATo6IOw5xjkcC76V70HHmXDsb65/M7KZfS81aOj9oq9tOuT66vcZDuc4/GREPHuaEVTO1rwKPp3cw/DLKtd/1EEOe7waUxnKHNx66jN6v+Y+jBFglSX1M/Z98SZI0sqMomUd135n0IKqMlnfR+cb0YuBVwHHAmszcNzPXUzJJHgB8o3GYB1Ma50zby4Cr1ZYTeDdlCur+mbkuMw/KzL0p3+/RlAyskyn1LJeEatr4w4Hf1VZfiRIo65VtBuV3eLf6oYCHZ+bvhxzCG4BrVPd/TalreqXMXJ+ZaykZVU+m1F+ruzvwyiHPBSU7cKZcwS7KlOw7AKszc39KQPIISjZU0yuBJzXW/Yoy/fr6wMrMPCAz11TneQKlHuOMAF4bEXeiv7fR+Ry7DHg5pbzBuszcp7o211ACZzenBG0+Se+AbRlACQrdu7H6a5Tp41fJzDWZeaXM3JcSBDoCuA/wauDMOcY9LjegZN0Fpebnayh1QFdVP9/V1TZvoDNAGsB/VCUh+qqy3t5O53ukXcBbKdncqzPzAMrv4UaU17iZYO+TgXuM/N0tHpcDn6Fc87cB9svMVdX1sR5YR8kIfTWdwciZ38ORg5ykKmnyGcp1XHcxJTP7lpTfx76ZuQ/ld3ITyvPrS8wOEM9lL8qsgZkPED5G+X2uq86xN+V15x8oH0jUPSci5sp2lqQ9Wsyz1rUkSYtWRDyux0PLgNexe9rwl4EPddluNfDa2vIngc/2OeUnM/MPQw6zp4i4H6WuVd39MvPj4zrHAGM4HPg+nRlfXwQekZnnz7Hv8ygZdjN2ATfLzO/32edkylTIGY/KzJP7bH8EpcHNjHMy84ge266iBLP2ra3+y8z8YK/jd9n/AcDHM7NvsGcUEXEKnVNbr6iB15aIuBUl86gecPz3zHxql23vQsmWrAduXpqZc2YFdan5OONTlN/B5h77HVidsz49fhdw28z8Zp/znURnzccZG4D7ZuYpc425Os59gU80Vv8r8Nx+GcgRsTclg69eNuEPwJGZuaXL9rcAvlVbdQlwq8wcKPBXZfTeOTO7vY4REf8D3LO26s3AEwdtihMRtwf+1K9+6xhqPs74DXCvzPx5n30fQfnQoO7/ZeYn++yzjDLN/8a11RuBe2fmV/vsdzQl2HVol4dbq/k46deDiLguJUD9jkGbqlV1Mz9JKY0w4z8z8zED7Hsyna/1AB+nvOZfMsD+R1Cey6/r8fjxlCzHps2UD0x6li+pMiT/j84SGz1r2kqSgMz05s2bN2/evNVuwM0oGVszt4f32O4Oje3uOeFxPqFx/gRuN+Ex/Gfj/F+lZKMMuv8bG/t/YI7tT25sf9Ic2x/R2P7sPtter7Htt6d9LTbGd0qX3/d8bicOeN6ndNn3AY1trkIJ3Na3+RKwbMBzdBvfDylZs3PteyXg/Ma+n5ljn5N6nPM+Q/w+llHqL9b3f9UQ+6+hBLvq+z+ux7aPG/U8A47lvNqxtwH7tnD9nt34Ho6YY/sXdfn9XAZce8DzfbKx77vn2P5eo14PlIDl9i77v2jcP8faOafyejDCOA+mTMmeOc8WShZ5v32OoXyIUB/fhwd9PRlwXMf3+Dn85YD7P62x32/b+l178+bN21K4Oe1akqTZmtMfv9xju+Nr93cCX29lNL2t77JuoIyUcajq/z28tmoH8NjMHKbm3HMpAYUZD6yy2abhgMbyr6YyigUmM/+d8sa/7j+jdIWeqT36AeCg2uN/AB6a8+v6/eTskgXYZXwXUq6juntUDWuG8enMHKZBxQOBa9WWfwU8b9Cdq+/tHxure2Vjt31t1o9/YQ6Y2TYFr8zMXw647dsay83pu03Nn/0nBr0eMvMMdk/bV01mXkCppTpjNWVadj/PprNe4x8pf1varjf6xcz8wIDbvpPyN2/G1SLikBbGJElLgsFHSZJmqwcfz8zeU6WPr93/Xmb2Knrflm6NNSbZCOVBdE7H/Vxm/mKYA2SZPve52qrllO6503BJY/kmC6EJzgLxaKD+u90H+HBVl+2VwG1rj+2gZA816zEO46c54NTnynvpDGIvY3YNw7k0g1VzeVhj+S05ZLOnzPwyJetwxjHVFOmmSxrLg3ZhH1T9+IfM1VhoSnYx3O+oWVf2Or2ez1UA/c6N1cMGE9885PZ7ktMay7fqtWFVU/a+jdWvywGmWo/BwL/DzLwYaJYZaDbIkiRVenUNkyRpj1S98akHUrpmPVYddetvoE5pcVi9dMswXDfB89+hsfy5rlvN7Xt0dtk9jlLba9LOpDQzmKlfeT3grRHxtCkElgdxGqXRz6jOGHTDzNwQEQ+k1B2c6TR8DKUj9XGNzZ+bmV+bx7hgdh3Fuca3JSK+QMlGnHErSvORgQ4BnDro+aogVjNIPur1/31211sMSiONZu3YZvDmMRFxBvDWMWWDnQacUN1fRgks/2X2qck4BT+uslwHkpkXRcSl7K7huoySLd4tq/MYSjfnGVvonfHe63w/j4izgGsOs98YTez1oK7qSH09SjOx9ZRyAs0u081mLFejt1vS+buA8uHCJAz8GlA5C7hhbXm/8Q1FkpYWg4+SJHW6BZ0BvK/02O5WdHaaPqWtAfWxscu6fbusa0sze+V6fZr49HNMY/mwEcczL5m5MyLeSmdH48cCD4qIj1A6r34tMxdKV+szM/MtkzpZZv4oIh5Pqbs5oxl4/BTwL2M43ekj7lMPPt5oiH3PyczL5t7sCtehs8kSwJ0iYpSs3Ss1lmdd/5l5ekScxu7n3HJKZt4zIuK/KYHPb2WPxjwDeCO7g49QAkC/jIjPUgLBp2Tmr0c89ricPcI+G+h8TdyH7sHHZsbajzNzR5ft5vJ9phd8nNjrQUTcjZL5e29glDIZzedOXTOr95zM/P0I5xjWZZl50ZD7ND+U2qfrVpIkg4+SJDXUp1wnvYOPx9fuT6PeI5Q6WE3dpmyOXZX5dVBj9ZPGdPhp1XwEeAlwezprku1LmXb8aICI+AWl0+nXgC9n5jmTHuS0ZOa7qgBbt261ZwOPzMwcw6lG+Zk29xnmOvrzkOfq1tm4a1fdEfQa9yMoU4nrz7sjgGdUtx0R8X3KtflVSsDw4kFOmJmfj4jXAH9fW72CEpA8ASAizqvO/zXg1OzTlb4ll4ywz87G8vIe2zWDYb1Kbcyl22vykhER16JMfb/jPA/VrV7xjObflUnV3r1khH0Gvb4kaY9n8FGStEeJiJtRuln3Us+cuojSAKXbdvev3f8T8LAe2/0hMz857DgH1C0T6RiGnLI6ov1pr3Z0c8rdxGTm5RFxZ+AVlG7iq7psdp3q9iiAiPg28HbgXZm5fVJjnaLnU7pFN99oP2bQYNcAhslCnNHMaOuXXdXULYu4nzYD5F2v/8z8VUTclFKXrls9yxWUpio3B54KbK+mor82M78010kz82kR8XPgZczOxoQScL1/dSMizgbeTanHN2zwdhTjCGr3sl9jedQyC6Nct4tCRNyA0sF+HE1V+v3taD63LhnD+QbR5vUlSXs8g4+SpD3NfYAXDrjtgQxWgP7QPtudCrQVfPwppe5jvfHMXB1dx6VbUG5cukZxJ6XqQvz3EfFq4K+AE4Fj6Z3Vcovq9oyIeEhmfm8iA52CKBH2t9D9Z/F4hqyT18cogYBJXjdTuf4z83fAfSLiWMq1eS/g2j02X0kJUt47Ij5HyUrt2wQoM98WEe8HHkxpKHVbeteRPQJ4AfDUiHhiZk6qLl8bmvVzR/39tnldTE1VC/kDzA48/gD4KPAdSubxecDlwNZ6LdKIOJ7eswjmYlBQkpYAg4+SJC1Smbk9Ir5B5xS420TEqmG77o6gW6bTUZn585bPOzFVnbGXAy+PiPWUenvHAbehBGWaGWrXBr4cEbfNzB9NdLCT84/M7kQ744ER8ZTM/PcxnGeU2qXNemvjysLspnn9n5+Z3aZityIzT6fUuHxqRBxGKRNwa8p1eVNmB4fvAXwpIm6dmX2zPKvH3wG8owo63YRy3d+WUpLg4MYu+wDviYgVmXnyvL6x6WleK/uNeJxR91voHgYcXVveATxqiIDz3kOcq9lUaJgMZknSAtXWdClJkjQZzazKA4D7tX3SKrjZnGJ4rbbPOy2ZuSEzv5iZL8nMu1N+zicAn29sug+Dd1heVKpajy9rrD6rsfwvEXHLMZzu8DHs0+ZU4GbToUOqAPXEZeYfM/MjmfkPmXlLSnDwr4GfNTa9ISV4PMyxt2fmtzPz3zPzQZQs71tS6v41G7K8NiIWa6CoWavxqBGPM+p+C939G8uvHDLTtVnHsZ/mc2vJ/l2RpD2JwUdJ0h4lM1+UmdHtBvxPbdMz+mz31dp2p/Xarrod3/K39D6gmeX4hJbPOaPZcOL4CZ136jJza2Z+OjPvQWd3bIDbR8TVpzGutkTEwZRpl/VZM6dSaozWO1OvBP47Iubb+OjYMezzg3mOoZ+fAVsa6+7Q4vkGlpkXZeZ/ULp9f6rx8CPmeeysgpF/S8m4rgcg96WzY/Zi8t3G8lUj4irDHCAiVgM3HtuIFpYbN5bfPeT+txhi2+bv4vCIuOqQ55MkLTAGHyVJAiJiOXC72qpTemy3hpL5M2PUOlZjkZl/At7VWH37iPircZ2j6mzdzRcbyw+IiD2xpMurmJ05daNpDKQN1e//fcCVa6vPBx6SmZsoTZouqT12deDd0aMD04D+35BjXAPcrbH6tHmcv6+qLmizw/2D2zrfKKrmR89orL7GuDI0M/PrwEcaqxfldV/Vwmx2VX7YkIe5H73rYy52zan2A3ejr/623muIc30H2NRY9/Ah9pckLUAGHyVJKm4O1N+U9woqHkdng5epBh8rL2Z2d9Z/j4h5T1eLiH2A/+rx8EeAXbXlI4DHzPeci01mJrPfjC+lIMQLgbvUlncBD83MPwJk5m+oOn/X3Bt45jzOef2IGCaT8OF01nzcBXxmHucfxH83lh8SEddv+ZzD+k2XdeO8NpvHX8zX/fsay0+NiIFqj1Yfujx3/ENaMJrZ9fsNse9DKR9IDKQKmn+8sfrvBv1dSJIWJoOPkiQV9aYtu+icWt1ru+3A/7U2ogFl5rnAPzRW7wd8PSJGzkSKiFtRptTevcd5f06Zilv3rxFxk3mcc2qdrkfN2qyacjQDvefNf0TTFxF3A57XWP2izOzoap2ZHwde3djunyLi9vM4/eurqaxzjfFKzK5F+fkqKNqmk4Gza8vLgQ9FxH6jHrDX9T+PjOJmMHQnjZp688xWbh5/MV/3b6O8ps84DPiPKnNvLv8C3KCVUS0Mv28sDzS9vio/8doRzvdKOrtcX5nSAMn3rpK0SPkCLklSUQ8qnpGZl/TY7vja/W9l5ubWRjSEzHw75c1z3SHA1yLi2RGx16DHiohrRsQ7KIHVI+fY/PnApbXlvYH/jYgHDHq+6pyHRcQLmV2jbpKeEBGfjYh7DPkm9xXAlWrLGylTBxe1qs7a++j8f/HzzA70zXgWncH45cAHqnqRo7ghJZjX89qNiAOBz9E5LTT7jHFsqgytpzdWX58S9B8qEBURx0TE2ylZzN28OyLeHhHHDHHMdcwO/HwtM3c21t0gIn4YEY+p9hn0+P8PuE9j9ULIBB9JZv6BEvSqeyDwyYi4Wrd9IuLAiDgZeGq1qlkHdKn4cmP5ZRHR929DlQX8VUpzrqFk5o+BdzZWPwD48BDZqEdExJOHPbckqR17Yl0mSZI6RMQq4Da1Vaf02G4vFlC9xy6eAOxFZ1OJ9cDLgSdFxEeBzwLfAy6cCUJU2WXXpvwM7g/cmRI4mlNmnhURD6ZMcZ3ZZ3/Km8TTgP+gvAH9dWbuqs4XlGDRDYGbUrJojqMEub430nc+HsuAe1S3P0XEJyi/4zOAX1UdvgGIiEOA2wNPqr7Wvb2qhdi260bE4+Z5jK9k5pnNlVU23AfpDKr+Hnj4zO+xKTN3VNfC99nd3fYw4P0Rcbde+/XwLcpz7QTgRxHxT8AnM/OianyHUQJDz2N2Pbo3ZOZEMpIz8yPV2OrZoUcDZ0TExyglC76RmVdkBFaZdIdTmvUcR6lved3q4Wb26Iy1wEOAx0bEmcDHgG9Srs3f155by4BrUK7hpwHXbBznNT2Of0PKc/X1EfF5SpD5dOAn9Wu5qhd5c+CvKK8z9cD09+idMb5YvBS4J3Cz2rp7Ab+KiC9RmqFcSHmNuxElK3wmYPt7SimKp9T2rWfvta211wPgLcDj2f37PgT4bkS8DPjvzPwtXPG6cTPKVOu/BVZV25/C8A3JnkRpInXj2rr7AXeIiNcDnwZ+UH0IMJOBfn3Kc+r+wJ2AHwOvG/K8kqQWGHyUJKkEOdbWlnsFFW/N7jdT/babiszcGRGPBH4OvITOAOKVKW/mnjSzeURcXG2znv6zIZpdrZvn/XwVdHonnXUzb1XdAHZFxKXVeeY630JwEPDY6gZARGwGNlOCDb2y8b7L5Gq/1X++o3oU0C3Y8M+U633GDuDBmXlhv4Nl5rkR8TBKNuLM7/jOlLqRLxxiXE+i1FS8BiX79p0AEbGBcs2u7bHfl5jdZKVtL6CM6VnAzLTp5ZTg6AMBImIHJUN4DfOvi3hdOjusZ/Vz2U6pe7myx35vzMxPznHsvYATqxsAEbENuIxS67ZXs5o/UwLTkwy2jV1mbo+IuwNfoHwwMmMVJQjZq3HKRZRA+f0a6yeZCdna60Fm/jgiXkNneY/9KNPN/yUiNgFbKUHZZumAzwP/ypDBx8y8PCLuTcmGr3eyP4Da60n1dyUo1+bUynZIkvpb6P/4S5I0CfUp1zuBrw2w3VZK5tGCksXLKW+c+wVHg/Imbl96/z/wE+ABmXnHHo/Xz/sR4Bb07jC8jPLGtN/5dgE/mOtcLZorcLKWkgnYK/D438AdM/PysY5qwiLiRErWXN2zMvMbg+yfmV+kZJDVPa+qHzmoCylBy2YgZD29A48fA+5bdaKemOo59xxKBuNve2y2AjiQ/oHHzXQPBEP/azMoQccD6R543Ao8PzOf1OWxuY4NJfB2JXoHHs8AjqtqwC56VXbtHSkZc80p6t18i/L9n0Fn0yPo7AK/2D0DeEePx9ZR/p40g38fomQh7hjlhNVU+NtRPnzolTm9L+Xn3i3wOEy2tSSpRQYfJUmaXe/x0h7bHV+7f9qkgxzDyMwfZOadKFMk3wj8bsBdzwbeRHkzfYPM/OgQ5/x5Zh5HCRp9nM5akL1spmTGPB04PDOn2S37DcBtKTUcv0kJ2sxlM+UN9h0y88GZubHF8bUuIq5JaaRS94nM7DUduJeXULIQZywD3hsRVxn0AFXDmGOrY/25z6Y/AB6YmfefZuA3Mz9FaTz0KODrzO4Q3M0FlOntJwGHVrVbu3kYJbj5VuCnDDad93zKNX39zPynPuP+AXAU8I+U5+IlAxx7F2Uq7SOBm2bmLwfYZ9HIzA2Z+RTKNN4XUj5U+QMlu3QTJbv8XZRMyOMy8xfVrs0SABdPZsTty8xdmflY4EHM/SHRtykfXP3FfOsiZ+bmzHw0pTTAe+j/WgDl2vwW5Xoe5gMPSVKLYpHPjpAkSQOqGojcALg6ZcrcKkpzlIuBPwLfy8y53tgNc77llHpd16JkZe1PCchspLyRP5NSC3J7r2NMU1UL9HqUab9XpmR+LaeM/8+UINBPMnOQIKX6iIjmP6TXyMyza4+voGTW3pByLW2hXEPfW6iBr1qN2KtSxryeEqy+DDiHcv3/bpSpylXTjetRajoeTMk8S2AD5bn8I0qd0qEzv6qarNeqblejZJatrsZ+KfAL4Id9PqTZY0XELyk/txk3rJqnLDkRcS3K9X0oJRt8I+W6/nZmntvieZdR/q5ch5KRux9wOeXv2C+BH/VpGCdJmhKDj5IkSZqquYKP0kJXdTj/UW3VRmDfUQLAkiQtNU67liRJkqT5eX5j+csGHiVJKgw+SpIkSRIQEatH2OepwF80Vr9pLAOSJGkJMPgoSZIkScXLI+JjEXGPqu5rTxFxREScDLym8dC3gS+0NUBJkhabFdMegCRJkiQtEMuBE6vbhoj4FqWW4/mUOo57A4dRmq3cvNq+bgPw0FEaCUmStFQZfJQkSZKk2dYDd6lugzgPuH9m/rq9IUmStPg47VqSJEmSirOArUPusx14F3CzzPzm+IckSdLiFs4IkCRJ0jRFRPMf0mtk5tnTGIsUEeuBuwG3Bm4EHA4cBKwFErgE+DPwQ+BrwCcz83dTGawkSYuAwUdJkiRJkiRJrXDatSRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa1YMe0BSJKk3iLiFOAOtVV3zMxTpjOa6YqIVcAxwDWBw4B1wE7gEuBi4EzgJ5m5Y1pjnISIyPpyZsYc258MPLK26lGZefL4RyZNX0QcD3ylturUzDx+KoPRghAR1wKuC1wN2AdYBWxk99+OnwC/zszsdQxJ0vwYfJSkJaRLkKGXncBllH+8fwl8G/hcZv5fa4OTRhARa4GHAX8B3A5YPccul0fEd4H/Bj6YmX9qeYiSFrkh/3ZuAC4FzgFOB04FPpOZ20c894uAF3Z56NOZecKIx2wG0e6amV8a5Vi1Y34CuG9j9ecy857zOW4bImIZcB/gwcA9gAMG2G1DRJwOfBL4SGae0+IQJWmP47RrSdozLQf2B64B3A14HvD1iPhBRNx7qiNrSUTcOCJeVLudNO0xqbeIWBkRTwd+B7wNuAtzBx4B9qIEKV8P/CEi3hURh7c3Ukl7kOXAfsDhwO2BpwIfA34fEU+PiOVjPNd9IuLWYzzeyCLiYOBeXR66W0RcZdLj6SciHgL8CvgE8FAGCzwCrKfMMng1cHZEfDUi7t7OKCVpz2PwUZJUdwzw6Yh47bQH0oIbU7JLZm4nTXMw6i0ijqRk4/4L/d84bgH+DGzt8fgK4K+AMyPi/411kJK028GU16tTI2L9GI/7ijEeaz4eQfcZc8sor7FTFxEHRcRngfdTPljtJSmzPi6lZLL2cjvgcxHxubENUpL2YE67lqSl7ZfAv3VZv4IS1LkJJaNs78bjT4mIyzLzBS2PT+oQETcDPs/soGMCXwI+U339bWZuqO13KOV6vitlqt2Va/uupv+bUUmq6/W3cybz8SjgzsChjcdvA3w8Iu6ambvGMI7bR8Q9MnPaAbCT+jz2KKYcJK0+sPoi3V/nzwA+DXwZ+DFwUWburPZbARwB3ILyv9D9KL/fuhu0MWZJ2tMYfJSkpe0PmfmWfhtExAGUNw5/03joeRHxscz8fmuj05z2pEYJEXFd4AuUkgB1XwWenpnf6bVvZp4HfBb4bEQ8A/hL4KWUN5Z7tMw8CTN9tYeoGnL1bcI0gEH+dq4CHgf8M7Cm9tCdgIcD757nGGa8PCI+P61mKBFxczoDcDPjmPkZXzsibjOtmtHVB09fBq7eeOhHwHMz81O99q2ak/2qur0/Ip5AyfJ8Nn5gJUlj5bRrSdrDZeZFmfm3lCljdUH3IvjS2EXEXpQaXc3A45spHb57Bh6bMnNHZr6Xkp30Gna/WZakscjMbZn5OkpdwaZnz+PQ59P5mnUT4EHzON58Paqx/GU6u4l322YiqsYyH2Z24PEjwC36BR67ycwtmfl2yt+O5wHbxjJQSZLBR0nSFZ4D/Kax7h5VUEhq20uB6zbWvTEznzDq9MXqjeTTgAcCm+c7QElqysyPUbKu664XEc2A2KB+BXygse6lY25mM5CIWAM8pLH6XdWt7i8iYt1kRtXhKZSp7nX/DfxFZm4Z9aCZuTUzXwbcitn/F0mSRmDwUZIEXDH96B2N1auBBdFtU0tX9Sb97xqrfwY8fRzHz8yPAv8xjmNJUhf/3WXdLedxvBcAO2rL12E62YUn0lkDcSPwUUq24Yba+vXAAyY2KiAi9qVkJ9adCzxuTPU2qcrO3GMcx5KkPZ01HyVJdd/osu7wUQ4UEYdQ3nxdg/LGZAvwg8z84gD7XplSAP5g4EBgE/An4BfA6dOqfTWXqvbUzSnjPojShflPlOYF350pcj9tVTbrbYDrAftS3kSeD/xfZv5+CkN6CrCqse6J88lcaRrlzegkr8OIuAbl2rkysBa4CPgpcFpmLripfxGxkpIVdAPKVPnLgQuAb2Xmr8Z0jqD8/K8DHFatPh84IzN/MI5zjFPV5fjWlN/hQZROun8Cfgd8c5zXc+O8y4GbAkdTrtUVlC7wH8nMP7Vxzi5jWEv53q9LCVZdDpwFfC0z/zzA/uuB46r996F0Iz4H+Epmbmpn1GP1oy7rDh71YJn5q4h4B/C3tdUvjIj3tnUd9dAMeH545vcRER9uPP4oxlfnchCPZnZjsmdk5sXjPMl8r78qI/Q4dr8uLKe8Lvye8je3laz8iDiI8nf+msBelL8pvwO+mpmXjfE8AdyI8jp9EOV/iospfw++nZm/G9e5JC1ymenNmzdv3pbIDTiZUitq5nbKkPsf1dg/Kf/Mz3WeF9UeuytwCrCry7F6jgdYCTwJ+GGX/eq3C4C3AFcb4Ps5fo5jzXU7YoBz7AU8jdJRs9v3PHP7M/C2QcbdOP4pjeMcP8f2J/X6mQOHUGoobuozztOAO03wmp15I1Yfw0+n+Bwa+3U4x/nuDXy7z3kuBV4N7Ffbp2ObAc5xcmOfk4Z83pxde2w9pUHVxX3G/FPggfP4mewFvAj4Q59znAU8AVjWY8ynjHr+EcZ7L0oNvG19xrsZ+CRwyyGPfUSv3zflTf4rgQt7nPP4MX1//a6HQ6rnweYeY9hKyTq+Uo9jX4MSsLq8x/6XA68F9p3HeOe8Fro8R4a6foBrdxn7cwfc90WN/b5erb9yl5/r0wY8ZnMsdxnh935VSgC96zUF3KHx2C7gmhN83jVfoy8AVk3q/AOM736U+phb+7wubAE+M8LrQvN6fVHtsZtQuns3f3czt+2UTN1rz/P7uxZltsz5fb6/mb8HTwBWTvt34s2bt+nenHYtSarr1iE0B9oxYkVEvJnSrfgOPY7Va99bAmcCrwduOMfmB1GyQX4REc8a9BxtiIgTKfW5Xk355L/f93wA8NeUcTenGLcuIu5EeRPwOEpmXS+3BP43Ip47kYGVjJArNdY1p/9PxCSvw4hYHRHvo7xJvHmfTfehBLd/FBE3GvY841Sd/0fAs+icitl0FPChiHhz1RBimHMcA/yE0uzqsD6bXgN4I/CViGhmP01ERBwcEV+mBA+OpwSue9kLOAE4LSLeX2UKzufct6A8n59JycqduIi4QzWGv6V8f92sAh4DfCcirtXY/4GUANIj6OwWXbeGkhn9zSqzfKHap8u6eWXMZeYfgDc0Vj+7yhKdhJPoLNF1DnBqbfmrdNZDDOCR7Q8LIuJwZr9GvzcXQJZ4RBwVEd+lTE+/I7Oz+utWUz68OC0i/rPqoD6fc/8D8B3Kh1q9XntXUBoYfT8i7jbCOVZFxOsppVEezdwZvkdRXqt/EhHXH/Z8kpYOg4/SFETEdSLiARHxxIh4dkQ8NiJOiIjrD/tGTRqzbm/uLhxw37dRAlt1OykZUj2nvEbECZTsgGv02OQSOmtfzVgDvCIi/mMaz5sq4PRRSnZK0y7KuLtNp1oDvC4imt3FW1MFHv+H2VPULqFkXnTzTxFxUovDmnGHLutOmcB5O0zyOqzeYH6E7l1yoWR7bWysuyrwpYg4cpBzjFtE3ICS3Xd446HL6B1keRyldt2g57gx/X8Hl1KyiOpuT7m2ewWvWhER16ZkCd+xxyYbKb/Hbh5CCZo2g+6DnvuGwBeZ/dqziXkGvIYYwy2Y/ZqyizK1s9vz5AjgMzNB1yrw+AFg79o2/f5eHAV8bAH/j9Ttg4HfdFk3rFdSrvsZVwL+YQzHHcRJjeV3Z+YVH0ZW95vTrB9ZTcNtW7e/G1+dwHn7qoJ536SUQehmI52/z7pHAV8c9YOJ6gPDf6XMJpixk97PyXXAJyLiekOc40DgS5TZAd3Kt23rc75rA9+IiGaDIEl7iIX6B1wLUEQsi4ijI+KREfH6iPhmRGyOiKzdjp/2OBeqiFgZEX8fET+iZNZ8mPKJ9suBt1OmY/0E+HNE/JcdhjUl3ZrLnDPAfg9kd+2nDcCLKbXgVmXmAZTAwE1ovFGJiKMob0Cb/2x/ilLkfU1m7k/JHDiK0hG5+Yb+McCze4zrF8Djq1vzTdIva4/1unWtVRYRT6RMPa2/yfoD8Hzg2Or73j8z11GyAv4K+HHjME+PiElkiRxKmWK1mvJG5D8pb9xWV2PcCziS8rNtBiJfExH7tzy+YxvLWynZUBMzgeuw6RWUzJS631OCdYdm5trMXE/JaHsk5W8GlMDDewc8xzjtBXyMUtuR6v49gHWZuW9m7g1chRIUuaSx73Mi4jpznaCqi/ZRZmfxfZGSMbguM/fLzDXA1YEnA+dV29ySkik5EdXf548zO0j6Y0oW3wGZuT4z11KyNx9PqbVWdwvgfSMGav6L3Zl2pwL3p0xL3rv6XRxICRz9cYRjD2Iv4IOU58sOyrTrW1Je9w6kvN7fjhKsrrsO8Mwq2HEyJUiymRJgO4YyLXPm78U9gWZdz1tRMq0Wogc3lndSgtPzkpkXUQJKdU8bNXA9qIi4PeXvQl23eo7vonN2xOHAndoaV82Nu6z77gTO21NEHEv5X37f2urNlOfH8ZTXsPWZuR8l8HcPStZ03e0pWYLDuhvl7xKU4OaL2P2cOpDyt+vmlNeOujXAWwc5QVXn95OU53bdqcDDgKtk5ura+W5EeY9Tb0y0L/Dhqia4pD3NtOd9e1scN0qGxkb61/QYW32hpXajfAL64wF+fvVb1/pI3rz1uzGPulWUT7HPauy/BdhrgPPUa/tcdcDzLaPUSazvvxN49Bz7XZsSEG3WMLrZHPudNOrPpnGcmzG7htN7gfUD/Hzf0thvE3DYHPudMszrbJfvc+b2J+A2c+x7x+pnWd/vyS1fsz9onO+MCT9nJn0d3orZtbi+0O/6oQSO39fj95oDfI/N5+tJc2x/fI9zbQLuN8e+N6C8+a3v928DjPHfu5yvb307yhvZr/UY6yktXjOv73K+t9GnphmlVubnu+z31DnOdUSv3zvwDxN6jvS6Hi4GbttnvxWU4Ep9nwuBr1f3zwau12f/vYHvN/b//gjjnfNa6PIcGfj6oXz41vzZfHqI/V/U2PfrXX4Ozbp6fZ9TXcYzVM1H4J39xtTY9tTGtu+dwDX5iea1OInnQp/x7Mvs/5++Bxw5wL6PogTx6/ved8jrtX7Ouf6neH6X/Y4ZYJyvbuyzGXjoAPtdi/IBWn3fj07z9+XNm7fp3Mx81KBuSvmUTkOKiOMo08iOrq3+HSUI8Q+Ufzr+jvLG6zT6TE+VWvYyZmfyfDYze00dbLoUuGsO3i35RGZPVXtmZv5nv50y85fAXeic0rwCmFSNwlfRWcPpg8AjMnNDj+0ByMwdlAyoT9dWr6XUM2vbDkrQ6P/6bZSZX6E0pKl7YGujKpp1/ebsjDtmJzLZ6/D5dM48+TlwYr/rJzO3UjIgpz2t8DGZ+bF+G2TmjymZz3V9r6GqZuPfNFb/a2b+2xznupRSL+3sftuNU1V3sDnWTwF/m5nbe+1X/X5PZHZW77MiYvUIQ/nXzHz1CPuN00Mz8+u9Hqxe855A5/81B1I68G4FTsjMn/fZfyOzXx9vPK3SA01V7bunMDsbeRuDZ0HPqfo5vKyx+gkRcbVxnaMuIvZm9nP2XX12ObmxfP+I2Hesg5qtWXLg4pbPN5en0Pn/05mUgO+v59oxM9/J7OvlOSOM4ffA3TNzroznlzE7q3iu1+hr0vlcTOBBmfn+uQaVmb+ivE7XO2yfWM04kLQHMfioUWylFDN+C9OZ/rVoVMXVP8/u6VGXUQqzH5GZj8/Mf8vMkzPzDZn51Mw8jvIP1XMp/7xKrYuI/SPiLcAzGg8ls4MI/bw0M88dYvvmm8ofAK8ZZMcq8PNPjdX3jYheteLGoqpzVq/xdjHwxMzMQfavtnsanW/G/3oCdcze2S9I0PC2xvKxLY+v2TyhVz2stkzsOoyIIyhT7er+LjO71QZtnmsmeD2tD6i+mJkfGHDbd9JZ8+tqc0yzO4nOmo1/ZMBp1FVQ72kDjmscHk/nhw+XA08Y5DWg+iCnWRf3EOAvhxzDnxiilmZLPpmZn51ro8w8h5Lp2PTmzPzRAPt/FfhtY/XNBhvivFw5Ih7X5fbEiHhORLynGtdrKZnJM3ZSgvRzfm9DegudJVBW016pgQfRWYtzC6V0Ry8forPW6F7MnoY+bns3li9p+Xw9VTUam03kHp+ZwwREX0PJnJxxy6oG7jCemZlz1ujOzF2UDvR1/ZqeATydzlqS78nM5pTxfuf8NSXJYkZQ3g9J2oMYfNSg3k35pP+mlKlht8jMxwP/O91hLVxVHad3sPuN9Qbgbpn5tuoPf1eZeX5mvjwzL+u1jTSEfm+gnh8RH6W8ger2T+DLMvOMAc+zndnZDz1FxD7AbRurX5+ZOwc9BiVDr16jcBmzAzvj9rDG8vsyc6hMvSpgVa9NdQClNlObmtmMPVWZa/XXn3VAKxk2lWbW10QaZsBUrsP70Pm/15mZ+aVBT5SZP6Vk0k/DMNfQxZROqHX9mhrcvbF88iAB2ZpPAsN88DEf92osf2SIbG8y85vAt+Y45lzeM0RGelvePsS23+myrhn86KdZy2/gBhnzcG3KNd+8vYGSOfZwSuC47kxK9v/YP5jP0sX5RY3VJ0XEdcd9LnbXb57x8SrLuKsqM/Ojcxxj3Ob9dyMift+oW9/vdkqfQ92dUo93xo+rWQQDqz5c+nBj9fFDHOIi+geIm77RWO75nKo+fGx+QPK6Ic41o1lv8vgRjiFpETP4qIFk5gsy8+2ZeXq/aUXjEMVNI+IREfEPEfH06v7Rc++9oDycUjh6xjMzs/mGQ2pbvzdQLwHux+wMAoB/z8znD3GeHw4ZhLsVnX+DktlvXvrKzEsozSjq2u6ieIfG8udGPM73GsvHjXicQVzC7ClWc/lNY3m/sYyku2b34kmW+Jj0ddj8PX98mHNVhhrfGJ065PZnNZb367ZR9UHdLRqr/2eYE1XB4s8Ps88oqgynGzdWf2SEQzWDBMO+bg0V2GhB0j2bsZdm5uJFlPrAo+6/3xD7TkJSyikcPWzQaUjvoTOov5zdTUbGopqx02wo0m/Kda9tbtXytNrm7KBploZaCP8XfL0KYA5qoNfnyjHsbjYGcGFmNsc6p8z8GZ2N2m5YTfGXtIdYMe0BSDMiYj3wTOCxzP40eWabXwIvzMzmp2cL0ZNq93/FgN3kpCn7EfDsYabT1PYbxg0by78ecorSjO9SOuHOaC2DsAo8NMd98xHrbjWn5TbrHo7Tb/tlW/fQrD+4T9etxmMjnVNu264VVjfp67BZW3LoN3Aj7jNfl2XpujuMQa+hw+h845uUBkDD+v4I+wzrKGb/79wtq28uzUy+q0TEAUP8jMc9pXdYl1ZB90E1s9J+O2ipisrGxnKbr0ejCEqJkjXA89o6SWbujIjn0RnwfmBEHJuZp4/pNCc1lv/I7A9XuvkyJUh89caxnjmWUc3WvCYm+Xej6VaN5atERLO8wiCawdph/i84e8hzDfM3vvn9bRzx+4MSNN6rur8MOJjZv0tJS5TBRy0IEXErSgZIv5pQULK43h8R9wMe1nYW5qgi4hg6MzneMcKbf6lNOylTay8Bfgl8G/jcXA1J+hi2SciBjeVmpt2gmsXcm8cdp4OZPWNgXDW32hz3JSPs05x2vLzbRhFxR2CYaX+fzMw/NNb9kc4pawcMcbz5mvR12Fx/9gjnGmWf+bpkhH0GuobozKgB2FBN4xzWXE0WxqH5+9s+ZJ3bGd2aUBxIyQgcxKSbMjUNWxameS3Md/9e19I4nZqZx9dXVFm6ewNHAnej1IudaXyyDHhuRKzKzGb95LHJzI9GxHfYXaMvgJczhpIj1fTav2qsfu8gZSgyM6s6mPWGW4+IiOcMWcZiUH+glIKaMcrfjWfRfeYHlAZfzaBbL4c2lh9S3eZrmP8LLhnmwFUgu76q32zI5vd3BEOU4ZjDgczOwpS0RBl81NRVb14/Ten6OuPMat2vKUXrrwv8Bbvrjj2Ikh3RdkHrUd2tsTzUFDJpjGa9gWrJsMGCZsBh1CYjzf3aDFy1GSBcO/cmIxsmw2hYj6xug/o55U1j3a/pzEA8KiJWTujDpUlfh83zjVLbd9INeaDda2i/xnLfrvF9TKJOclvXCwzx2jVicHac5ns9tHk9tabK1txAycw9IyLeTCmDcJfaZv8YEd/NzGHq7w3rOXRmI949Iu6QmcOWRmi6C7Pr+w4y5XrGyXQGHw+jBEWHnUkxiFnThiPiKsN8GNCvNmeVFDFo8LGt/w2G+b+gzefUYv3fR9ICY81HTVVEHEwpQDzzx2cL8BjgqMx8ema+uao1+XRKALI+dfkvIuIRkx3xwOpZjxuAHwNExLER8YaI+ElEXBYRGyPiNxHxsYj4m4jYq/vhpCUnGsvj+se5zX/AV829yciaP489SXO64GpmT4duy7Svw0UZhBmzZs3PUZ9nbT4/Z7R1vYz7WJqAqtP6/Zldv/JNEXGlLruM67xfYnbjqVeM4dCP7rLux4M2ZaHMomhqq/HMGV3WzdWxuS1tvfYslP8L/N9H0lgYfNS0vZLdU613AffLzP/sVgsoMy/PzMfRWevmpdU0kYXmJrX7vwTWRMQbKXWenghcn9IFex1l+sKJlMDqWRFx/4mOVJqO5vTC/UY8TrPO0yj1+gbVnOqYwLrMjDHcTmpx3Atdt2ydO07o3JO+DpvrR6lTNs3aZm2Y9TOJxnzAAe03hrHMpXm9jPq76LZfm69dakkVgHwU5X/YGQcy5kYwXTy7sXxcRJzQdcsBRMR+wP+b14i6OyEi2sic+2qXdbfvsm4Smv8b3HNM/xccMY1vpovm9/fBMX1/kZmnTOMbkjQdCzFooz1ERBwKPKy26j8yc5AOcU8GZqbjHQ7ca9xjG4ODavcvAD4EPIHdn/BtA37P7BothwIfjointD1Aacqa/8weMeJxrjnHccfpT43l6HL+PUpmnjSGNxrfYPbv7THtjx66nPeIEY8z6HU4jvONss9Cdj6dWX+rGO151WZn3RnN39+qiLhy1y376/b9TbuOo0aUmd+mdKKue2zVObrNc368sfpl8/hA/qF0Nv4al1V0/q8/Fpl5FrMzTh8eEZPIgG5q/m/Q2u99Spb69ydpQgw+apoeSGcq/2sG2alqVvCl2qq7jnNQ81VlbKyvrbozuwOkZwL3A/bJzKtl5v7A9YB31g8B/FtE3HkS45WmpNmt9VpV5sWwbtZY/sFow5lb1QX5nMbq49s6354iM3cw+437UVU94LZN+jpsrr9p1636G2WfBauqX/jzxupBa63Nd59h/YxSh7qu+bsfRHOf34/QTVwLywspHyzPWAE8v+VzPo/OjMsbMnqjk+b06C8Ajx/x9ok5jj0u72wsH0SpCT9p328sHz+FMbSp+f3daMS/k5L2cDac0TTdrnb/rMxsvvno59vAPav7t+y1UURcdZSBDejSarpN0zo6A/srq6+nA3fKzI5C85l5JvDoiPgZ8Kpq9TLgtRFxTLcp6NIScBrlTdPMcyUogfnmm4meImJfZn/48I0+uzSDBqN0TP0i8Nja8oOBN4xwHHX6d0p2eP0DqTdGxLGZuWUcJ4iIZZm5q7F60tfhNykZRjNOpHRcHcZSLM3xf3RmLj4MeN+gO1fZh8ePeUyzZObmiDiDzuDhA4BPDnmov2gs93vd0iKQmedExLuAv66tflhEvDQzf9XSOX8SEe+ls0P1SyJiqGY3EXEDZgfE/zkzm3UlBz3ed+icwn3jiLhxZp4xyvH6eAelwc1+tXWviojPTjiY/0U6G6/dIyL2bf6/v4h9A9hEeX8DJX7wAMrPX5IGZuajpulGtfs/GXLf82v3+wUYf9fi7Yk9znl5l3W7gIf3+0ckM/+FzgLiN6Czg6K0ZGTmZcDXGqufNOSUsccB9SZNu4B+pRuaHxaMUq+t+abuthFx9xGOo5rMPBt4Y2P1UcC/jOP4EXE/OoPGM+ed9HX4aTozla4bEQO/zkfE9YE7DTG2xaIZaLx71W12UC9gtA8TRtHs3PvAYaZeR8QtmZ2l2UY3YE3ey9ldFgjKNdl29mMz4/KadHmtm0MzM/E84JRRB5SZ32N285mxZz9WsxFe1lh9ZeBtE64H/1mg3oF+HcN/qLRgZeY2ZmezPj8iVk9jPJIWL4OPmqZ6AeoTBu2mV3XUe1Nt3/0nPO6+MnMnpWt33Rcy82cD7P7axvKCmlIujdnrGsvHUmq6zikijmT2m7pPZOZv+uz2x8bytYatD5WZX2R2ltI7I+LqwxynbsTmGkvR85j9hvVJEfHGUd9IRsSaiHg1pVHZ2h6bTew6rIKszcDk6yKi19jq51oBvJkl+L9bVQu0/iHkMuDkiDh4rn2rwPLftDS0bt5CZ7BnLSVLd87ncUSsqfavOx/4wPiGp2mpnt/vbqx+WMu1H88G3tZYPXDAs3pdeXhj9Qe7ZIkP678ayw9rqR7jaygZ7HUPAD4SEXt12X7sqizL5t+Rf4yIe3bbfhAL8P+ClwI7a8uHM8/MxwX4PUpq2ZL7B1aLyn5jOs6cb9qm4LLG8lcG3O9UOgvvHzue4UgL0seZXQPvXyPiEf12qgI+X2L3FCAoU6qbGRBNP6Jz6vVeDJ8hAvAPdGa3HAb8X0QM1WkzIq5ZBcbePsIYlpzM3EyZqtfMEH8C8JWIGLi2XkSsiIiHURoSPI3dzb66+TiTvQ7/ic7sx6OAj0fE3n3OtQp4F9Pr5joJT6bz7991ga9GxK27bVz9jp8OfJDy+x3L9Py5ZOZ5zA72nEgJQPYsZ1T9fj8K3Ljx0CuqzCItDS+n8+/McsoHK236J8q02BmHDbHvvYFmkL8ZOBxF8xgHAiN34+6l+sD/AZQmjnUnAt+OiPsOe8yIuBFwkyF3+1fgt7Xl5cBHI+LxQ557v4j4e+BbQ56/VVVprDc1Vj8sIj4aEQcMepyIWBYR94iI/2F3+SxJewiDj5qmzbX7FwO/nsetqyG7sA57e2Wf7605pt923Wr2eC+rfhYzDuq1rbTYVZkVD6HztWA58O6I+FhE3GVmWk8U142IFwM/ZHbH3xdWU736ne9y4PON1W+MiC9ExIsj4kkR8bjGbX2X45xGKapfd1Xg1Ij4fEQ8NCIOr3+qX/3DfdWIuFdEvCgiTqe8TjyN8X0Qs+hVGeJ3p/N1EErQ7dvVz/fJEXH9ZrAuIg6u3tS8GjgbeC9wjQHOOenr8JvA6xur7wr8NCL+pp7tFxH7V0HQH7C7VmQzy2dJqOrLNX8u16UE9r8dEa+MiKdExLMi4j8o5U/+hVJXeQfwkuYhWxzuM5jdaffxwHer5/9+Mysj4pCI+BtKZmfzzfYXmJ0xpUWs6sL83sbqh7ec/Xg+pW7uKJrToX+TmfMOfFXBqjPmONdYVI0o78TshnA3AD4REWdExEsj4viIuFJEdJRoiIgDIuLWEfEPEXFqNe5jhhzDxcB96QwCrwHeFBE/joi/i4gbdjn3gRFxh4h4akR8AbgA+DfK9PmF5mnMTqa4H/CbiHhtRNw5IvapP1jNPjgmIh5WvW7/kTJN/Z4Yh5D2ODac0TRdCMz8kfpQZv7tNAczZj8BjqstD5ORUd92zXiGIy1MmfmziPhLSvZSfYrUidWNiLiEkl22ku7eAfT7MKDuZZTgVv3v313pXeLgc8yuFUlmvqOa0vVvjXHdrboB7IyIS6vH96Z/9p0qmfmtKPX+PkTnG8Cg8+dLRGyh1Nram/6vl5uAX/Q556Svw2cC16NcizOuBrwVeGtEbKZMcWsGvy+kTJFspYHFAvD3lN/loxvrb17dutlFmXZ9dmN9a5mQmXl5RJxI+TCjHuC+EVX9yojYQAli95qd8R3gYTaVW5JeBjyC3XVIZ7IfT2rxnK+iBMAHLkVUfdBxr8bqcWQ91o9149ryPSLisMxslkCZt8z8ZZR6qu9h9t/zG1W3mQzUrF7Pg/J6M9f74a9QZjzMNYYfRMS9KLWhD6k9dDS7P2TIiLiM8rq1D5OrVTtvmbkjIu4PvJ/OD1L2AZ5S3Yb5uyxpD+MnDpqmenfro6c2inac0VgeaEpClSlV/8fxz+MakLRQZeangDsCveo17kf3gM8W4NmZ+dhB61NVWWePYHZphKFl5huAOzA7A2rGcspzfz29A4/bGb7h1pKXmb+gBJueDVzSZ9M1wJXo/QZnK6VO4rUy83/mOOckr8OtlKBmr660a5kdePw9cNfM7Jntv9hVP7/HUhq6XTLALn8ATsjMd7L7w8wZg+w/ssz8JeVDxl5lVdbTO/D4X8DxmXlhG2PTdFXdrd/fWN129uOlwD8PudvDmf2aNs7g4wfozEBeTvn724oqA/TulO7fZ/fZdOZ/7f3oHXhM4OvA/8vMO2Xm9wccw1cpU7Z7/b0JSrO7/ekfeBzofJOWmZdQpuo/h84mO3Vz/V2G0tRo7EFoSQubwUdNU/0f9ltFxJWmNpLx+1Rj+cYD7nddOrNuzhrLaKQFrprmdT1K3bcfz7H5nyg1164zR/mDXuf6AHAk8CTgY5SMuEvorOM46LG+SZnadX/gi3RO3e3lUkrnyCcAV87MFw573j1BZm6rfr9Xo3SU/jKdjT562Uzp1Po44NDMfEJVp2+Qc07yOtySmQ+m1Ln8bp9NL6M0VbhhZp4x7HkWmyzeBFyL8hz5AiWQsIUSTD6H8vx5DHBkLajcrFvXnLrfxljPz8w7AfehXHP9XkMup3Q7Py4zH1rVONXS9U90NuiYRO3H11EC8oNqToP+SWbO9bo3sMz8LbMbtLUy9bp2zszM91BeP+5HCYAO+lqwkRJwfAHlteV2mfnJEcbwx8y8N3BTSib0nwbYbRvw1erc183MBdtwsvoZvwK4OmW8P2GwMhe/pDTcuidw1blKlEhaesLZHpqPiDgJeGdt1R2zdK0cZN8jKH+IZj51fFVmPnOc45umiPgWcItq8VzgiMzc0WcXIuKFwItqqx6Tmf/ZzgilhSsirkJ5/hxMKVS/ifIP/JnA6Qt1qmKUxiA3o3SCPJCSWbGFMnX7d5Tx/2bQDDl1qmovHkMJHh9KmQa9k/Lm8mJKRv1PsjQhGMf5JnYdRsQ1qnNdmfIh1MWUrNpvpg1J5hQRb6ezgdTfVdnJkxzDeuA2lN/hQZRr80+Uus/fzMyJNMWRtFs1q+jalA/4r0bJSl5JCTZeQnmt/SXw8zb+Nlfnv351O6C67aL8X3A+5QPQX1QZ8YtSlUBS/1u5ht0/319RfrbO5pL2cAYfNS/zCT5W+7+b3VMwdgD3zswvDLF/ACsX4huziHgQnVPqnpmZr+qz/dUo3Xj3rVZdRglYtp69IUnSYlUF/X9DCfrNuGVmfntKQ5IkSVKN0641bc9gd82PFcCnqm5zfQsUR8RhEfF3lCyXY1se40gy80PAN2urXh4RXZvqVHWAvsjuwCPAqw08SpI0p8fQGXi8kNm1lyVJkjQlZj5qIFV3s25Ze+vprLP0B0pdo6ZnZOZHexz7OEpH2Xqx+AspXSTPAC6i1MrZD7gOJdh4E3Y3cDguM08b8FuZqGpq+Tcp0wNnfJ9Sr+p3lE5wt6LUi1td2+Z/gbuPa+qgJEkLXUSsGnYmQ/U/xP/SWS/5lZn57LEOTpIkSSMz+KiBdJlePaxHZebJfY5/FPBxSnBxWLfIzO+MOK7WRcSNKN/bEQPu8lHgrzJzU1tjkiRpoYmIE4HnAm8APtkv+z8i9qF0xX4RsKr20KXA0Zl5bnsjlSRJ0jBWzL2J1L7M/FlE3AB4NKXL6PXn2OWnwGeB9y707p+Z+YOIuCHlDdIjgV5dvX8MvAz44EJtpiFJUstuBpwM7IiI7wI/pHS4vowyQ+BAygyI21IaDjX9jYFHSZKkhcXMRy1IVYfRWwGHAPsD2yjd6H4N/Dgz/zTF4Y0sIlZQOmFek/K9baV0uvtmZv5mmmOTJGmaqszHj424+3ZKh+u3jm9EkiRJGgeDj5Wqa/KRwA2Aq1HqD26m1Bv8AfCjSdffi4hlwK2rcR1GmUp0LvA1G5FIkqSlpKrf+ClKduMwvgI8Z6HWf5YkSdrT7dHBx4hYD5wA3Be4E3BQn80vptQ8/NfM/GOf7cYxrhXAM4En0Nm9ccY2yj/nT8/Ms9sciyRJ0qRU/wPdHrgdcFPgGpT/hdZRygVdSvlg+FfA14DPZebp0xmtJEmSBrHHBh+rwOMFwJohd70IeGxmjjotqK+IOAT4NKXm0VwuozQm+UQbY5EkSZIkSZLmY08OPu5HyWasOws4FTgTuJASmLwh8AA6m4TsBB407gBkROxFmTp0y9rqc4H3UmodHgjck5IRMGMLcKfM/OY4xyJJkiRJkiTNl8HHkj34TuA/M/OHPbZdC7wW+Ova6ouB62TmhWMc078AT6+t+jDw8Mzc2tjuoZROkCurVb+rxrJlTOM4D1hbHVeSJEmSJEl7rqsBmzPz0FF23pODj3sDzwX+JTMvGnCf9wEPra16YWa+ZEzjuSrwS3ZPA/8hcLPM3N5j+2cBr6itenpmvnpMY7ls9erV64888shxHE6SJEmSJEmL1K9//Wu2bt26ITP3GWX/PTb4OIqIuDLweyCqVd/JzFuM6dgvA55TW3WPzPx8n+1XAGcDV6lW/T4zrzamsfzk+te//vV/8pOfjONwkiRJkiRJWqSOPvpofvrTn/40M48eZf9l4x7QUpaZfwB+Vls1ztTA+9XunwN8YY6x7KBMF59x1YgYpEmNJEmSJEmSNBEGH4e3sXZ/3TgOGBHXAI6qrfpSDpaS+sXG8n3GMR5JkiRJkiRpHAw+Du+I2v3zxnTMGzWWTxtwv28DO2rLx4xnOJIkSZIkSdL8GXwcQkTcFji4tuqbYzr0UY3lXw2yU9Xd+g+1Vdcf03gkSZIkSZKkeTP4OJxnNJb/e0zHvWZj+bdD7FvftnkcSZIkSZIkaWoMPg4oIh4CnFBbdQbwiTEdvtmq/KIh9r24dn9lRKwew3gkSZIkSZKkeVsx7QEsBhFxNPC22qodwF9n5q4xnWLvxvKWIfa9vMuxtg6yY0T8pMdD4+ziLUmSJEmSpD2UmY9ziIjDgM/QGSB8VmZ+d4ynWdNY3jbEvs1A417zHIskSZIkSZI0FmY+9hERBwCfBw6vrX5bZr56zKdqZjqu6rKul+Y062YmZE+ZeXS39VVGpM1rJEmSJEmSNC9mPvYQEfsAnwNuWFv9PuDxLZxuY2O5mQnZTzPTsXksSZIkSZIkaSoMPnYREXsDnwVuXlv9YeCRY6zzWHdZY3n/Ifbdr3Z/e2YOVO9RkiRJkiRJapvBx4aIWEup8Xjr2upPAg/NzJ0tnfY3jeWrD7FvfUr4WWMYiyRJkiRJkjQWBh9rImIv4FPA7WurPws8KDO3t3jqnzaWrzXIThGxBrhyn+NIkiRJkiRJU2PwsRIRq4GPA3eqrf4ScP/MHKb79Ch+0Fg+bsD9bkFn06AfjWc4kiRJkiRJ0vwZfAQiYhXwEeButdVfAe6bmYN2nR5ZZv4G+Hlt1V0iIgbY9a6N5U+Pb1SSJEmSJEnS/OzxwceIWAF8ALh3bfXXgBMy8/IJDuVjtfuH0xkInaUa96Nqq84FvtvCuCRJkiRJkqSR7NHBx4hYDrwXuF9t9TeAe2Xmpnke+4iIyNrtlDl2eTNQ71T9qohY2Wf7pwNXqS2/NjNzxOFKkiRJkiRJY7fHBh+rac3vAB5cW30acI/M3Djp8WTm74A31lYdA7yvqkXZISIeAry4tupc4A3tjlCSJEmSJEkazoq5N1mybgs8srHu6sD3Byu3eIU7ZOa5YxrT8ymdtm9WLT8IuHVEvAc4C9gfuBdwh9o+W4G/nERtSkmSJEmSJGkYe3LwcXmXdVce4Tj9pkYPJTM3R8QJwGeAY6vVVwGe1WOXDcAjM/Pr4xqDJEmSJEmSNC577LTrhSozzwNuBbwAOK/HZtsoDWpulJkf67GNJEmSJEmSNFV7bOZjZp4CDDW/esjjnz3q8TNzO/DSiHg5cGvgWsAhlEzH3wNfy8yLxjRUSZIkSZIkqRV7bPBxMcjMncDXqpskSZIkSZK0qDjtWpIkSZIkSVIrDD5KkiRJkiRJaoXTrqUWvf9bvx3r8R56y6uP9XiSJEmSJEltMvNRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRUGHyVJkiRJkiS1wuCjJEmSJEmSpFYYfJQkSZIkSZLUCoOPkiRJkiRJklph8FGSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkqSJEmSJElqhcFHSZIkSZIkSa0w+ChJkiRJkiSpFQYfJUmSJEmSJLXC4KMkSZIkSZKkVhh8lCRJkiRJktQKg4+SJEmSJEmSWmHwUZIkSZIkSVIrDD5KkiRJkiRJaoXBR0mSJEmSJEmtMPgoSZIkSZIkqRV7fPAxIpZFxNER8ciIeH1EfDMiNkdE1m7HtzyG4xvnG+Z2szbHJkmSJEmSJI1qxbQHME0R8RHg7sC6aY9FkiRJkiRJWmr26OAjcFMWZuDxHGDHgNtuaXMgkiRJkiRJ0qj29OBj3Vbgh8D3gL2Bh09xLMdn5tlTPL8kSZIkSZI0b3t68PHdwO8oAccfZeZ2gIg4iekGHyVJkiRJkqRFb48OPmbmC6Y9BkmSJEmSJGmp2uO7XUuSJEmSJElqh8FHSZIkSZIkSa0w+ChJkiRJkiSpFXt0zccF7OURcX3gcGAdcAlwHvBN4PPAJzJz5/SGJ0mSJEmSJM3N4OPC9JDG8kHV7YbA3wBnRcTTMvMTEx+ZJEmSJEmSNCCDjwvXxcBllMzHA+icIn9N4OMR8fLMfO6oJ4iIn/R46MhRjylJkiRJkiTNsObjwvFn4PXAPYADM/OAzDwiMw+iBB/vD/xfY5/nRMRTJjxOSZIkSZIkaSBmPi4M3wOumplbuj2YmZcCH4uIjwPPBV5ae/ifI+Kjmfm7YU+amUd3W19lRF5/2ONJkiRJkiRJdWY+LgCZuaFX4LGxXWbmPwFvqa1eDTyjtcFJkiRJkiRJIzL4uDg9D7i8tnzCtAYiSZIkSZIk9WLwcRHKzD8Dp9ZWHR4Rh01rPJIkSZIkSVI3Bh8XrzMbywdPZRSSJEmSJElSDwYfF6/LG8trpzIKSZIkSZIkqQeDj4vXIY3lC6cyCkmSJEmSJKkHg4+L1+1q97cD505rIJIkSZIkSVI3Bh8XoYi4J3Ct2qr/y8zN0xqPJEmSJEmS1I3BxxZExBERkbXbKX223WvIYx8GvLWx+uThRylJkiRJkiS1y+Dj9D04Ik6NiPtGxKp+G0bEXYBvAVerrf4B8J42ByhJkiRJkiSNYsW0BzBNEXF/4FVdHlrfWH5fRDS7SwM8IzM/Ooah3L66XRIR/wf8EPgjsIHSxfoawF2BGzX2Ow84MTN3jWEMkiRJkiRJ0ljt0cFHYB/gyAG2u3Kf/cdpP+De1W0upwEPz8yzxzwGSZIkSZIkaSycdj193wXeCfwMyDm2TeAbwMOB22bmr1semyRJkiRJkjSyPTrzMTNPpoVmLVU2Ygy47Y+BRwNExH7ATYCrA1cC9gK2ApcAZwPfzsxLxz1eSZIkSZIkqQ17dPBxocnMS4CvTHsckiRJkiRJ0jg47VqSJEmSJElSKww+SpIkSZIkSWqFwUdJkiRJkiRJrTD4KEmSJEmSJKkVBh8lSZIkSZIktcLgoyRJkiRJkqRWGHyUJEmSJEmS1AqDj5IkSZIkSZJaYfBRkiRJkiRJUisMPkot+945F/GpH/yBizdvm/ZQJEmSJEmSJmrFtAcgLWV/uORyPnL6uQBcsnkbjzjuiOkOSJIkSZIkaYLMfJRadO4ll19x/6wLN7Erc4qjkSRJkiRJmiyDj1KLLtuy/Yr7W3fs4pLN2/tsLUmSJEmStLQYfJRatGHLjo7lP156eY8tJUmSJEmSlh6Dj1KLNlzemen4x0u3TGkkkiRJkiRJk2fwUWrRZbMyHw0+SpIkSZKkPYfBR6lFG7Z0Zj6e57RrSZIkSZK0BzH4KLVk566cVfPx4s3buXzbzimNSJIkSZIkabIMPkot+fOmrWSX9edd5tRrSZIkSZK0ZzD4KLXkgsu2dl1vx2tJkiRJkrSnMPgoteT8HhmO59l0RpIkSZIk7SEMPkotOb9n5qPBR0mSJEmStGcw+Ci1pJ75eNDeqzvW79zVrRqkJEmSJEnS0mLwUWrJBRt2Zz4eefDeLItyf8eu5MKN3bMiJUmSJEmSlhKDj1JLLqhlPh6wdiVXqmU/OvVakiRJkiTtCQw+Si05f8PuAOP6vVZy2L5rrlg+z47XkiRJkiRpD2DwUWpJveHMPmtWcti+e12xbOajJEmSJEnaExh8lFqwY+cu/ryxHnxc0ZH5aPBRkiRJkiTtCVZMewDSUvTnTduoN7Rev2Ylq1bsjvVv3LqDDVu2s37NyimMTpIkSZIkaTLMfJRacH6t2czqFctYtWIZ69esZO/Vu+P955n9KEmSJEmSljiDj1ILmvUeZzj1WpIkSZIk7UkMPkotqGc+rt9rd7ZjZ/DRjteSJEmSJGlpM/goteCCDd0zHw+147UkSZIkSdqDGHyUWnBBLfNxnzXdMx8v3LiV7Tt3TXRckiRJkiRJk2TwUWpBx7TrWubjlfZezYplAcCuhAtqtSElSZIkSZKWGoOPUgvqDWfW1zIfly8LDtnHuo+SJEmSJGnPYPBRasEFG+rTrld2PHZovenMZdZ9lCRJkiRJS5fBR2nMtu/cxZ83bbtieZ+9OoOPHR2vLzH4KEmSJEmSli6Dj9KYXbhxK5m7l+vTrgEOq3W8Pu+yy8n6xpIkSZIkSUuIwUdpzOr1HvdauZyVyzufZofWaj5u2b6Ly7bsmNjYJEmSJEmSJsngozRmnZ2uV8x6fK9Vy1lVC0hu3mbwUZIkSZIkLU0GH6Uxu+Cy3s1mZuy1avkV9y/ftrP1MUmSJEmSJE2DwUdpzC7YsHvadbfMRyjTsWdsNvgoSZIkSZKWKIOP0pjVp103O13PqGc+btlu8FGSJEmSJC1NBh+lMas3nDHzUZIkSZIk7ckMPkpj1tlwZoCaj2Y+SpIkSZKkJcrgozRm9ZqP+/TIfFy70oYzkiRJkiRp6TP4KI3Rth27uGjTtiuWB+p2beajJEmSJElaogw+SmP0p41bO5Z71nw0+ChJkiRJkvYABh+lMarXe9x/7UpWLO/+FNvLadeSJEmSJGkPYPBRGqMLasHHg9ev6bmdmY+SJEmSJGlPYPBRGqPzL9s97frgfVb33K6e+bh5245WxyRJkiRJkjQtBh+lMbpgw+7Mx0P26Z35uHbV7lqQW7fvYldmq+OSJEmSJEmahokEHyNi30mcR5q2eubjIQNmPiawxanXkiRJkiRpCZpU5uMfIuLdEXH7CZ1Pmop6w5l+mY+rVy4jass2nZEkSZIkSUvRpIKPewEPA74SEWdGxNMj4qAJnVuamAvqNR/7NJxZFsGalTadkSRJkiRJS9ukaz4GcG3gn4HfR8SHIuIeEx6D1JrzazUf+zWcgUbHazMfJUmSJEnSEjSp4OMrgT821q0E7g98JiLOjogXRMTVJjQeaey27tjJJZu3X7Hcb9o1dNZ9NPNRkiRJkiQtRRMJPmbmc4CrAycCnwRmIi1R3a4OvBA4KyL+JyJOjIjl3Y4lLVT1KdcAB+09eObjZjMfJUmSJEnSEjSxadeZuSszP5mZJ1KCjc8FftXYbDlwd+AjlGnZr4iIa09qjNJ8XFCbcn3gulWsWtH/6VXPfLTbtSRJkiRJWoomXfMRgMw8LzNfkZnXAe4IvB+YSRubyYY8BHgG8POI+EpEPDQi+qeSSVN0fr3ZzBxTrsHMR0mSJEmStPRNJfhYl5mnZubDgcOAJwNnNDYJ4PbAe4A/RMS/R8Qxkx2lNLfzL6s1m1k/d5x8rTUfJUmSJEnSEjf14OOMzLw0M9+QmccCNwPeClxWPTyTDbk/8CTg+xHxrYh4TESsm86IpU4XbNid+XjIHJ2uwW7XkiRJkiRp6Vswwce6zDw9Mx9PyYY8CTgfyOo2E4i8GfA24NyIeJ2dsjVtf6oFHw9eP8C0azMfJUmSJEnSErcgg48AEXEo8FTgecDBtYdyZpPq6z7AE4FfRMQ/RcTKiQ1Sqtm4ZccV9/fZa8Wc25v5KEmSJEmSlrq5IyQTFBHLgHsDjwXuSel+fcXD1dc/Au8DbgjctbZ+NfBs4OYRcc/M3DWRQUuVTdt2Bx/Xrhog+GjmoyRJkiRJWuIWRPAxIq4FPBp4JHDozOraJruAL1DqQH4qM3dW+10NeBzweGC/ap+7AE8A3jCJsUsz6h2r161e3mfLorPb9Y4+W0qSJEmSJC1OU5t2HRGrI+LhEfEV4EzgmZQajzM1HaFkOb4MODIz75mZH58JPAJk5u8y87nAtYEv1Q7/iIl8E1LNpq27A4jrhsx83L4z2bHLZF1JkiRJkrS0TDzzMSJuQplW/RBg35nVtU12AV+kkeXYT2b+OSJOAs6hTNU+apxjlgZRn3a9bvXcT63m1OzLt+1k/ZoFW4ZVkiRJkiRpaBMJPkbEvsDDKEHHG82sZnf3aihZju8E3p6Z5wx7jsz8Q0ScA1wTWDfvQUtD2rx1d5x87aq5p12vXB4sj2Bnlh5KJfhovyRJkiRJkrR0TCrz8Y+UhjCwO+hI9fULwNuATw6S5TiHDfPcXxrZxtq0670HyHyMCPZatfyK/Ww6I0mSJEmSlppJBR/X0JnleB7zyHLs44+UxjPSRO3YuYutO3bXbFw7QPARSt3HK4KP2ww+SpIkSZKkpWWSNR/HneU4+wSZ9xr3MaVBbG5kLa4bYNo1dHa8NvNRkiRJkiQtNZMKPr6c8Wc5SgtGvdM1zG4m00u94/VmMx8lSZIkSdISM5HgY2Y+bxLnkaZlU63ZzKrly1i1YrCu1WY+SpIkSZKkpWxS3a5vX929PDO/M4/jHAvsDZCZXx3H2KRx2Lxtd+bj2tWDTbmGRvDRzEdJkiRJkrTETGra9SmUmo+/Aq47j+O8AzimOtYk61VKfdU7Xa8bcMo1dE67NvNRkiRJkiQtNZMM4AW7u13P9zjSgrK5Nu163RCZj2vNfJQkSZIkSUvYYIXpxiMneC5pojbVp12b+ShJkiRJkgRMNvg4DjORmh19t5ImbNOImY/WfJQkSZIkSUvZYgs+HlZ93TjVUUgN9YYzo9Z83GzmoyRJkiRJWmIWTfAxIu4IHEiZvn3OlIcjdejMfBwi+FjLfNyybSeZVieQJEmSJElLx1gbzkTEMcCN+2yyPiL+aohDLgP2BW4APLi2/rThRye1p7Pm4xDTrmuZjzsz2bZzF6tXDL6/JEmSJEnSQjbubtf3A17Q47EADgbeOeKxZ7pcJ/CfIx5DasWmrbuDj3uPmPkIpe6jwUdJkiRJkrRUtDHtOubeZOjj1QOPL8zM7475HNK8bK41ixmm2/WKZctYtXz309CO15IkSZIkaSkZd+bjjF4ByGEDkzsozWXOpky1fmdmfmce45JasbGW+ThMt2so2Y/bLt8F2PFakiRJkiQtLWMNPmbmi4EXN9dHxC5K1uKvM/M64zyntBB0dLseYto1lLqPl16+HTDzUZIkSZIkLS2T7HY97unY0oJR73Y9TMMZ6Kz7aOajJEmSJElaStqadt00kw150YTOJ01UveHMuiFqPkJnx2szHyVJkiRJ0lIykeBjNR1bWrLqDWeGnnZdy3zcbOajJEmSJElaQiY57VpasjZtG73hzFozHyVJkiRJ0hJl8FEag/q067XDTru25qMkSZIkSVqiDD5K87Rtxy6278wrlveex7RrMx8lSZIkSdJSMraajxFRj5pkZq7o8dg4dBxfmqbNtSnXAGuHnHbd0XDGzEdJkiRJkrSEjDOAF0BWX4d5TFrUNm5tBB9XDhl8NPNRkiRJkiQtUeOedt0vuGjgUUtSvUP1mpXLWLF8uKdVPfOxmUUpSZIkSZK0mI0z8/FRIz4mLWr1ZjPrhmw2A53Bx63bd7Erk2VhrF6SJEmSJC1+Yws+Zua7RnlMWuw2bd2d+ThsvUfo7I6dlABkfSq2JEmSJEnSYmW3a2meNm2bX+bj6pXLOmoSOPVakiRJkiQtFQYfpXmqBwvXrR4++LgsgjUrbTojSZIkSZKWHoOP0jxtrE+7HnG6dEfH620GHyVJkiRJ0tIwzoYzYxURa4FrAyuBczLzT1MektTV5lrDmb1HyHyEzqYzZj5KkiRJkqSlYiKZjxGxd0Rcs7pdZY5trxwRHwQuBk4HvgWcFxFfi4hbTGK80jA2batnPo4YfKxlPm4281GSJEmSJC0Rk5p2/a/AL6vbP/baKCIOBU4DHkjJeIza7TbA1yPixLYHKw1j09Z6zccRp13XMh+3mPkoSZIkSZKWiEkFH+8LVzT0fWOf7d4EXLW6n43HkjJN/L0RcfXxDk8a3XwbzoCZj5IkSZIkaWlqPfgYEUcAh1KChz/LzF/22O5o4ER2Bx0vBP4euBfwHGBj9dhewIvaHLM0jE21hjPrRm04Y81HSZIkSZK0BE2i4cz1a/e/2We7R1RfA7gcOC4zz6rWfS4iTgO+XC0/KCKekJlbxjtUaXj1adej1nxca7drSZIkSZK0BE1i2nV9ivTP+mx3z+prAh+sBR7LysxTgFOqxbXAsWManzQvm7bZ7VqSJEmSJKmbSQQf96ndv7jbBhFxJeAGtVX/3eNYp9TuHzW/YUnjUa/RuHbUhjNmPkqSJEmSpCVoEtOuV9buN5vIzLgNuxvSbAdO7bHd72v395/nuBa0iFgG3Bo4EjgMuBQ4F/haZnYN4mo6Nta7XY847drMR0mSJEmStBRNIvi4oXb/wB7b3KH6msD3MvPyHtvVg5er5jswuCLIdxRws9rtRpTGNjPuWE37bl1ErACeCTwBuHKXTbZFxKeAp2fm2ZMYk/rbXG84M4Zu12Y+SpIkSZKkpWISwcdza/d71Wk8oXb/632OdUDt/saRR1SJiI8AdwfWzfdY4xARhwCfpgRAe1kFPAC4a0T8VWZ+YiKDU0/1mo9rx9DtetvOXezYtYsVyyZRFUGSJEmSJKk9kwg+nl59DeCEiDgwM/8882BE3IUytXjG//Y51nVq9/84hrHdlIUTeNwL+ASdgcdzgfcCv6Zkjd4TuH312D7AByLiTpnZr4u4WpSZHd2uR818bHbJvnzbTtavMfgoSZIkSZIWt9ajG5n5G+D7lCnT64BPRcQNImJ1RNwReCe7p1NfSP/g4y1r93855qFuBb4DvIUS8Ju0l9D5/X0YODIzn5WZb8/MV2bmHYCHUepiAqwBPhgRayY8VlW27tjFrloxgHUjNpxZuTxYHnHFslOvJUmSJEnSUjCp1KqXs7uhzC2BHwCbgS8BV6keS+C1mdk16hIRhwPHVItbgB+PYVzvBv6GkgG5PjNvkZmPp38AdOwi4qrAk2qrfgg8NDO3NrfNzPcDL6ituhrwxHZHqF7qWY8wesOZiOis+2jTGUmSJEmStARMJPiYmR+hZBTOBCCjdpvJG/s28Oo+h3nYzOGAb2fmjj7bDjquF1RZhadn5va592jN4ylZjDOeMcd4/pXOWppPbWNQmtumrZ1BwnrtxmHZ8VqSJEmSJC01Eysql5lPoATZzmk8tAV4M3CXzNzWbd+IWMXuzMAA/qetcU7J/Wr3zwG+0G/jKvD6ztqqq0ZEvyY1akm92cy6VctZtiz6bN2fHa8lSZIkSdJSM4mGM1fIzLcCb42IawCHUqZe/6xX0LFmf+DZteXPtTTEiat+FkfVVn0pM7PX9jVfBJ5XW74P8N1xjk1z21zvdD1is5kZZj5KkiRJkqSlZqLBxxlVE5rfDLH9+cC72hvRVN2osXzagPt9G9jB7t/hMX22VUs21qZdr1s1+pRr6Mx83GzmoyRJkiRJWgImNu1aPR3VWP7VIDtl5hbgD7VV1x/biDSwzbWGM+vmm/lowxlJkiRJkrTEGHycvms2ln87xL71bZvH0QRs2lbPfBzjtGszHyVJkiRJ0hJg8HH69mksXzTEvhfX7q+MiNVjGI+GsGlrvebj/KZdr7XhjCRJkiRJWmKmUvMxItZSah0eBewHrKN0sR5YZr5k/CObir0by1uG2PfyLsfaOujOEfGTHg8dOcQY9mgd3a5tOCNJkiRJktRhosHHiLgBpUPzfYH5ZuktleDjmsbyXJ2/65qBxr3mORYNaXNLDWfMfJQkSZIkSUvBxIKPEfF44LXVOWeyHJMhMx5r+y0VzUzHVV3W9dIM4DYzIfvKzKO7ra8yIm1gM4CN9WnXY6z5uNnMR0mSJEmStARMJPgYEScCb6wW64HDpNQ43DiJcSxQze99DYMHH5uZjnvyz3EqNtemXe89xm7XW7btJDOJGCU2L0mSJEmStDC0HnyMEj35t2pxJtPxv4C3At/OzGFqHC5FlzWW9wcuGXDf/Wr3t2fmwPUeNR71btfzbThTz3zcmcn2ncmqFQYfJUmSJEnS4jWJzMebA0ewO+PxMZn5zgmcd7H4TWP56l3W9XJ47f5Z4xmOhlHvdr1uvtOuGzUjN2/bwaoVq+Z1TEmSJEmSpGlaNoFz3Lh2/38NPM7y08bytQbZKSLWAFfucxxNQEfDmXlOu16xbBmrlu9+StrxWpIkSZIkLXaTCD4eULv/2Qmcb7H5QWP5uAH3uwWdmas/Gs9wNIxN2+qZj/Obdg12vJYkSZIkSUvLJIKPf67dv3gC51tUMvM3wM9rq+4Sg3UZuWtj+dPjG5UGVZ92vXaemY/QWffRzEdJkiRJkrTYTSL4eHbt/kETON9i9LHa/cOBu/XbOCJWAI+qrToX+G4L49Ic6g1n9p5nwxkw81GSJEmSJC0tkwg+nsLu7Mc7TuB8UxcRR0RE1m6nzLHLm4F6p+pXRcTKPts/HbhKbfm1mZm9NlZ7NtczH+fZcAbMfJQkSZIkSUtL68HHzNwOvBEI4G4RceO2z7nYZObvKD+jGccA74uI1c1tI+IhwItrq84F3tDuCNXNrl3Zkfk4327X0Jn5uNnMR0mSJEmStMjNP1oymJdSahQeB3wkIu6Ymb+d0Ll7ioj7A6/q8tD6xvL7IuLyLts9IzM/OqbhPB+4PXCzavlBwK0j4j3AWcD+wL2AO9T22Qr8ZWZuGdMYNIRmZuK6MUy7XmvmoyRJkiRJWkImEnzMzJ0RcU/gv4B7Aj+IiFcA78rM8ycxhh72AY4cYLsr99l/LDJzc0ScAHwGOLZafRXgWT122QA8MjO/Pq4xaDj1TtcA68bRcMaaj5IkSZIkaQmZSPAxIr5c3V0G7AL2BV4BvCIizgHOA4bJ3svMvPN4Rzl9mXleRNyKEnB8AnBol822UQKU/1B1ytaUbNq6Ozi4LGD1ivlXMegIPpr5KEmSJEmSFrlJTbs+Hqg3RElKDUiAIygdngcVjWONLDNPBk4ex7Eaxz2b3d/fsPtuB14aES8Hbg1cCziEkun4e+BrmXnRmIaqedhUazazbvUKIkb6lXfoaDhj5qMkSZIkSVrkJhV8hP7BuPlHbZaYzNwJfK26aQHaPOZmM2C3a0mSJEmStLRMKvj4rgmdR5qYeubj2jE0mwFrPkqSJEmSpKVlUg1nHjWJ80iTVG8400bm45btO9mVybIxTOeWJEmSJEmahvl3yJD2UJtrDWfWjSnzcW0tiJnA1u27xnJcSZIkSZKkaTD4KI1o49bxZz6uXrmsowDq5lp2pSRJkiRJ0mJj8FEaUT0wuHb1eIKPyyJYY9MZSZIkSZK0RBh8lEa0qdYQZu8xTbuGRtMZg4+SJEmSJGkRm1S361ki4l7AXYFbAlcF9gfWAr/KzOs2tl0J3KRa3JmZ35vkWKVuNte7XY9p2jV0Np2x47UkSZIkSVrMJh58jIiHAS8FDq+v7nEfgMzcHhHvBq5dHeMmmfnDVgcqzWFjveHMKjMfJUmSJEmSmiY27ToiVkTEfwHvpgQeo3aD0ty3nzfWtn14K4OUhlCv+bhuTDUfwcxHSZIkSZK0dEyy5uP7gAezO+C4Hfgs8GLgCdW6fgHID9Uev2d7w5QGU6/5OK6GM9DIfDT4KEmSJEmSFrGJTLuOiL8EHkQJHgbwEeDJmfnH2jZv6neMzDwvIr4PHAtcPyIOzMw/tzhsqa9NtZqPY512Xct83Oy0a0mSJEmStIhNKvPxxbX7b8nMB9UDj0M4vXb/BvMckzQvHcHHMWY+rjXzUZIkSZIkLRGtBx8j4vqURjEJ/A74+3kc7he1+0fOZ1zSfG3eVm8401LNRzMfJUmSJEnSIjaJzMdja/f/OzO3zuNYl9Tu7z+P40jzVs98XLu6pW7XZj5KkiRJkqRFbBLBx0Nq98+c57HqkZhV8zyWNC+bat2u926r27WZj5IkSZIkaRGbRPCx3sF6vhGaA2r3L57nsaSR7dyVbNm+64rlteNsOGPmoyRJkiRJWiImEXy8oHZ/vnUab1y7f/48jyWNrJ71CO3VfNy2cxc7du3qs7UkSZIkSdLCNYng409r9+8z6kEiYjVw99qq00YekTRPm7d2ZiSOt9t157HMfpQkSZIkSYtV68HHzDydkqUYwHUj4pEjHuqJwJUo07h/mpl/HNMQpaHVMx9XLg9WrRjfU2nl8mB5xBXL1n2UJEmSJEmL1SQyHwHeUX0N4E0Rcddhdq62f3lt1evGNTBpFB2drsc45RogIlhj3UdJkiRJkrQETCr4+EpK7ccE9gL+JyLeFBHX6bdTRBwQEa8APk3pbp3AL4D/bHm8Ul+batOux9npesZaO15LkiRJkqQlYPxRky4yc2NEnAh8iRJ8XA78LfC3EXEW8JPa5gdExJuBo4FbVdvOzEG9DDgxM43GaKo2b6tnPo6v0/UMO15LkiRJkqSlYFKZj2TmacAJdHa/DkoH7BMoWY0A+wN/A9yGzuDoH4F7ZuaZ7Y9W6m9jfdp1C5mPe5n5KEmSJEmSloCJBR8BMvMrwDHAycD22kPR2DRq63YC7wGOrQKY0tRt3lafdt1u5uNmMx8lSZIkSdIiNZFp13WZ+Sfg0RHxbOCBwO2AGwEHAvsBm4ELgZ8DXwE+nJnnTHqcUj9tNpyBxrRrMx8lSZIkSdIiNfHg44zMPB94Y3WTFpV6w5l1bdR8rE273mLmoyRJkiRJWqQmOu1aWirqDWfWtdHt2mnXkiRJkiRpCTD4KI1gU8vBRxvOSJIkSZKkpWCi064jIoBjq9uVgAOAfYBLgYsotR6/m5lnTHJc0rDq067XtjHtul7z0cxHSZIkSZK0SE0k+BgRtwCeAdyZEmyca/tLgC8C/5KZ32t3dNLw6g1n9jbzUZIkSZIkqatWp11HxCER8Wngm8D9gH2BqG5dd6lu+wMPAr4dER+PiIPaHKc0rHodxla6Xa/szHzMzLGfQ5IkSZIkqW2tBR8j4ijgNOCe7A421iMo0eVGY7sATgC+GRHXaWus0rA2bq3XfGx32vXOTLbvNPgoSZIkSZIWn1amXUfE1YCvAgdSAolJCSRuomRBfgM4G7gY2AisB/YDrgkcB9wKWMfuIOQ1gVMj4maZeW4bY5aG0dHtuo3Mx0Ydyc3bdrBqxaqxn0eSJEmSJKlNbdV8/E92Bx4D+BPwz8DbM3PDXDtHxD7A3wL/SGlMk8AhwH9QMimlqepoONNC5uOKZctYtXwZ23buAkrdx/3GfhZJkiRJkqR2jX3adUTcndJYZiZr8TvAsZn5b4MEHgEy87LM/BdKV+zvsntK9t0i4s7jHrM0rE0tZz5Co+O1TWckSZIkSdIi1EbNx6dUXwP4HXD3UadKZ+bvgXtUx5kJZj51vgOU5mtzLfNxXQvdrmF20xlJkiRJkqTFZqzBx4g4ALhLtZjAYzPzkvkcMzMvAh7L7qY0d4uI/eZzTGk+tu3YdcV0aGin4Qw0Mh8NPkqSJEmSpEVo3JmPd6TUkUzgR5n5pXEcNDO/CPyoWlwB3Gkcx5VGUW82A7C2rWnXK512LUmSJEmSFrdxBx9vVbv/7jEfu368W/XcSmrZpkYW4rpVZj5KkiRJkiR1M+7g4/Vq97815mOfVrt//TEfWxrY5q27Mx9Xr1jGiuVtlE6FtbXMx81mPkqSJEmSpEVo3FGTI2r3vzfmY59eu3/4mI8tDWxjLfjYVrMZMPNRkiRJkiQtfuMOPh5cfb08M7eM88CZeTmwmdJ05uA5Npdas3lbvdN1O1OuAdZY81GSJEmSJC1y4w4+7k1pNnPJmI87Y+a461s6vjSnTfXMx5aazQCsNfNRkiRJkiQtcuMOPq6uvm4e83FnXF59XdXS8aU5bap1u17bUrMZsNu1JEmSJEla/MYdfGyn88ZsMaHzSLNs2lqfdm3NR0mSJEmSpF4mFSyUlozN2yYz7bqe+bhl+052ZbZ2LkmSJEmSpDYYfJSGtLGW+bi2xYYza2uBzQS2bt/V2rkkSZIkSZLa0Fba1vqI+Ks2jtvCMaWhbK41nNm7xWnXq1cuIyiBR7DuoyRJkiRJWnzaipwcDLyzpWNLU7WpVn9xbYvTrpdFsGbl8iuCjvXp3pIkSZIkSYtBe5GTdprCWPROU7dpa73mY3vTrqE0nZkJPpr5KEmSJEmSFps2go9tdqK2y7WmrqPhTIvTrqGz6YwdryVJkiRJ0mIz7sjJo8Z8PGnB2VRrOLOuxYYzUDIfZ5j5KEmSJEmSFpuxBh8z813jPJ60EG2qZT62WfMRzHyUJEmSJEmL27JpD0BabDZNqNs1NDIfDT5KkiRJkqRFxuCjNKTNHd2uW552vdJp15IkSZIkafEy+CgNaePWyTWcqQc3N5v5KEmSJEmSFhmDj9IQMrMjCDjJbtf1LtuSJEmSJEmLgcFHaQhbd+xi5668Ynldy9Ou16/ZHdzcsMXgoyRJkiRJWlwMPkpDqDebAVjbcubj+jUrr7i/YcsOMrPP1pIkSZIkSQuLwUdpCM26i2tXtpv5uM9eu4OP23bu6qg3KUmSJEmStNAZfJSGsKlWd3HtquUsWxatnm/tquXUT3HBhq2tnk+SJEmSJGmcDD5KQ6hPu167qt0p1wDLIjqmXp9/2ZbWzylJkiRJkjQuBh+lIWzaunva9d6r251yPaPedOaCy8x8lCRJkiRJi4fBR2kIm7dNNvMROpvOXLDBzEdJkiRJkrR4GHyUhrCxlvm4bkKZj/vUMh/PN/NRkiRJkiQtIgYfpSFMJ/OxNu3ahjOSJEmSJGkRMfgoDaGz5uNkgo/72HBGkiRJkiQtUgYfpSF0drueRsMZg4+SJEmSJGnxMPgoDWFTbdr1ugllPnY2nNlKZk7kvJIkSZIkSfNl8FEawuZpNJzZa3fwcfO2nWysZV9KkiRJkiQtZAYfpSFsnELDmbWrlrMsdi/bdEaSJEmSJC0WBh+lIWyuZR2um1DNx2URHVOvbTojSZIkSZIWC4OP0hA2batPu55M5iM0m86Y+ShJkiRJkhYHg4/SEOrdricbfKw3nTHzUZIkSZIkLQ4GH6UhbK5lPq6d0LRrgH1qmY/nm/koSZIkSZIWCYOP0hDqmY97T2natTUfJUmSJEnSYmHwURpCPfg4qW7XAPt0TLs281GSJEmSJC0OBh+lAe3alWzeXm84M7lp1x01H818lCRJkiRJi4TBR2lAW3bsJHP38tS6XW/YStYHIkmSJEmStEAZfJQGtLE25Rpg3SSnXe+1O/Nx87ads8YiSZIkSZK0EBl8lAa0eevuKdfLAtasnNzTZ+2q5SyL3ct2vJYkSZIkSYuBwUdpQJu27c42XLdqBRHRZ+vxWhbRWfdxg3UfJUmSJEnSwmfwURrQplrm49oJNpuZ0VH30cxHSZIkSZK0CBh8lAbUzHycNDMfJUmSJEnSYmPwURpQvebjJDtdz9inlvlozUdJkiRJkrQYGHyUBrSp1mF67arpTrs+/zIzHyVJkiRJ0sJn8FEaUMe066lkPtanXZv5KEmSJEmSFj6Dj9KANm+b7rTrjpqPZj5KkiRJkqRFwOCjNKCNW+sNZ6bc7XrDVjJz4mOQJEmSJEkahsFHaUCbO2o+TmHa9V67Mx83b9vZEQyVJEmSJElaiAw+SgPaVJt2vffqyWc+rl21nBXL4oplO15LkiRJkqSFzuCjNKCObtdTqPm4LIKD1q++YvmCDdZ9lCRJkiRJC5vBR2lA9czHadR8BDh4nzVX3L/AzEdJkiRJkrTAGXyUBlSv+TiNbtcAB9cyH8+347UkSZIkSVrgDD5KA9o45YYzAIfsU592beajJEmSJEla2Aw+SgPaXJ92PYWGMwAHr9897drMR0mSJEmStNAZfJQGtHnb9Kddm/koSZIkSZIWE4OP0oDq067XTWnadWfDGTMfJUmSJEnSwmbwURrAzl3Jlu27rlheO61u1x0NZ7aSmVMZhyRJkiRJ0iAMPkoDqE+5Bth7atOud2c+Xr59Z0c2piRJkiRJ0kJj8FEawKatOzuW106p4cwBa1exYllcsXz+ZdZ9lCRJkiRJC5fBR2kAm2qZjyuWBauWT+eps2xZcND6etMZ6z5KkiRJkqSFy+CjNIDNtczHdatXEBF9tm5XZ9MZMx8lSZIkSdLCZfBRGkBnp+vpTLme0dl0xsxHSZIkSZK0cBl8lAZQbzizdkrNZmYcsk992rWZj5IkSZIkaeEy+CgNYNO2zmnX03TI+t3Trs18lCRJkiRJC5nBR2kAmxbStGszHyVJkiRJ0iJh8FEaQD34uHbVdDMfOxvOmPkoSZIkSZIWLoOP0gA216Zd7716ITWc2UpmTnE0kiRJkiRJvU03hWuBioijgWOAKwM7gXOB72bmb6Y6ME1NR+bjtGs+1jIfL9++k0s2b2f/daumOCJJkiRJkqTuDD7WRMQDgedTAo/dHv8G8NzMPKWFcx8PfGXE3W+emd8d32jUtGnbwqn5eOC6Vey710ouvXw7AGeev4FbXfPAqY5JkiRJkiSpG6ddAxGxPCLeCXyIHoHHyq2B/42Il05mZFooNm9dON2uI4LrHrr+iuUzz9swxdFIkiRJkiT1ZuZj8RrgpNryZuB9wBnAKuCWwAOAlZSA7fMi4qLMfE2LYzoH2DHnVoVdR1q2saPb9fSfNtc7dD3f/s1FAPzc4KMkSZIkSVqgph9FmbKIuDfwd7VVPwXukZm/a2x3I+B/KHUgAf41Ir6UmT9qaWjHZ+bZLR1bQ6o3nFk75YYzANc5ZHfm4y/ON/goSZIkSZIWpj162nVELANeXlu1GTihGXgEyMwfAA8CdlWrmvtqCavXfNx7ytOuoWQ+zvjFeRvseC1JkiRJkhakPTr4CNyZzhqPr8vMs3ptnJnfoNSFnHGfiLhWW4PTwtHR7XoBTLu+Ti34uGHrDs695PIpjkaSJEmSJKm7PT34eL/G8n8MsM/bG8snjmcoWsg21RvOTLnbNcA+a1Zylf32umLZpjOSJEmSJGkh2tODj/eu3f91Zv56gH2+RmeDl/uMd0haiDbXpl1Pu9v1jHrHa5vOSJIkSZKkhWiPDT5GxH7A1WurThtkv8zcBnyvtuqYXttq6ejIfFwADWegM/ho5qMkSf+/vfuOj+sq8z/+fWbUJcuSe4m705wGCYnTkw0JLQtLC7AhhLIsgVAWtoT9hYXfssvSFn6EXTphwyYQ6kJCykIIgfQKpOEUd8d27LhILurl+f1xp9y51sgz1ozmSvN5v17z8j1nzr33kXQsXT06BQAAAHFUtclHSUdHymuKODc8QrLdzOaUIJ6oT5vZo2bWYWb9ZvaCmT1uZt80s9ebWTwyYFWgf3BY/UPDmXIc1nyUIpvOsOM1AAAAAACIoWpOPi6NlDcVcW60bfRapfCXkk6Q1CapVtJMScdJeo+k/5H0rJn9RRnui4ie/qGcchynXa/dsV8DoQQpAAAAAABAHFRz8rE1Ut5dxLkdkfKUEVuNXYekjZJ2SopmlpZKusHM/q1M90bK/tB6j5LUFIMNZyRp6YwW1SRMkjQw5Fq3o6vCEQEAAAAAAOSq5uRjS6TcO2KrkfUc5FqHapek/5T0CknT3X2auy9295mSpkl6vaR7I+dcaWZ/cyg3M7M/jfSStGwsH8Rk092XTT7W1SRUm4zHf5u6moSWzmzOlJ/etreC0QAAAAAAABwoHlmUymiIlPuLOLcvUm4cYyxSsInNYe7+IXf/lbvnjMR09z3u/nNJZ0n6eOTcz5nZghLEgBF0haZdt8RkynXakXOyA3jZdAYAAAAAAMRNNScfoyMd64o4tz5Sjo6ELJq773P3g46+9MCnJH0jEs8Vh3DPY0Z6KXdDnarXFRr5GJcp12lHseM1AAAAAACIsWpOPu6PlKMjIUcTHekYvdZ4+CflJj1fXYEYqkI4+dgck52u046cnU0+Pk3yEQAAAAAAxEw1Jx+jC+S1F3FuW6Q87lkfd98l6c5Q1SIzmzvecVSD7tC06+b6eI18DO94vaWzR/t6ByoYDQAAAAAAQK5qTj6uj5QXFnHuokh53RhjOVTPRMqzKhLFJNcV2u26OWZrPh7W3pizDuWz2ysxCBcAAAAAAGBk1Zx8XBUpLy/i3PBu0B3uvq0E8RyK6FqTTRWJYpKL85qPZqYjZmc3W2fdRwAAAAAAECdVm3x0905Jm0JVpxVynpnVSTopVPVECcMq1uxIeWdFopjkuvrC067jNfJRyp16/cy26GoCAAAAAAAAlVO1yceUW0PHy8xsaQHnnKXczWluLm1IRTkrdDwgaUulApnMuvvju+GMxKYzAAAAAAAgvqo9+fjzSPmvCzgn2uaG0oRSHDN7pXKnit/r7t2ViGWy2x8a+dgUsw1nJOnIOa2Z42e275O7VzAaAAAAAACArGpPPt4u6clQ+YNmtiRfYzM7TdJFoapb3H11nraLzcxDr9+Nct3GYoJO7Wr9zUj1d4u5BgoX95GPR4WmXXd2D+iFfX0VjAYAAAAAACCrqpOP7j4s6cpQVbOkm8xsQbStmR0v6SfKfs6GJX2sRKG82czuNLPXpNaUzMvMzpf0oKRwjI9Juq5EsSAi7ms+tjfXadaU+kyZTWcAAAAAAEBcxC+TMs7c/SYz+5qky1NVx0h6ysy+L+lRSbWSTpX0xtRx2kfd/bEShnJ26tVpZvdKelzS85L2KdjFeomkCySdEDlvm6TXphKpKIPwbtfNMdvtOu3IOVMyIx6f2bZPZx8xs8IRAQAAAAAAkHxM+5CkKZLelio3S3pPnrYu6bPu/oUyxdIm6cLU62AekHSJu28oUyxQ7rTrphiOfJSCqdd3rw42O2fTGQAAAAAAEBdVPe06zd2H3P1SSW9W7hqQUQ9IOt/drxylzaF4RNI1kp5SkNwcjUu6T9Ilks5097UljgURXf3ZadctMdxwRpKOCO14/cz2vRWMBAAAAAAAICuew7gqxN1/LOnHZnaspOMlzZM0JGmrpIfdfV0R19ogyQps+6Skd0mSmbVJerGkhZJmSGqU1CepU9IGSQ+5+55C48DYhaddN8VwwxlJOiq04/Xq7fs1NOxKJgrqfgAAAAAAAGUTz0xKhaWSgaONgCznvTsl/bYS98bIctd8jOd/mcNntyhh0rBLfYPD2rCrS8tmtlQ6LAAAAAAAUOWYdg2Mwt3V3R/e7Tqe064bapNaPL05U35yC4NjAQAAAABA5ZF8BEbRNzisweHsMpzNMd1wRpJetLAtc/zbp1+oXCAAAAAAAAApJB+BUYRHPUpSU108Rz5K0gVHz84c3/H0CxoYGq5gNAAAAAAAACQfgVGF13uU4rvhjCSdfcRM1dUE/6X39g7q4fW7KxwRAAAAAACodiQfgVF09WeTj421yVjvIN1cX6Mzlk3PlG9btb2C0QAAAAAAAJB8BEbV1Rf/zWbCLlgxJ3P861Xb5e6jtAYAAAAAACgvko/AKPb2DmSOW2K82Uza+UfPyhxv6ezRU8/vq2A0AAAAAACg2pF8BEbR2d2fOW5rqqtgJIWZ1dqgFy1oy5R/zdRrAAAAAABQQSQfgVF0dGVHPrY31VYwksJdsCK76/Wvn9pWwUgAAAAAAEC1I/kIjCI88rG9Of4jHyXpZaHk45Nb9mprZ08FowEAAAAAANWM5CMwio7u8MjHiZF8XD6rRUtmNGfKtz/F1GsAAAAAAFAZJB+BUXSERz5OkGnXZpY79Zp1HwEAAAAAQIWQfARG0Rka+TgRNpxJCycfH1i3K2fXbgAAAAAAgPFC8hEYRe7Ix4mTfDxxYbump9aoHBhy/e6ZHRWOCAAAAAAAVCOSj8AoOrom3rRrSUomTOcdNStTZuo1AAAAAACohJpKBwDEWUfMpl1f/+CmgtvW1yQzx7f9aZuuvX+DahK5f2+4eOXCksUGAAAAAAAQxchHII/egSH1DAxlyu3NE2fkoxTsel2bNElS3+Cw1u/sqnBEAAAAAACg2pB8BPIIbzYjTaw1HyWpriah5TNbMuUnNu+pYDQAAAAAAKAakXwE8ghvNtNQm1BDbXKU1vF0zPypmeM/PtepvT3seg0AAAAAAMYPyUcgj4m603XY8fOnqrUhWNp1aNh192p2vQYAAAAAAOOH5COQR2fMNps5FDXJhM46fGam/NCG3drfN1jBiAAAAAAAQDUh+QjkER75OG2CbTYTdvLiaWquD0Y/Dgy57l2zs8IRAQAAAACAakHyEchjMox8lIKNZ85aPiNTvn/dLnX3M/oRAAAAAACUH8lHII+OrvCajxN35KMkrVwyTY2pDXP6B4d139pdFY4IAAAAAABUA5KPQB4doZGPE3XDmbT62qTOWD49U75v7U71DgxVMCIAAAAAAFANSD4CeXSG1nycyNOu005bOkP1NcF/+d6BYT2wjtGPAAAAAACgvEg+Anns7p48064lqbEuqdOWZUc/3rNmJ2s/AgAAAACAsiL5COTROYmmXaedsWyGapMmSeruH9L1D26qcEQAAAAAAGAyI/kI5NGRM+164o98lKTm+hqtXJId/fj1363N2VgHAAAAAACglEg+AiMYGnbt6Zl8Ix8l6azDs6Mfd3X1619vXlXhiAAAAAAAwGRF8hEYwd6eAblny5Mp+TiloVYXHD07U/7ZH7fojqe3VzAiAAAAAAAwWZF8BEYQnnKdMGlKQ00Foym905fP0IL2xkz5yp89qb29A6OcAQAAAAAAUDySj8AIOkKbzbQ11SmRsApGU3oJM73+xMNUlwy+BWzb26vP3Pp0haMCAAAAAACTDclHYASdoZGP7ZNks5mo2a0N+uB5yzPlHzy0Sfet2VnBiAAAAAAAwGRD8hEYQXjk42Ra7zHqvecu04q5rZnyR3/2uLr7BysYEQAAAAAAmExIPgIjCI98bJvEycfaZEKff+PxSqamlT+3u0f//qtnKhwVAAAAAACYLEg+AiPoqIJp12nHzp+q956zNFP+7n0bdC/TrwEAAAAAQAmQfARGkDPtunnyjnxM++B5h2v5rBZJkrv0Nz/8o7bv7a1wVAAAAAAAYKIj+QiMoKMrPO16co98lKSG2qS+/JYXqa4m+Jawc3+/Pnj9HzU4NFzhyAAAAAAAwERG8hEYQe6068k/8lGSjpk3VZ98zTGZ8kMbdusLtz1bwYgAAAAAAMBER/IRGEFnzm7Xk3/kY9pbTl6g1714fqb8jTvX6jdPba9gRAAAAAAAYCIj+QiMoKNKdruOMjP92+uO1eGp9R8l6W9//Jie291dwagAAAAAAMBERfIRiHD33A1nqij5KElNdTX6+iUnqrE2KUna0zOgD1z/B/UNDlU4MgAAAAAAMNGQfAQiegaG1D+Y3WilmqZdpy2fNUWfef1xmfJjm/fokzetqmBEAAAAAABgIqqpdABA3IRHPUqTe9r19Q9uGvX9UxZP00MbdmfadvUNauWS6XnbX7xyYUnjAwAAAAAAExsjH4GIjq7seo8t9TWqq6ne/yYXHj9XC9obM+WbHtuq9Tu7KhgRAAAAAACYSKo3qwLkEd7puq0Kp1yH1SYTeuvKRZrSEAySHnbp+gc3qjO0IQ8AAAAAAEA+JB+BiPBO19W22cxIWhtrdcnKRUomTJLU1T+k7z2wMWddTAAAAAAAgJGQfAQiwqP6qn3kY9qCaU167YvmZ8pb9/TqZ3/cLHevYFQAAAAAACDuSD4CEbu7stOuGfmYddKidp2+LLvZzOOb9+jOZ3dUMCIAAAAAABB3JB+BiNxp14x8DHvlsXO1dGZzpnzbqu16fHNn5QICAAAAAACxRvIRiMidds3Ix7BkwnTxyQs1rTn7efnp7zdrAztgAwAAAACAEZB8BCI6usPTrhn5GNVUX6N3nLZYjbVJSdLgsOu6BzZq576+CkcGAAAAAADihuQjEBEe+djezMjHkcyYUq9LTs3ugN0zMKTv3r9Bu/aTgAQAAAAAAFkkH4GI8MhHpl3nt2RGs9540mGZ8u6ufr372kfUOzBUwagAAAAAAECckHwEIthwpnAnHNaml6+YnSn/cVOnPvzDRzU07BWMCgAAAAAAxAXJRyBkcGhY+3oHM+V2Rj4e1NlHzNTJi6dlyr/80zZd8dPHNUwCEgAAAACAqkfyEQjp7BnIKbPm48GZmV5zwjwdMbslU/c/f9isf7rxSbmTgAQAAAAAoJqRfARCwpvN1CZNzXXJCkYzcSQTpotPWaSVS7IjIK9/cJP+5eZVJCABAAAAAKhiJB+BkOhmM2ZWwWgmlrqahL7zjpN14sK2TN01927QZ3/5NAlIAAAAAACqFMlHIKSji81mxqKlvkbffdcpOv6wqZm6b965TlfdvrqCUQEAAAAAgEoh+QiEhHe6bmOzmUPS2lCra991io6aMyVT9+XfrNYXfvUMIyABAAAAAKgyJB+BkPC0a0Y+Hrq2pjp9790rtXxWdhOar/x2jT52w5MaYhdsAAAAAACqBslHICQ88rGdkY9jMqOlXte/e6WOnJ0dAXn9g5v0wR/8QX2DQxWMDAAAAAAAjBeSj0BIZ1fuhjMYm1mtDfrRZafqpEXtmbpbn9imd333Ye3vG6xgZAAAAAAAYDyQfARCckc+Mu26FNqa6nTdX52ic4+cmam7d80uXfztB7Rrf18FIwMAAAAAAOVWU+kAgDjpzFnzkZGPxbr+wU1533vpUbPV0dWvxzbvkSQ9vnmPLvjSXbr0tEWaNaVhxHMuXrmwLHECAAAAAIDxwchHICR3t2tGPpZSMmG66CULdNqy6Zm63V39+sada7Xmhf0VjAwAAAAAAJQLyUcgJGe362ZGPpZawkx/ftxcvfyYOZm63oFhffe+9Xpw/a4KRgYAAAAAAMqB5COQ4u7qZLfrsjMznXPETF18ykLVJk2SNOzSjY9u1S2Pb9Wwe4UjBAAAAAAApULyEUjZ3zeoweFs4osNZ8rr2PlT9Z6zlqm1Ibv07L1rd+m6+zeqb2CogpEBAAAAAIBSIfkIpIQ3m5GkqY0kH8ttfnuj3nfucs2bmt1w5pnt+/TNu9blrL8JAAAAAAAmJpKPQEo42dXaUKOaJP89xsPUxlq95+xlWjG3NVO3bW+vvva7tfrDpo4KRgYAAAAAAMaK7AqQsrsrtN4jm82Mq7qahC5euVBnHz4zU9fVN6i3fOsB/eKxrRWMDAAAAAAAjAXJRyAlPO26jc1mxl3CTK84do7ecOJ8JS3YiKZ/cFgf+sEfddXtz8rZiAYAAAAAgAmH5COQ0pGz0zXrPVbKSYum6Z1nLlZjbTJTd9Xtq3XZdb/Xvt6BUc4EAAAAAABxQ/IRSOkIjXxsZ+RjRS2d0aL3nbtMS2c0Z+puW7Vdf/GVe7V6+74KRgYAAAAAAIpB8hFIeWFvb+Z4Gms+VtyMlnr9/PIzdM4R2XUg1+3s0l989V7d8vjzFYwMAAAAAAAUiuQjkLL6hf2Z46Uzm0dpifEytalW//WOk/Wh85Zn6rr7h/T+6/+gT9/6lAaHhisYHQAAAAAAOBiSj4Akd8+Zznv4rCkVjAZhyYTpb192pK6+9CWa0lCTqf/WXet00Tfv1/qdXRWMDgAAAAAAjIbkIyBpx74+7e0dzJQPn9VSwWgwkvNXzNZNHzhTR87OJob/uKlTr/ry3brugY3shg0AAAAAQAyRfASUO+V6Rkud2lnzMZYWz2jWz99/ut5w4mGZup6BIX38hif1jmse1vbQup0AAAAAAKDySD4CElOuJ5Cmuhp98U0n6BuXnKj2ptpM/Z3P7tDLr7pLN/xxC6MgAQAAAACICZKPgHJHPh4+mynXE8Erjp2rX33kbJ131KxMXWf3gD78o0d10Tfu1xOb91QwOgAAAAAAIEk1B28CTH45yUfWe4yN6x/cdNA2Lz1qlqY21OqWJ55Xf2r360c2dug1X7lHJy5q18tWzNaUhmCE5MUrF5Y1XgAAAAAAkIvkI6pedKfr5Uy7nlDMTCcvmaZls1p06xPPa9XzeyVJLun3Gzv05JY9OveImTp12fTKBgoAAAAAQBUi+Yiqt6urXx3dA5ky064npmnNdbrk1EVa88J+3fz4Vr2wr0+S1Dc4rF+t2q671+xUd/+QLj1tUWYkJAAAAAAAKC/WfETVW709O+V6WnOdZrTUVzAajNXyWS364HmH69UnzFNjbTJT390/pH//1TM683O/1VW3P6s9oYQzAAAAAAAoD5KPqHprXghPuWbU42SQTJhOWzpdf3fBETrr8BmqS2a/1e3pGdBVt6/WmZ+7Q5/536e0fW9vBSMFAAAAAGByI/mIqsdmM5NXU32NXnnsXP3Dy4/UuUfOVEt9dqWJfX2D+uad63Tm5+7Q3//kMT0bWvcTAAAAAACUBslHVL3wtGuSj5NTc32NXrZiju796Hn68PmHq7Uhm4QcGHL99Peb9bIv3aV3ffdhPbBul9y9gtECAAAAADB5sOEMql7OyMfZ7HQ9mU1tqtWHzz9C7z5rqX740CZ95571en5Pdtr1HU+/oDuefkEnHDZV7zl7mV5x7BwlE1bBiAEAAAAAmNhIPqKqdXT1a+f+vkyZkY+T2/UPbsocN9XV6PJzl+vxzZ26e/VObQut/fjY5j16//V/0LTmOp25fIZOXNiuupoDB4pfvHLhuMQNAAAAAMBERfIRVS086nFqY61mTmGn62qSTJhevLBdL1rQptUv7Nfdq3do7Y6uzPu7u/r1i8e26vantuvM5TO0csl0NdYlR7kiAAAAAAAII/mIqrY6tNP14bNaZMYU22pkZjpi9hQdMXuKtnT26O7VO/TE5j1Kr/zY3T+k21Zt153P7tCpS6frjOUzcjavAQAAAAAAI+O3Z1S1nM1mZjPlGtL8tka95eSFevmKft2zdqce2bBbA0NBGrJvcFh3PrtD967ZqZMXT9M5R87U/LbGCkcMAAAAAEB8sds1qtqa0LTr5bPYbAZZ7c11evXx8/QPLz9Kf3bkTDXUZr9dDg677l+3S+d8/re64qePad2O/aNcCQAAAACA6sXIR1S16LRrIKqlvkYXrJijsw6fqQfX79Y9a3aqq29QUpCE/PEjm/WT32/Wq46bq8vPXaZj5k2tcMQAAAAAAMQHyUdUrT09A9q+N7TTNdOuMYqG2qTOOWKmTl82XY9s7NDdz+5QZ8+AJMlduuXx53XL48/rjOXT9bZTF+v8o2epJsngcgAAAABAdSP5iKoVnnI9pb5Gc1obKhgNJoraZEKnLZ2uUxZPU2NdUl/73RqtC+2Qfe+aXbp3zS7Nm9qgt566SG8+eYFmtLCLOgAAAACgOjEsB1Vr9fbslOvls9npGsVJJkxvPOkw/foj5+jrbz1Rx85vzXl/655e/fuvntHpn7lDH7j+D/rNU9s1MDRcoWgBAAAAAKgMRj6iaq0OjXxkvUccqmTC9Mrj5uoVx87RIxs7dO39G/W/TzyvweFgh+z+oWHd/Pjzuvnx59XeVKsLj5+r175ovk5a1E7CGwAAAAAw6ZF8RNXKTT6y0zWKd/2Dmw6oO23pdB0zr1UPb9ith9bv1r7ewcx7Hd0D+t4Dm/S9BzapralWK+a2asXcVi2a3qxkwnTxyoXjGT4AAAAAAGVH8hFVa01k2jVQKq0NtXrpUbN17hGz9My2fXp0c6eefn5vZjSkJHV2D+i+tbt039pdaqxN6ui5UzStuU5nHj5DLfV8awYAAAAATA78houqtK93QFv39GbKTLtGOSQTphXzWrViXqt6B4b0p6179OhznVq3o0seatczMKQ/bOrUe7/3e9UkTCctatc5R87U2YfP1Iq5rUokmJ4NAAAAAJiYSD6iKq0N7U7cXJfU/LbGCkaDatBQm9RJi6bppEXTtLd3QE89v1dPPb9Xa3d0aSg0InJw2PXg+t16cP1uff6Xz2hGS73OWD5dpy4NXounN7FWJAAAAABgwiD5iKqUs9P1LHa6xvhqbajVyiXTtXLJdPUODOnZ7fu06vm92rirW3t6BnLa7tzfpxsf3aobH90qSZrT2qBTl07TKUum6+TF7Vo2s4WRkQAAAACA2CL5iKq0JrTZzHI2m0EFNdQmdfxhbTr+sDa9+eQFemxzp+56dofufHaHHnuuU8Oe237b3l7d8OhW3ZBKRrY11eqkhe16yeJpesnidh03f6oaapMV+EgAAAAAADgQyUdUpWdDIx8PZ7MZxMSPHn5OkjRrSoMuOmmBLjxurtbu6NL6nfu1bkeXXtjXd8A5nd0D+s3TL+g3T78gKVhn8rC2Ri2a3qRLT1uskxa1q725blw/DgAAAAAA0kg+jsDMjpF0vKR5koYkbZH0iLuvH+c4EpJOl7RM0lxJe1Kx3O3uHeMZy2Syp2dAj2zMfvrYbAZx1VRXo+PmT9Vx86dKkvb3DWr9zi6t27FfG3d1a/veXkUGRmpo2LVxd7c27u7WXat3SgqWFnjJomB05MmL27VwGutGAgAAAADGB8nHEDN7o6SPK0g8jvT+fZI+5u6/K3McNZI+KulyBQnQqH4zu0nS37v7hnLGMhn91z3rta93UJI0pb5GJy+ZVuGIgMK01OcmI3sHhrRpd7c27urShl3d2tzRrYGhaDoyWGZgzQv79cPUyMoZLfU6eXG7Tl06XSuXTtMRs6awbiQAAAAAoCxIPkoys6SkqyW94yBNT5f0GzP7tLt/vEyxzJZ0s6SXjNKsTtIbJF1gZpe6+43liGUy2tM9oP+6JzuA9V1nLlFrQ20FIwIOXUNtUkfMnqIjZgfrlg4Nu7Z29gQjH3d1afvePu3cf+BU7Z37+/S/T27T/z65TZLU3lSrkxdP08ql0/XihW1aMbeVdSMBAAAAACVB8jHwJeUmHrslfV/SowoSfSsVJPtqJSUk/ZOZ7Xb3L5UyCDNrlHSjchOPWyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxnLZPWde9ZpX19q1GNDjd515pIKRwSUTjJhWjCtSQumNenM5TP0l6cs0MZd3XpkY4ce2bBbj2zsyNlsKa2je0C3rdqu21ZtlyTVJk0r5rbqRQva9KKFbTpi9hQtndGixjoSkgAAAACA4lR98tHMLpT0wVDVKkmvcPfnIu1OkHSrstOgv2Bmt7v7EyUM518UJDrTfirpEncPD136rJldLOm7CpKhDZJ+ZGZHuHtvCWOZdDq7+/Vf927IlN995lJNbWTUIyavHzyU/TaW3lG7u29QG3d3a/3OLq3f2aWtnT0HrBs5MOR6bPMePbZ5j/77/o2Z+vltjVo2q0XLZjZrQXuT5kxt0OzWes2a0qBZrfWqryE5CQAAAADIVdXJx9SGLp8OVXVLenU08ShJ7v6YmV0k6W4Fox/T5766RLEcJukDoarHJV3s7gMjxHK9mS2U9JlU1QJJ75f0xVLEMlldffd67U+NemxtqNE7z1xc2YCACmiqr9HRc1t19NxWScG6kRt2dWn9ji5t2t2tLZ09Ghw+cN1ISdrS2aMtnT2669kdI74/tbFWbU21amusVWtjbaY8tbFWbY11mtpYq6mpcmtDraY01KilvkYtDTWqTSbK9jEDAAAAACqnqpOPkl6q3M1l/sPd1+Vr7O73mdlPJL05VfXnZrbc3deUIJb3KRjFmHbFSInHkC8oSFbOT5U/LJKPee3u6tc192bXenzP2UtZ6xFQsG7kUXNaddScIBk5NOzatrdXz+0ONrDZ2tmrnfv78iYkw/b0DGhPz4A2HrTlgeprEjnJyJb6GrXU16qlPqnGuho11ibVUJtQY21SjXVJNdQmU3VJNdYlcsuhNk11SRKbAAAAAFBB1Z58fF2kfHUB53xb2eSjJL1WQSKwlLFslHTbaI3dfdDMrpH0T6mqw8zsJe7+SAlimXS+ffc6dfUPSZLammr19tMXVzYgIKaSCdP8tkbNb2tUsMysNOyuzu4B7djXqx37+rRjf5/29Axob8+g9vYOqDv1f2ss+gaH1be/Xzv394/5WlEt9TVqa6rVtOY6tTXVaVpTbfBvc53aQ8dtTbVqb6pTe1Md61sCAAAAQIlUe/LxwtDxWndfW8A5d0vqVXaU4p9rjMlHM1si6ehQ1e3ufvBhRtKvlU0+pmMh+Rixa3+f/vu+DZnyX5+1VFMY9QgULGGmac1Bgu7IOQe+Pzg0rL29g+rqG1TPwJB6+ofUMzCk7v4h9abK3Zn6QfX0DwXJxsHhcYl/f9+g9vcNanNHT8Hn1Nck1N4USkg2B0nKtsagPKUhO0IzGLFZq+b6pKak/q1htCUAAAAASKri5KOZtUlaGKp6oJDz3L3fzH4v6YxU1fGjtS/QCZFyQbFIekjSoLJfx1LEMul86+51mZFZ7Yx6BEquJpnIJCeLMeyu/lQSsncglZAcGFJv6t++wWH1Dg5pcMjVPzSsgcFhDQwNa2DIU/8Gx/3p48FhDQy7BgaHD9hEp1h9g8PatrdX2/Ye2j5ejbXJ0PTx7HTyKfU1aq6PJi6Duik5U86D48bapMxsjB8NAAAAAFRO1SYflTvSUJKKWbdxrbLJx3Yzm+Pu28Y7FnfvNbOtyiZRV4whhklp5/4+XXtfdgW6y85Zppb6au72QHwkzNSQWqexlDvPu7uGhl0DQ66+wWAEZnf/kLr6B4PjvsFMuSdc3z+k/hKNxuwZCEZ/7tjXN6brmEkNNak1LGsSaqhLZsupNTDrU+tc1tUkVJdMqCZhqq1JqDaZUF3SVJsMjmtrEqpNWOY4/F5N0lSXbpdMqK4meC+ZMNUkgn/Tr5rIMclRAAAAAKOp5izM0kh5UxHnRtsulTSW5ONYY0knH6PXqXr/fd8G9QwEox6nN9fp0tMWVTgiAOVmZqpJmmqSUmNdUm1NhZ87ODSs7oF0wnJQ3X2paeOhBGX3wFB2ZGZ6xObgkAaGxjre8kDu2URmXCVMqkkklEho1ERlupxIJSvTSct06tIseAV1lq1TtkG4bfpcS9UnzGQ28r+JTNmUDMWZjq8maTmJ1pqc+BOh9w+sz9aFzs1Xn/r40x9rIhV7ULYRPx5JSiRC7RR8PEodj3id1OfNFfQhueRyuafrPDM6OL3ISxBbcH7ClBNrwkzJTNwkmwEAAFCcak4+tkbKu4s4tyNSnhKTWGrNrN7dxzbUZhJ537nBSMdv3rVOl52zVE111dzlARxMTTKh1mRCrYewLuzQcDCNvHcwO4U8naAM6odHTFr2DQxnjnsHgn8L2Fw8NoZd6h8aloYkaXzW8UTlRJORmeRuKlmZTvQmEqFjMyUSoeNQvYUSnenkZ/5EaKocOk6k2tsIx+lzw23zGW2lbR9lIYd0Qjd7jVSS17PnZpO+oevlSQRH63ISx56NJXzfzNcm8nUK6iynPNJ7kX9SbQ7848BIbdIJ85G+7tGvRbiPZL+GwR8E0sfpr1n485v5nB1Qp4LaKafdoV8n53KpymJiiCb8M21D10/H5zl12bY550ZiyNc/sudkzz9Y27xx5dwrG0e0v+Zcq5i4Ivca6WsY/qNUug+m6zXCH67Sf7TJHIf7f+QPWdnjkesVvVbe61vkGtk/FkWvW6himmc/99HvLz7C1+DAfhFtE/6elb5ubmy53y+in6fwe/m+N1n4ZEW/BnnOidw/G0/+9tn7HRjz6PfIvn/A98O83y8Pcq0ivqj5mmb/BxTavrzXH7ltnmsUHUuR1y/yOvlOGKm2mM/XzCn1umDF7Hx3rSrVnIlpiZSLWdgrumtB9FqVjqWg5KOZ/SnPW0etXbtWxxxzTBFhxNuwS1f9QPpycT/nx2xPz8D43hDApBBNQmR+gfARfjlIt0m/n7lG6JfFTL3nlLPHB/7CO+aFMwEAAIAq1liX0GHtRUzDirG1a9dK0oJDPb+ak48NkXJ/EedGk3uNkygWSRru6+vrWrVq1XMluFY1W5b6t5Bd1FFd6BvIh76BfOgbyIe+gXzoG8iHvoF86BslNCBp1fOVjqJkFkjqPtSTqzn5GB1dWMw2rfWRcnT0YSliKXT04yHH4u6TZ2hjDKVHlvJ5RhR9A/nQN5APfQP50DeQD30D+dA3kA99A+WSqHQAFbQ/Uo6OPhxNdHRh9FoTORYAAAAAAACgJKo5+bg3Um4v4ty2SHnf2EIpWSwDbDYDAAAAAACAuKjm5OP6SHlhEecuipTXxSSWscYBAAAAAAAAlEw1Jx9XRcrLizh3Wei4w923VSIWM2uQNG+U6wAAAAAAAAAVU7XJR3fvlLQpVHVaIeeZWZ2kk0JVT5QgnMci5YJikXSKcjcNKkUsAAAAAAAAQElU827XknSrpPemjpeZ2VJ3P9jU5bOUuyHMzWMNwt3Xm9nTko5KVZ1vZubufpBTL4iUxxwLSocdwpAPfQP50DeQD30D+dA3kA99A/nQN5APfQPlUrUjH1N+Hin/dQHnRNvcUJpQcmJZJOllozU2sxpJ7wxVbZH0SIliAQAAAAAAAMas2pOPt0t6MlT+oJktydfYzE6TdFGo6hZ3X52n7WIz89DrdweJ5euSwjtVf97Makdp//eS5ofKVxUwUhIAAAAAAAAYN1WdfHT3YUlXhqqaJd1kZguibc3seEk/UfZzNizpYyWM5TlJXw1VHS/p+2ZWP0Isfynpk6GqLZK+UqpYAAAAAAAAgFIwBstJZvZVSZeHqrokfV/So5JqJZ0q6Y2p47R/cPcvjHLNxZLWh6rudPdzDxJHk6Q7Jb0kVL1F0nWS1klql/QqSeeE3u+TdL673zPatQEAAAAAAIDxRvJRkpklJV0j6W0FNHdJn3X3K0drdCjJx9R5cyTdIunEAmLZJ+nt7h5duxIAAAAAAACouKqedp3m7kPufqmkNyt3DcioBxSMMhw18TjGWLYpGGn5CUnb8jTrV7BBzQkkHgEAAAAAABBXjHwcgZkdq2DNxXmShiRtlfSwu68b5ziSkk6XtFzSbAUjHTdLutvdd49nLAAAAAAAAECxSD4CIWZ2jHITz1skPeLu60c9sfRxJBQknpdJmitpTyqWu929YzxjQaDSfcPM6iQdLWmFpDmSmiTtlbQ9Fce4/nEEWZXuG4ivuPUNM2tV8LNlnqRZkvZLeiEV16Pu3lWJuKpRXPqGmS1TsNTPXElTJPVI2iXpcUlPuPvgeMaD+OBZFFE8iwIYC5KPgCQze6Okjyv4RWAk90n6mLv/rsxx1Ej6qIINkOaN0KRf0k2S/t7dN5QzFgQq2TfMbL6Cza5eJelMBQ95+ayR9DVJX3P3vlLHggPF5ftGPmb2fklfiVR/0t3/uQLhVJW49Q0zO0vBz5YLJNXlaTakYHmZj7n7neMRVzWKQ99Izax5n6T3SzpqlKY7Jf23pE8z46Z8Ukm+oxVsOJl+nSCpMdTsz8bx+wXPojERh77Bs2g8xaFvFIJnUYSRfERVSz2AXy3pHQU0H1bwAP7xMsUyW9LNyt3tPJ+9ki519xvLEQsq3zfM7GWSfinJijz1T5Le5O6rShULclW6bxTCzA6TtErBSKYwHvjKKG59w8yaFDz0v0OFfy/5B3f/QrliqlZx6RtmNkvBxoaFPGukvSDpDe5+T6njqXZm9j+SXi6p+SBNxyWJwLNofMShb/AsGk9x6BuF4FkUUTWVDgCosC8p9xeBbknfl/SogtEhKyW9QVKtgg2a/snMdrv7l0oZhJk1SrpRuQ97WyR9T9JaSdMlvVLS2an3WiX90MzOc/f7SxkLMirdN5qU+7A3LOkxSXdL2iipQ1K7gg2q/kLZ0UzHSLrDzM509zUligW5Kt03CvF1Hfiwh/KLTd8ws2YFSaZzQtU9kn6jYITjdklJBVPnXiTpPAU/W1AeFe8bqSmTv1buqMs+Sb9Q0Cd2S2qRdJyCkU7TUm1mSfpfM1tJMqHkTtLBEwjjgmfR2IlD3+BZNJ7i0DcKwbMocrk7L15V+ZJ0oSQPvf4kacEI7U5Q8PCVbjck6bgSx/LvkVh+Iql+hHYXK5jukm63SVJDpT+Xk+0Vh74h6bWpa65TMP1p7ihtFyqYqheO+a5Kfx4n4ysOfaOAGN8Suu+qSLz/XOnP4WR9xa1vSLo1Es+1kmaN0r5W0uskvaLSn8vJ9opL35B0RSSORyUtydN2iqQfRdr/utKfy8n2krQh9PntlfSQgl/Yr4t87s8dh1h4Fo3RKw59QzyLxvIVh75RQIw8i/I64JUQUIVS62R8OlTVLenV7v5ctK27PybpIgV/7ZOCEQmfjrYbQyyHSfpAqOpxSRf7CGuluPv1kj4RqlqgYM0mlEiM+sYLki6TdKS7f87dn8/X0N03KZh+8Uyo+iwzOzvPKTgEMeobo8U4XdKXU8VeSR8q9z0Rv75hZn+lYIRS2ufd/VJ3fyHfOe4+4O4/d/dfljKWahezvvH20HFPKo71IzV0932S3qrgmSTtpWY20hqAOHTXSnqPgpFMU9z9FHd/n4IRyuOGZ9FYikPf4Fk0nuLQN/LiWRT5kHxEtXqpcqcd/YePskObu9+n4C/AaX9uZstLFMv7JDWEyle4+8Ao7b+gYGRE2odLFAcCsegb7n6fu3/rIH0h3H6fpE9Gqv98rHEgRyz6xkF8ScEUSUn6lILF31F+sekbZjZFwc+JtAck/Z9SXBuHJBZ9w8waFOxQm3bzSAnQSCyDkr4dvozyb5SDQ+Dun3D3b7v7Hwr9eV8mPIvGTBz6Bs+i8RSHvnEQPItiRCQfUa1eFylfXcA5346UX1uaUHJi2SjpttEap34ZuCZUdZiZFbNwPEYXp75RrNsj5WUViWLyinXfSC0M/7ZUcZWkz5frXjhAnPrGJZLaQuUr3H04T1uUX1z6xvRIudBfBldHytNGbIWJjmdRlArPolWMZ1GMhuQjqtWFoeO17r62gHPuVjB0PG3Mf8kzsyWSjg5V3e4eLJRxEL+OlPmrYunEom8cov2R8kRYjHoiiW3fSG0u8s1U0SVdFtO/hk9Wceob7wkdP+Pud5foujg0cekbnQq+N6QV+vOhJVLOO3UfExPPoigxnkWrFM+iOBiSj6g6ZtamYFHktAcKOc/d+yX9PlRViqlHJ0TKBcWiYGHhwRLHUvVi1jcOxZJIeVtFopiEJkDf+JSkxanjq939njLdBxFx6htmNkPBztVpt471mjh0ceob7t6lYJfatPMKPPWloeP0xgaYXHgWRSnxLFq9eBbFqEg+ohodHSkXsw5FeMRCu5nNqUQs7t4raWuoakW+tihKnPrGoXh9pHx/BWKYrGLbN8zsFGUX896uYEdKjJ849Y1TIuX7pWDxdzP7iJndY2bPm1lf6t/7zOxTZnb4GO+LkcWpb0jSf4aOjzWzUTcJMbOTJb0rVPUtd99bgjgQLzyLopR4Fq1CPIuiECQfUY2WRsqbijg32jZ6rUrFMtY4EIhT3yiKmbVIujxU1S/pxvGMYZKLZd8ws1pJ31H25/lH3L2jVNdHQeLUN14cKT9tZm+Q9LSk/yfpDElzJNWl/j1N0sckPWVmXzOz+jHeH7ni1DekYI2+8M+F/0x93Y8KNzKzOWZ2haTfSkr3iYckXVmCGBA/PIuiJHgWrU48i6JQJB9RjVoj5d1FnBv9RjolJrHU8ktjScSpbxTri5LmhsrfcHemupROXPvGP0o6NnV8m7v/oITXRmHi1DdmRsrnKtg5eUaq7JJ2SHpe0lCoXVLBbre/MbPGMcaArDj1DaXW8XuTpKsUTJc1BV/3p8xsj5mtN7N0//icgrXaBiR9XdJLU1O3MfnwLIpS4Vm0OvEsioKQfEQ1ii6e3jtiq5H1HORaEzkWTNCvh5ldqtxNJjZJ+vh43b9KxK5vmNnRCkatpe/xvlJcF0WLU99oi5S/qCDB1CfpnyXNd/dZ7j5Pwe7Hlys30XCGgkQTSiNOfUNSsJ6ku39EwS+Kd4bealWwVteMUN0mSa9198vdPbqJBCaP2PVTTDw8i1YnnkVRDJKPqEYNkXJ/Eef2RcpjHSESp1gwAb8eZnaOpG+HqgYkvYV1uUouVn3DzEzB1z09yuRf3H3dWK+LQxKnvhH9xb9WwfeEV7n7J939+fQb7r7H3b8u6UxJu0LnvD211h/GLk59Q5JkZgkz+4ikuySdc5DmCyXdYma/NjOm1E5eseunmFh4Fq1OPIuiWCQfUY2if9GtK+Lc6HSS6F98J3IsmGBfDzM7SdIvlI3TJb3T3Vncu/Ti1jcuVzBKTZKeUDDCDZURp74x0oilL7r7HflOcPenJP1tpPrDY4wDgTj1DZlZg6SbFaz/OStVfbuk1yqYKlknqV1BUvLbyk7NP1/SI2Z24lhjQCzFqp9iYuFZtKrxLIqikHxENYpOHYr+xXc00b/ojnUaUpxiwQT6epjZcZJ+pdy1mi539++X875VLDZ9w8wWSPpMquiSLnP3gbFcE2MSm74haV+k7JL+o4DzrlewO2Xa+WOMA4E49Q1J+rKkV4bKV7r7Be5+o7tvc/cBd+9097vc/T2SXqZsYqpd0s9SG0pgcolbP8UEwbNo9eJZFIeC5COqUXQKQHsR57ZFytFf9IpVqlgG3D069QXFi1PfyCu1M+ntCtZsS/uwu3+jXPdErPrG15XdfOIbjC6ouDj1jWgsT4enWufj7oOS7glVzTKzw8YYC2LUN1Lrcv11qOoX7v6ZfO0lKTVi9mOhqkWSLhtLHIglnkVRNJ5Fqx7PoigayUdUo/WR8sIizl0UKY91XYtSxcL6GqURp74xIjM7XNIdyk6Zk6R/dPcvl+N+yIhF3zCz10i6MFXcJun/HOq1UDKx6BspayPlTUWcuzFSju6cjeLFqW+8RcHmQ2lfKfC8byp3DcDXjzEOxA/PoigKz6LVjWdRHKqaSgcAVMCqSHl5EecuCx13uPu2MsRy50gNw1LrNs0b5To4NHHqGwdILfh/h4K1udI+4e6fK/W9cIC49I3wpg9Nkn4frPedV/Tn/IfM7JJQ+VPu/t0xxIP49A1J+lOkXMyutdG2xUy9xMji1DeOj5QfKeQkd+8ys6dD5x8zxjgQPzyLomA8i0I8i+IQkXxE1XH3TjPbpOxfdk8r5Dwzq5N0UqjqiRKE81ikfJqk7xRw3inK/f9biliqXsz6RvQeiyT9VlJ4KuSn3P1fS30vHCimfaNVuessFaJduVPq2koWTZWKWd94UsEmIclUeVoR50bb7hqxFQoWs77RHCkXszZfV+iY3YwnH55FURCeRTECnkVRMKZdo1rdGjpelvor3sGcpdyRIDePNQh3Xy/p6VDV+XaQPx2lXBApjzkWZMSib4Sl1l67Q7lToT7n7h8v5X1wULHrG4iNWPQNd9+j3BFLx5tZoc96Lw4dD0jaPNZ4ICkmfUNSR6Q8p4hzwyOcSEpPMjyLohA8iwIYK5KPqFY/j5T/esRWo7e5oTSh5MSySMHuknmZWY2kd4aqtqjA6VMoSJz6hsxsroKHvfAvrP/P3f+xVPdAwSreN9z9Kne3Ql+SlkQu8clIm6vGEg8yKt43Qn4aOp6qg/xMkSQzWyLp5FDVA+7eXaJ4ql1c+saaSDmaOBpRam23xaGqZ0sQC+KHZ1HkxbMowngWxaEi+YhqdbuC6WlpH0z98jUiMztN0kWhqlvcfXWetovNzEOv3x0klq9LCu8O+Hkzqx2l/d9Lmh8qX+XufpB7oHCx6RtmNjMVz+Gh6v9w97872AeBsohN30DsxKlvXCdpe6j82dQ03tF8UbnPhP99kPYoXFz6xi8j5SvNbMqILXNF13H7VQHnoMJ4FkU+PIsiH55FUW4kH1GV3H1Y0pWhqmZJN5nZgmhbMzte0k+U/f8yLOljJYzlOUlfDVUdL+n7ZlY/Qix/KemToaotKnzHShQgLn3DzNol/VrSilD119z9b0pxfRQvLn0D8ROnvuHu+yX931DVCZJ+lvqeEo2l3sy+Kul1oepnJV1bqniqXVz6hrvfLenhUNUySbemplIewMyazOxq5faNvZK+XYp4EC88i2IkPIsCKCU2nEHVcvebzOxrki5PVR0j6Skz+76kRyXVSjpV0htTx2kfdffo4txj9XFJZ0t6Sap8kaTTzew6SesULMr7KknnhM7pk/QWdy9mN1MUICZ94wMKkgZhrzCz6NS50Wx293NLFA8Um76BGIpZ3/iWgp8Xf5kqXyhpjZn9WNLjkgYVjGJ5k4Iplmn7Jb3B3QdKHE9Vi1HfuEzSXZJaUuUzFfSLX0h6UMF6js0KEk9vkDQ9cv7fuPvOEsZT9czs9ZI+P8Jb0VGp3zeznhHaXeHuPytRODyLxkhM+gbPojEUk74BFI3kI6rdhxR8o35bqtws6T152rqkz7r7F0odhLt3m9mrJd0i6cRU9XxJ+dZS2Sfp7e5+T6ljQUal+0ZyhLpCNioI43t8eVS6byC+YtE33N3N7B0KRtC9OVU9TdJ7Rzlti6TXufuTo7TBoat433D3P5rZhZJ+qOwmMvUKkkwX5T1R6pX0EXf/binjgaRgl9hlBbSbN8r5JcGzaOzEoW/wLBpPcegbQNGYdo2q5u5D7n6pgl/ORvuF6wFJ57v7laO0GWss2xSMfPiEpG15mvUrWBT8BHePLmKPEopT30C80DeQT5z6hrv3u/tbFIxufHSUpnsUjKA4wd0fHqUdxiAufcPd75J0rKR/U/5njbRuSddIerG7f6Mc8SBeeBYFAJSLsTYwkGVmxyqYbjRP0pCkrZIedvd14xxHUtLpkpZLmq3gr8ubJd3t7rvHMxYE4tI3ED/0DeQTp75hZkdIenEqljoFU2xXSXrI3QfHO55qF4e+YWYm6WhJL5I0U8HIzB5JuxX0jUfdvS/vBTCp8SwKACglko8AAAAAAAAAyoJp1wAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACMwMwSZnaMmb3dzP7TzO43s24z89Dr3ErHmWZmGyKxHcrrd6WMqaaUFwMAAAAAAAAmAzP7H0kvl9Rc6VjGWWcpL0byEQAAAAAAADjQSZp4iccNkgaLPGeepMZQ+Qcli0YkHwEAAAAAAICD6ZP0uKTfS2qRdEllwxmZu59bTHszq5e0Rdnk4y5JN5QyJpKPAAAAAAAAwIGulfScgoTjE+4+IElm9g7FNPl4CF4raXqofJ2795XyBiQfAQAAAAAAgAh3/8R43cvMTNKJklZImiXJJG2X9Ad3/1MZb/3uSPk7pb4ByUcAAAAAAACgAsxsiqSPKkgCzs7TZrWk/+vuJV2L0cwWS3ppqOpBd3+ylPeQpESpLwgAAAAAAABgdGZ2qqTVkj6mPInHlMMlXW9mPzaz2hKG8C4FIyzTri7htTMY+QgAAAAAAACMIzP7M0k3S2oKVT+TqlurYMfqIyW9SdKC1PsXSXJJby7B/ROS3hGq6pL0o7FedyQkHwEAAAAAAIBxYmazJP1A2cRjr6T3S7rG3T3S9uOSviTpslTVm8zsZne/boxhvEzZpKYk/cjd943xmiNi2jUAAAAAAAAwfj6r7DTrYUmvc/f/iiYeJcnde9z9vZL+J1T9r6mRi2MR3WimLFOuJZKPAAAAAAAAwLgwszmS3hqqutrdf1nAqR+SNJA6XiTpVWOIYaak14SqVrn7/Yd6vYMh+QgAAAAAAACMjzdKqguVv1TISe6+VdLtoaoLxhDDpZLCG9d8ZwzXOiiSjwAAAAAAAMD4OCt0vM7dny7i3IdCxyvHEMO7Qsf9kq4dw7UOiuQjAAAAAAAAMD5OCB3/qchzt4eODzuUm5vZaZJWhKpudPedh3KtQrHbNQAAAAAAADA+poeOX21mB2wyU6D2Qzxv3DaaSWPkIwAAAAAAADA+2kp0naZiTzCzFklvClVtVO46kmXByEcAAAAAAABgfHRLak0dd0jaPY73foukllD5GncfLvdNST4CAAAAAAAA42OnssnHn7j7ZeN4778KHQ9LumY8bsq0awAAAAAAAGB8hHe3Pma8bmpmx0g6NVR1m7tvGo97k3wEAAAAAAAAxsdvQ8enmtmMcbrvX0XK3xmn+5J8BAAAAAAAAMbJTyUNpo6Tkv6h3Dc0szpJbwtV7ZB0Y7nvm0byEQAAAAAAABgH7r5B0g9CVX9rZi8r5hoWqCvilL+QFB5hea27DxRzz7Eg+QgAAAAAAACMnyskPZ86rpF0k5n9nZk1jHaSmc01sw8qWDfyxCLuV7Ep15Jk7j6e9wMAAAAAAABiz8xeL+nzI7w1RdKsUHmrpJ4R2l3h7j/Lc+3TJP1S2Z2vpWAn7F9JelTSbgXTstskHaEg2fhiSZZqe5q7P1DAx7BQ0nplByDe5+5nHOy8UqoZz5sBAAAAAAAAE0SrpGUFtJs3yvkjcvf7zexUSTcoSC5KwdTot6ZeBzNUQBtJeqdyZz5fXeB5JcO0awAAAAAAAGCcuftTko6V9F5Jqwo4ZZWkL0p6sbs/fLDGZmYKko9p+yT9+BBCHROmXQMAAAAAAAAVZmbzJZ0qabakdkn9kjokrZX0pLvvqGB4h4zkIwAAAAAAAICyYNo1AAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMqC5CMAAAAAAACAsiD5CAAAAAAAAKAsSD4CAAAAAAAAKAuSjwAAAAAAAADKguQjAAAAAAAAgLIg+QgAAAAAAACgLEg+AgAAAAAAACgLko8AAAAAAAAAyoLkIwAAAAAAAICyIPkIAAAAAAAAoCxIPgIAAAAAAAAoC5KPAAAAAAAAAMri/wMZEmFf1HC9MAAAAABJRU5ErkJggg==\n",
217
- "text/plain": [
218
- "<Figure size 1500x750 with 1 Axes>"
219
- ]
220
- },
221
- "metadata": {
222
- "needs_background": "light"
223
- },
224
- "output_type": "display_data"
225
- }
226
- ],
227
- "source": [
228
- "gene_detection_counts = [i for i in gene_detection_counts_dict.values()]\n",
229
- "import seaborn as sns\n",
230
- "import matplotlib.pyplot as plt\n",
231
- "plt.figure(figsize=(10,5), dpi=150)\n",
232
- "plt.rcParams.update({'font.size': 18})\n",
233
- "count_plot = sns.distplot(gene_detection_counts).set_title(f\"# Cells Expressing Each\\nProtein-Coding or miRNA Gene\")"
234
- ]
235
- },
236
- {
237
- "cell_type": "code",
238
- "execution_count": 47,
239
- "id": "missing-bradley",
240
- "metadata": {},
241
- "outputs": [
242
- {
243
- "data": {
244
- "text/plain": [
245
- "27454"
246
- ]
247
- },
248
- "execution_count": 47,
249
- "metadata": {},
250
- "output_type": "execute_result"
251
- }
252
- ],
253
- "source": [
254
- "len(gene_detection_counts)"
255
- ]
256
- },
257
- {
258
- "cell_type": "code",
259
- "execution_count": 55,
260
- "id": "perfect-signal",
261
- "metadata": {},
262
- "outputs": [
263
- {
264
- "data": {
265
- "text/plain": [
266
- "25424"
267
- ]
268
- },
269
- "execution_count": 55,
270
- "metadata": {},
271
- "output_type": "execute_result"
272
- }
273
- ],
274
- "source": [
275
- "len([i for i in gene_detection_counts if i > 0])"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": 56,
281
- "id": "faced-theory",
282
- "metadata": {},
283
- "outputs": [
284
- {
285
- "data": {
286
- "text/plain": [
287
- "22735"
288
- ]
289
- },
290
- "execution_count": 56,
291
- "metadata": {},
292
- "output_type": "execute_result"
293
- }
294
- ],
295
- "source": [
296
- "len([i for i in gene_detection_counts if i > 100])"
297
- ]
298
- },
299
- {
300
- "cell_type": "code",
301
- "execution_count": 57,
302
- "id": "tough-workplace",
303
- "metadata": {},
304
- "outputs": [
305
- {
306
- "data": {
307
- "text/plain": [
308
- "21167"
309
- ]
310
- },
311
- "execution_count": 57,
312
- "metadata": {},
313
- "output_type": "execute_result"
314
- }
315
- ],
316
- "source": [
317
- "len([i for i in gene_detection_counts if i > 1000])"
318
- ]
319
- },
320
- {
321
- "cell_type": "code",
322
- "execution_count": 49,
323
- "id": "cooperative-camcorder",
324
- "metadata": {},
325
- "outputs": [
326
- {
327
- "data": {
328
- "text/plain": [
329
- "173152.0299000284"
330
- ]
331
- },
332
- "execution_count": 49,
333
- "metadata": {},
334
- "output_type": "execute_result"
335
- }
336
- ],
337
- "source": [
338
- "gene_detection_event_digest = crick.tdigest.TDigest()\n",
339
- "gene_detection_event_digest.update(gene_detection_counts)\n",
340
- "gene_detection_event_digest.quantile(0.5)"
341
- ]
342
- }
343
- ],
344
- "metadata": {
345
- "kernelspec": {
346
- "display_name": "Python 3 (ipykernel)",
347
- "language": "python",
348
- "name": "python3"
349
- },
350
- "language_info": {
351
- "codemirror_mode": {
352
- "name": "ipython",
353
- "version": 3
354
- },
355
- "file_extension": ".py",
356
- "mimetype": "text/x-python",
357
- "name": "python",
358
- "nbconvert_exporter": "python",
359
- "pygments_lexer": "ipython3",
360
- "version": "3.10.11"
361
- }
362
- },
363
- "nbformat": 4,
364
- "nbformat_minor": 5
365
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/pretraining_new_model/pretrain_geneformer_w_deepspeed.py DELETED
@@ -1,167 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # run with:
5
- # deepspeed --num_gpus=12 --num_nodes=3 pretrain_geneformer_w_deepspeed.py --deepspeed ds_config.json
6
-
7
- import datetime
8
-
9
- # imports
10
- import os
11
-
12
- os.environ["NCCL_DEBUG"] = "INFO"
13
- os.environ["OMPI_MCA_opal_cuda_support"] = "true"
14
- os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
15
-
16
- import pickle
17
- import random
18
- import subprocess
19
-
20
- import numpy as np
21
- import pytz
22
- import torch
23
- from datasets import load_from_disk
24
- from transformers import BertConfig, BertForMaskedLM, TrainingArguments
25
-
26
- from geneformer import GeneformerPretrainer
27
-
28
- seed_num = 0
29
- random.seed(seed_num)
30
- np.random.seed(seed_num)
31
- seed_val = 42
32
- torch.manual_seed(seed_val)
33
- torch.cuda.manual_seed_all(seed_val)
34
-
35
- # set local time/directories
36
- timezone = pytz.timezone("US/Eastern")
37
- rootdir = "/parent_ouput_directory"
38
-
39
- # set model parameters
40
- # model type
41
- model_type = "bert"
42
- # max input size
43
- max_input_size = 2**11 # 2048
44
- # number of layers
45
- num_layers = 6
46
- # number of attention heads
47
- num_attn_heads = 4
48
- # number of embedding dimensions
49
- num_embed_dim = 256
50
- # intermediate size
51
- intermed_size = num_embed_dim * 2
52
- # activation function
53
- activ_fn = "relu"
54
- # initializer range, layer norm, dropout
55
- initializer_range = 0.02
56
- layer_norm_eps = 1e-12
57
- attention_probs_dropout_prob = 0.02
58
- hidden_dropout_prob = 0.02
59
-
60
-
61
- # set training parameters
62
- # total number of examples in Genecorpus-30M after QC filtering:
63
- num_examples = 27_406_208
64
- # number gpus
65
- num_gpus = 12
66
- # batch size for training and eval
67
- geneformer_batch_size = 12
68
- # max learning rate
69
- max_lr = 1e-3
70
- # learning schedule
71
- lr_schedule_fn = "linear"
72
- # warmup steps
73
- warmup_steps = 10_000
74
- # number of epochs
75
- epochs = 3
76
- # optimizer
77
- optimizer = "adamw"
78
- # weight_decay
79
- weight_decay = 0.001
80
-
81
-
82
- # output directories
83
- current_date = datetime.datetime.now(tz=timezone)
84
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}"
85
- run_name = f"{datestamp}_geneformer_30M_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}"
86
- training_output_dir = f"{rootdir}/models/{run_name}/"
87
- logging_dir = f"{rootdir}/runs/{run_name}/"
88
- model_output_dir = os.path.join(training_output_dir, "models/")
89
-
90
-
91
- # ensure not overwriting previously saved model
92
- model_output_file = os.path.join(model_output_dir, "pytorch_model.bin")
93
- if os.path.isfile(model_output_file) is True:
94
- raise Exception("Model already saved to this directory.")
95
-
96
-
97
- # make training and model output directories
98
- subprocess.call(f"mkdir {training_output_dir}", shell=True)
99
- subprocess.call(f"mkdir {model_output_dir}", shell=True)
100
-
101
-
102
- # load gene_ensembl_id:token dictionary (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/token_dictionary.pkl)
103
- with open("token_dictionary.pkl", "rb") as fp:
104
- token_dictionary = pickle.load(fp)
105
-
106
- # model configuration
107
- config = {
108
- "hidden_size": num_embed_dim,
109
- "num_hidden_layers": num_layers,
110
- "initializer_range": initializer_range,
111
- "layer_norm_eps": layer_norm_eps,
112
- "attention_probs_dropout_prob": attention_probs_dropout_prob,
113
- "hidden_dropout_prob": hidden_dropout_prob,
114
- "intermediate_size": intermed_size,
115
- "hidden_act": activ_fn,
116
- "max_position_embeddings": max_input_size,
117
- "model_type": model_type,
118
- "num_attention_heads": num_attn_heads,
119
- "pad_token_id": token_dictionary.get("<pad>"),
120
- "vocab_size": len(token_dictionary), # genes+2 for <mask> and <pad> tokens
121
- }
122
-
123
- config = BertConfig(**config)
124
- model = BertForMaskedLM(config)
125
- model = model.train()
126
-
127
- # define the training arguments
128
- training_args = {
129
- "learning_rate": max_lr,
130
- "do_train": True,
131
- "do_eval": False,
132
- "group_by_length": True,
133
- "length_column_name": "length",
134
- "disable_tqdm": False,
135
- "lr_scheduler_type": lr_schedule_fn,
136
- "warmup_steps": warmup_steps,
137
- "weight_decay": weight_decay,
138
- "per_device_train_batch_size": geneformer_batch_size,
139
- "num_train_epochs": epochs,
140
- "save_strategy": "steps",
141
- "save_steps": np.floor(
142
- num_examples / geneformer_batch_size / 8
143
- ), # 8 saves per epoch
144
- "logging_steps": 1000,
145
- "output_dir": training_output_dir,
146
- "logging_dir": logging_dir,
147
- }
148
- training_args = TrainingArguments(**training_args)
149
-
150
- print("Starting training.")
151
-
152
- # define the trainer
153
- trainer = GeneformerPretrainer(
154
- model=model,
155
- args=training_args,
156
- # pretraining corpus (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048.dataset)
157
- train_dataset=load_from_disk("genecorpus_30M_2048.dataset"),
158
- # file of lengths of each example cell (e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/blob/main/genecorpus_30M_2048_lengths.pkl)
159
- example_lengths_file="genecorpus_30M_2048_lengths.pkl",
160
- token_dictionary=token_dictionary,
161
- )
162
-
163
- # train
164
- trainer.train()
165
-
166
- # save model
167
- trainer.save_model(model_output_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/tokenizing_scRNAseq_data.ipynb DELETED
@@ -1,91 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
6
- "metadata": {
7
- "tags": []
8
- },
9
- "source": [
10
- "## Tokenizing .loom or .h5ad single cell RNA-seq data to rank value encoding .dataset format"
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "id": "1fe86f48-5578-47df-b373-58c21ec170ab",
16
- "metadata": {},
17
- "source": [
18
- "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
19
- "\n",
20
- "#### The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.\n",
21
- "\n",
22
- "#### Genes should be labeled with Ensembl IDs (loom row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute \"n_counts\") to be used for normalization.\n",
23
- "\n",
24
- "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
25
- "\n",
26
- "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
- "\n",
28
- "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
29
- ]
30
- },
31
- {
32
- "cell_type": "markdown",
33
- "id": "32c69493-4e5a-4b07-8dc1-958ff2ee7d0b",
34
- "metadata": {},
35
- "source": [
36
- "**********************************************************************************************************\n",
37
- "#### OF NOTE: PLEASE ENSURE THE CORRECT TOKEN DICTIONARY AND GENE MEDIAN FILE IS USED FOR THE CORRECT MODEL.\n",
38
- "#### 95M: current defaults; 30M: https://huggingface.co/ctheodoris/Geneformer/tree/main/geneformer/gene_dictionaries_30m\n",
39
- "\n",
40
- "#### ADDITIONALLY:\n",
41
- "#### The 95M model series require the special_token argument to be set to True and model_input_size to be 4096. (current defaults)\n",
42
- "#### The 30M model series require the special_token argument to be set to False and the model_input_size to be 2048."
43
- ]
44
- },
45
- {
46
- "cell_type": "code",
47
- "execution_count": null,
48
- "id": "080fdd9c-0c48-4d5d-a254-52b6c53cdf78",
49
- "metadata": {},
50
- "outputs": [],
51
- "source": [
52
- "from geneformer import TranscriptomeTokenizer"
53
- ]
54
- },
55
- {
56
- "cell_type": "code",
57
- "execution_count": null,
58
- "id": "37205758-aa52-4443-a383-0638519ee8a9",
59
- "metadata": {},
60
- "outputs": [],
61
- "source": [
62
- "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
63
- "tk.tokenize_data(\"loom_data_directory\", \n",
64
- " \"output_directory\", \n",
65
- " \"output_prefix\", \n",
66
- " file_format=\"loom\")"
67
- ]
68
- }
69
- ],
70
- "metadata": {
71
- "kernelspec": {
72
- "display_name": "Python 3 (ipykernel)",
73
- "language": "python",
74
- "name": "python3"
75
- },
76
- "language_info": {
77
- "codemirror_mode": {
78
- "name": "ipython",
79
- "version": 3
80
- },
81
- "file_extension": ".py",
82
- "mimetype": "text/x-python",
83
- "name": "python",
84
- "nbconvert_exporter": "python",
85
- "pygments_lexer": "ipython3",
86
- "version": "3.10.15"
87
- }
88
- },
89
- "nbformat": 4,
90
- "nbformat_minor": 5
91
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/__init__.py DELETED
@@ -1,34 +0,0 @@
1
- # ruff: noqa: F401
2
- import warnings
3
- from pathlib import Path
4
-
5
- warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip
6
-
7
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl"
8
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl"
9
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl"
10
- ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl"
11
-
12
- from . import (
13
- collator_for_classification,
14
- emb_extractor,
15
- in_silico_perturber,
16
- in_silico_perturber_stats,
17
- pretrainer,
18
- tokenizer,
19
- )
20
- from .collator_for_classification import (
21
- DataCollatorForCellClassification,
22
- DataCollatorForGeneClassification,
23
- )
24
- from .emb_extractor import EmbExtractor, get_embs
25
- from .in_silico_perturber import InSilicoPerturber
26
- from .in_silico_perturber_stats import InSilicoPerturberStats
27
- from .pretrainer import GeneformerPretrainer
28
- from .tokenizer import TranscriptomeTokenizer
29
-
30
- from . import classifier # noqa # isort:skip
31
- from .classifier import Classifier # noqa # isort:skip
32
-
33
- from . import mtl_classifier # noqa # isort:skip
34
- from .mtl_classifier import MTLClassifier # noqa # isort:skip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/classifier.py DELETED
@@ -1,1563 +0,0 @@
1
- """
2
- Geneformer classifier.
3
-
4
- **Input data:**
5
-
6
- | Cell state classifier:
7
- | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
8
-
9
- | Gene classifier:
10
- | Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
11
-
12
- **Usage:**
13
-
14
- .. code-block :: python
15
-
16
- >>> from geneformer import Classifier
17
- >>> cc = Classifier(classifier="cell", # example of cell state classifier
18
- ... cell_state_dict={"state_key": "disease", "states": "all"},
19
- ... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
20
- ... training_args=training_args,
21
- ... freeze_layers = 2,
22
- ... num_crossval_splits = 1,
23
- ... forward_batch_size=200,
24
- ... nproc=16)
25
- >>> cc.prepare_data(input_data_file="path/to/input_data",
26
- ... output_directory="path/to/output_directory",
27
- ... output_prefix="output_prefix")
28
- >>> all_metrics = cc.validate(model_directory="path/to/model",
29
- ... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
30
- ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
31
- ... output_directory="path/to/output_directory",
32
- ... output_prefix="output_prefix",
33
- ... predict_eval=True)
34
- >>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
35
- ... output_directory="path/to/output_directory",
36
- ... output_prefix="output_prefix",
37
- ... custom_class_order=["healthy","disease1","disease2"])
38
- >>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
39
- ... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
40
- ... title="disease",
41
- ... output_directory="path/to/output_directory",
42
- ... output_prefix="output_prefix",
43
- ... custom_class_order=["healthy","disease1","disease2"])
44
- """
45
-
46
- import datetime
47
- import logging
48
- import os
49
- import pickle
50
- import subprocess
51
- from pathlib import Path
52
-
53
- import numpy as np
54
- import pandas as pd
55
- import seaborn as sns
56
- from tqdm.auto import tqdm, trange
57
- from transformers import Trainer
58
- from transformers.training_args import TrainingArguments
59
-
60
- from . import (
61
- TOKEN_DICTIONARY_FILE,
62
- DataCollatorForCellClassification,
63
- DataCollatorForGeneClassification,
64
- )
65
- from . import classifier_utils as cu
66
- from . import evaluation_utils as eu
67
- from . import perturber_utils as pu
68
-
69
- sns.set()
70
-
71
-
72
- logger = logging.getLogger(__name__)
73
-
74
-
75
- class Classifier:
76
- valid_option_dict = {
77
- "classifier": {"cell", "gene"},
78
- "quantize": {bool, dict},
79
- "cell_state_dict": {None, dict},
80
- "gene_class_dict": {None, dict},
81
- "filter_data": {None, dict},
82
- "rare_threshold": {int, float},
83
- "max_ncells": {None, int},
84
- "max_ncells_per_class": {None, int},
85
- "training_args": {None, dict},
86
- "freeze_layers": {int},
87
- "num_crossval_splits": {0, 1, 5},
88
- "split_sizes": {None, dict},
89
- "no_eval": {bool},
90
- "stratify_splits_col": {None, str},
91
- "forward_batch_size": {int},
92
- "token_dictionary_file": {None, str},
93
- "nproc": {int},
94
- "ngpu": {int},
95
- }
96
-
97
- def __init__(
98
- self,
99
- classifier=None,
100
- quantize=False,
101
- cell_state_dict=None,
102
- gene_class_dict=None,
103
- filter_data=None,
104
- rare_threshold=0,
105
- max_ncells=None,
106
- max_ncells_per_class=None,
107
- training_args=None,
108
- ray_config=None,
109
- freeze_layers=0,
110
- num_crossval_splits=1,
111
- split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
112
- stratify_splits_col=None,
113
- no_eval=False,
114
- forward_batch_size=100,
115
- token_dictionary_file=None,
116
- nproc=4,
117
- ngpu=1,
118
- ):
119
- """
120
- Initialize Geneformer classifier.
121
-
122
- **Parameters:**
123
-
124
- classifier : {"cell", "gene"}
125
- | Whether to fine-tune a cell state or gene classifier.
126
- quantize : bool, dict
127
- | Whether to fine-tune a quantized model.
128
- | If True and no config provided, will use default.
129
- | Will use custom config if provided.
130
- | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
131
- | For example: {"bnb_config": BitsAndBytesConfig(...),
132
- | "peft_config": LoraConfig(...)}
133
- cell_state_dict : None, dict
134
- | Cell states to fine-tune model to distinguish.
135
- | Two-item dictionary with keys: state_key and states
136
- | state_key: key specifying name of column in .dataset that defines the states to model
137
- | states: list of values in the state_key column that specifies the states to model
138
- | Alternatively, instead of a list of states, can specify "all" to use all states in that state key from input data.
139
- | Of note, if using "all", states will be defined after data is filtered.
140
- | Must have at least 2 states to model.
141
- | For example: {"state_key": "disease",
142
- | "states": ["nf", "hcm", "dcm"]}
143
- | or
144
- | {"state_key": "disease",
145
- | "states": "all"}
146
- gene_class_dict : None, dict
147
- | Gene classes to fine-tune model to distinguish.
148
- | Dictionary in format: {Gene_label_A: list(geneA1, geneA2, ...),
149
- | Gene_label_B: list(geneB1, geneB2, ...)}
150
- | Gene values should be Ensembl IDs.
151
- filter_data : None, dict
152
- | Default is to fine-tune with all input data.
153
- | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
154
- rare_threshold : float
155
- | Threshold below which rare cell states should be removed.
156
- | For example, setting to 0.05 will remove cell states representing
157
- | < 5% of the total cells from the cell state classifier's possible classes.
158
- max_ncells : None, int
159
- | Maximum number of cells to use for fine-tuning.
160
- | Default is to fine-tune with all input data.
161
- max_ncells_per_class : None, int
162
- | Maximum number of cells per cell class to use for fine-tuning.
163
- | Of note, will be applied after max_ncells above.
164
- | (Only valid for cell classification.)
165
- training_args : None, dict
166
- | Training arguments for fine-tuning.
167
- | If None, defaults will be inferred for 6 layer Geneformer.
168
- | Otherwise, will use the Hugging Face defaults:
169
- | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
170
- | Note: Hyperparameter tuning is highly recommended, rather than using defaults.
171
- ray_config : None, dict
172
- | Training argument ranges for tuning hyperparameters with Ray.
173
- freeze_layers : int
174
- | Number of layers to freeze from fine-tuning.
175
- | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
176
- num_crossval_splits : {0, 1, 5}
177
- | 0: train on all data without splitting
178
- | 1: split data into train and eval sets by designated split_sizes["valid"]
179
- | 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
180
- split_sizes : None, dict
181
- | Dictionary of proportion of data to hold out for train, validation, and test sets
182
- | {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
183
- stratify_splits_col : None, str
184
- | Name of column in .dataset to be used for stratified splitting.
185
- | Proportion of each class in this column will be the same in the splits as in the original dataset.
186
- no_eval : bool
187
- | If True, will skip eval step and use all data for training.
188
- | Otherwise, will perform eval during training.
189
- forward_batch_size : int
190
- | Batch size for forward pass (for evaluation, not training).
191
- token_dictionary_file : None, str
192
- | Default is to use token dictionary file from Geneformer
193
- | Otherwise, will load custom gene token dictionary.
194
- nproc : int
195
- | Number of CPU processes to use.
196
- ngpu : int
197
- | Number of GPUs available.
198
-
199
- """
200
-
201
- self.classifier = classifier
202
- if self.classifier == "cell":
203
- self.model_type = "CellClassifier"
204
- elif self.classifier == "gene":
205
- self.model_type = "GeneClassifier"
206
- self.quantize = quantize
207
- self.cell_state_dict = cell_state_dict
208
- self.gene_class_dict = gene_class_dict
209
- self.filter_data = filter_data
210
- self.rare_threshold = rare_threshold
211
- self.max_ncells = max_ncells
212
- self.max_ncells_per_class = max_ncells_per_class
213
- self.training_args = training_args
214
- self.ray_config = ray_config
215
- self.freeze_layers = freeze_layers
216
- self.num_crossval_splits = num_crossval_splits
217
- self.split_sizes = split_sizes
218
- self.train_size = self.split_sizes["train"]
219
- self.valid_size = self.split_sizes["valid"]
220
- self.oos_test_size = self.split_sizes["test"]
221
- self.eval_size = self.valid_size / (self.train_size + self.valid_size)
222
- self.stratify_splits_col = stratify_splits_col
223
- self.no_eval = no_eval
224
- self.forward_batch_size = forward_batch_size
225
- self.token_dictionary_file = token_dictionary_file
226
- self.nproc = nproc
227
- self.ngpu = ngpu
228
-
229
- if self.training_args is None:
230
- logger.warning(
231
- "Hyperparameter tuning is highly recommended for optimal results. "
232
- "No training_args provided; using default hyperparameters."
233
- )
234
-
235
- self.validate_options()
236
-
237
- if self.filter_data is None:
238
- self.filter_data = dict()
239
-
240
- if self.classifier == "cell":
241
- if self.cell_state_dict["states"] != "all":
242
- self.filter_data[
243
- self.cell_state_dict["state_key"]
244
- ] = self.cell_state_dict["states"]
245
-
246
- # load token dictionary (Ensembl IDs:token)
247
- if self.token_dictionary_file is None:
248
- self.token_dictionary_file = TOKEN_DICTIONARY_FILE
249
- with open(self.token_dictionary_file, "rb") as f:
250
- self.gene_token_dict = pickle.load(f)
251
-
252
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
253
-
254
- # filter genes for gene classification for those in token dictionary
255
- if self.classifier == "gene":
256
- all_gene_class_values = set(pu.flatten_list(self.gene_class_dict.values()))
257
- missing_genes = [
258
- gene
259
- for gene in all_gene_class_values
260
- if gene not in self.gene_token_dict.keys()
261
- ]
262
- if len(missing_genes) == len(all_gene_class_values):
263
- logger.error(
264
- "None of the provided genes to classify are in token dictionary."
265
- )
266
- raise
267
- elif len(missing_genes) > 0:
268
- logger.warning(
269
- f"Genes to classify {missing_genes} are not in token dictionary."
270
- )
271
- self.gene_class_dict = {
272
- k: list(set([self.gene_token_dict.get(gene) for gene in v]))
273
- for k, v in self.gene_class_dict.items()
274
- }
275
- empty_classes = []
276
- for k, v in self.gene_class_dict.items():
277
- if len(v) == 0:
278
- empty_classes += [k]
279
- if len(empty_classes) > 0:
280
- logger.error(
281
- f"Class(es) {empty_classes} did not contain any genes in the token dictionary."
282
- )
283
- raise
284
-
285
- def validate_options(self):
286
- # confirm arguments are within valid options and compatible with each other
287
- for attr_name, valid_options in self.valid_option_dict.items():
288
- attr_value = self.__dict__[attr_name]
289
- if not isinstance(attr_value, (list, dict)):
290
- if attr_value in valid_options:
291
- continue
292
- valid_type = False
293
- for option in valid_options:
294
- if (option in [int, float, list, dict, bool, str]) and isinstance(
295
- attr_value, option
296
- ):
297
- valid_type = True
298
- break
299
- if valid_type:
300
- continue
301
- logger.error(
302
- f"Invalid option for {attr_name}. "
303
- f"Valid options for {attr_name}: {valid_options}"
304
- )
305
- raise
306
-
307
- if self.filter_data is not None:
308
- for key, value in self.filter_data.items():
309
- if not isinstance(value, list):
310
- self.filter_data[key] = [value]
311
- logger.warning(
312
- "Values in filter_data dict must be lists. "
313
- f"Changing {key} value to list ([{value}])."
314
- )
315
-
316
- if self.classifier == "cell":
317
- if set(self.cell_state_dict.keys()) != set(["state_key", "states"]):
318
- logger.error(
319
- "Invalid keys for cell_state_dict. "
320
- "The cell_state_dict should have only 2 keys: state_key and states"
321
- )
322
- raise
323
-
324
- if self.cell_state_dict["states"] != "all":
325
- if not isinstance(self.cell_state_dict["states"], list):
326
- logger.error(
327
- "States in cell_state_dict should be list of states to model."
328
- )
329
- raise
330
- if len(self.cell_state_dict["states"]) < 2:
331
- logger.error(
332
- "States in cell_state_dict should contain at least 2 states to classify."
333
- )
334
- raise
335
-
336
- if self.classifier == "gene":
337
- if len(self.gene_class_dict.keys()) < 2:
338
- logger.error(
339
- "Gene_class_dict should contain at least 2 gene classes to classify."
340
- )
341
- raise
342
- if sum(self.split_sizes.values()) != 1:
343
- logger.error("Train, validation, and test proportions should sum to 1.")
344
- raise
345
-
346
- def prepare_data(
347
- self,
348
- input_data_file,
349
- output_directory,
350
- output_prefix,
351
- split_id_dict=None,
352
- test_size=None,
353
- attr_to_split=None,
354
- attr_to_balance=None,
355
- max_trials=100,
356
- pval_threshold=0.1,
357
- ):
358
- """
359
- Prepare data for cell state or gene classification.
360
-
361
- **Parameters**
362
-
363
- input_data_file : Path
364
- | Path to directory containing .dataset input
365
- output_directory : Path
366
- | Path to directory where prepared data will be saved
367
- output_prefix : str
368
- | Prefix for output file
369
- split_id_dict : None, dict
370
- | Dictionary of IDs for train and test splits
371
- | Three-item dictionary with keys: attr_key, train, test
372
- | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
373
- | train: list of IDs in the attr_key column to include in the train split
374
- | test: list of IDs in the attr_key column to include in the test split
375
- | For example: {"attr_key": "individual",
376
- | "train": ["patient1", "patient2", "patient3", "patient4"],
377
- | "test": ["patient5", "patient6"]}
378
- test_size : None, float
379
- | Proportion of data to be saved separately and held out for test set
380
- | (e.g. 0.2 if intending hold out 20%)
381
- | If None, will inherit from split_sizes["test"] from Classifier
382
- | The training set will be further split to train / validation in self.validate
383
- | Note: only available for CellClassifiers
384
- attr_to_split : None, str
385
- | Key for attribute on which to split data while balancing potential confounders
386
- | e.g. "patient_id" for splitting by patient while balancing other characteristics
387
- | Note: only available for CellClassifiers
388
- attr_to_balance : None, list
389
- | List of attribute keys on which to balance data while splitting on attr_to_split
390
- | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
391
- | Note: only available for CellClassifiers
392
- max_trials : None, int
393
- | Maximum number of trials of random splitting to try to achieve balanced other attributes
394
- | If no split is found without significant (p<0.05) differences in other attributes, will select best
395
- | Note: only available for CellClassifiers
396
- pval_threshold : None, float
397
- | P-value threshold to use for attribute balancing across splits
398
- | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
399
- """
400
-
401
- if test_size is None:
402
- test_size = self.oos_test_size
403
-
404
- # prepare data and labels for classification
405
- data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
406
-
407
- if self.classifier == "cell":
408
- if "label" in data.features:
409
- logger.error(
410
- "Column name 'label' must be reserved for class IDs. Please rename column."
411
- )
412
- raise
413
- elif self.classifier == "gene":
414
- if "labels" in data.features:
415
- logger.error(
416
- "Column name 'labels' must be reserved for class IDs. Please rename column."
417
- )
418
- raise
419
-
420
- if (attr_to_split is not None) and (attr_to_balance is None):
421
- logger.error(
422
- "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined."
423
- )
424
- raise
425
-
426
- if not isinstance(attr_to_balance, list):
427
- attr_to_balance = [attr_to_balance]
428
-
429
- if self.classifier == "cell":
430
- # remove cell states representing < rare_threshold of cells
431
- data = cu.remove_rare(
432
- data, self.rare_threshold, self.cell_state_dict["state_key"], self.nproc
433
- )
434
- # downsample max cells and max per class
435
- data = cu.downsample_and_shuffle(
436
- data, self.max_ncells, self.max_ncells_per_class, self.cell_state_dict
437
- )
438
- # rename cell state column to "label"
439
- data = cu.rename_cols(data, self.cell_state_dict["state_key"])
440
-
441
- # convert classes to numerical labels and save as id_class_dict
442
- # of note, will label all genes in gene_class_dict
443
- # if (cross-)validating, genes will be relabeled in column "labels" for each split
444
- # at the time of training with Classifier.validate
445
- data, id_class_dict = cu.label_classes(
446
- self.classifier, data, self.gene_class_dict, self.nproc
447
- )
448
-
449
- # save id_class_dict for future reference
450
- id_class_output_path = (
451
- Path(output_directory) / f"{output_prefix}_id_class_dict"
452
- ).with_suffix(".pkl")
453
- with open(id_class_output_path, "wb") as f:
454
- pickle.dump(id_class_dict, f)
455
-
456
- if split_id_dict is not None:
457
- data_dict = dict()
458
- data_dict["train"] = pu.filter_by_dict(
459
- data, {split_id_dict["attr_key"]: split_id_dict["train"]}, self.nproc
460
- )
461
- data_dict["test"] = pu.filter_by_dict(
462
- data, {split_id_dict["attr_key"]: split_id_dict["test"]}, self.nproc
463
- )
464
- train_data_output_path = (
465
- Path(output_directory) / f"{output_prefix}_labeled_train"
466
- ).with_suffix(".dataset")
467
- test_data_output_path = (
468
- Path(output_directory) / f"{output_prefix}_labeled_test"
469
- ).with_suffix(".dataset")
470
- data_dict["train"].save_to_disk(str(train_data_output_path))
471
- data_dict["test"].save_to_disk(str(test_data_output_path))
472
- elif (test_size is not None) and (self.classifier == "cell"):
473
- if 1 > test_size > 0:
474
- if attr_to_split is None:
475
- data_dict = data.train_test_split(
476
- test_size=test_size,
477
- stratify_by_column=self.stratify_splits_col,
478
- seed=42,
479
- )
480
- train_data_output_path = (
481
- Path(output_directory) / f"{output_prefix}_labeled_train"
482
- ).with_suffix(".dataset")
483
- test_data_output_path = (
484
- Path(output_directory) / f"{output_prefix}_labeled_test"
485
- ).with_suffix(".dataset")
486
- data_dict["train"].save_to_disk(str(train_data_output_path))
487
- data_dict["test"].save_to_disk(str(test_data_output_path))
488
- else:
489
- data_dict, balance_df = cu.balance_attr_splits(
490
- data,
491
- attr_to_split,
492
- attr_to_balance,
493
- test_size,
494
- max_trials,
495
- pval_threshold,
496
- self.cell_state_dict["state_key"],
497
- self.nproc,
498
- )
499
- balance_df.to_csv(
500
- f"{output_directory}/{output_prefix}_train_test_balance_df.csv"
501
- )
502
- train_data_output_path = (
503
- Path(output_directory) / f"{output_prefix}_labeled_train"
504
- ).with_suffix(".dataset")
505
- test_data_output_path = (
506
- Path(output_directory) / f"{output_prefix}_labeled_test"
507
- ).with_suffix(".dataset")
508
- data_dict["train"].save_to_disk(str(train_data_output_path))
509
- data_dict["test"].save_to_disk(str(test_data_output_path))
510
- else:
511
- data_output_path = (
512
- Path(output_directory) / f"{output_prefix}_labeled"
513
- ).with_suffix(".dataset")
514
- data.save_to_disk(str(data_output_path))
515
- print(data_output_path)
516
- else:
517
- data_output_path = (
518
- Path(output_directory) / f"{output_prefix}_labeled"
519
- ).with_suffix(".dataset")
520
- data.save_to_disk(str(data_output_path))
521
-
522
- def train_all_data(
523
- self,
524
- model_directory,
525
- prepared_input_data_file,
526
- id_class_dict_file,
527
- output_directory,
528
- output_prefix,
529
- save_eval_output=True,
530
- gene_balance=False,
531
- ):
532
- """
533
- Train cell state or gene classifier using all data.
534
-
535
- **Parameters**
536
-
537
- model_directory : Path
538
- | Path to directory containing model
539
- prepared_input_data_file : Path
540
- | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
541
- id_class_dict_file : Path
542
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
543
- | (dictionary of format: numerical IDs: class_labels)
544
- output_directory : Path
545
- | Path to directory where model and eval data will be saved
546
- output_prefix : str
547
- | Prefix for output files
548
- save_eval_output : bool
549
- | Whether to save cross-fold eval output
550
- | Saves as pickle file of dictionary of eval metrics
551
- gene_balance : None, bool
552
- | Whether to automatically balance genes in training set.
553
- | Only available for binary gene classifications.
554
-
555
- **Output**
556
-
557
- Returns trainer after fine-tuning with all data.
558
-
559
- """
560
-
561
- if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
562
- logger.error(
563
- "Automatically balancing gene sets for training is only available for binary gene classifications."
564
- )
565
- raise
566
-
567
- ##### Load data and prepare output directory #####
568
- # load numerical id to class dictionary (id:class)
569
- with open(id_class_dict_file, "rb") as f:
570
- id_class_dict = pickle.load(f)
571
- class_id_dict = {v: k for k, v in id_class_dict.items()}
572
-
573
- # load previously filtered and prepared data
574
- data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
575
- data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
576
-
577
- # define output directory path
578
- current_date = datetime.datetime.now()
579
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
580
- if output_directory[-1:] != "/": # add slash for dir if not present
581
- output_directory = output_directory + "/"
582
- output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
583
- subprocess.call(f"mkdir {output_dir}", shell=True)
584
-
585
- # get number of classes for classifier
586
- num_classes = cu.get_num_classes(id_class_dict)
587
-
588
- if self.classifier == "gene":
589
- targets = pu.flatten_list(self.gene_class_dict.values())
590
- labels = pu.flatten_list(
591
- [
592
- [class_id_dict[label]] * len(targets)
593
- for label, targets in self.gene_class_dict.items()
594
- ]
595
- )
596
- assert len(targets) == len(labels)
597
- data = cu.prep_gene_classifier_all_data(
598
- data, targets, labels, self.max_ncells, self.nproc, gene_balance
599
- )
600
-
601
- trainer = self.train_classifier(
602
- model_directory, num_classes, data, None, output_dir
603
- )
604
-
605
- return trainer
606
-
607
- def validate(
608
- self,
609
- model_directory,
610
- prepared_input_data_file,
611
- id_class_dict_file,
612
- output_directory,
613
- output_prefix,
614
- split_id_dict=None,
615
- attr_to_split=None,
616
- attr_to_balance=None,
617
- gene_balance=False,
618
- max_trials=100,
619
- pval_threshold=0.1,
620
- save_eval_output=True,
621
- predict_eval=True,
622
- predict_trainer=False,
623
- n_hyperopt_trials=0,
624
- save_gene_split_datasets=True,
625
- debug_gene_split_datasets=False,
626
- ):
627
- """
628
- (Cross-)validate cell state or gene classifier.
629
-
630
- **Parameters**
631
-
632
- model_directory : Path
633
- | Path to directory containing model
634
- prepared_input_data_file : Path
635
- | Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
636
- id_class_dict_file : Path
637
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
638
- | (dictionary of format: numerical IDs: class_labels)
639
- output_directory : Path
640
- | Path to directory where model and eval data will be saved
641
- output_prefix : str
642
- | Prefix for output files
643
- split_id_dict : None, dict
644
- | Dictionary of IDs for train and eval splits
645
- | Three-item dictionary with keys: attr_key, train, eval
646
- | attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
647
- | train: list of IDs in the attr_key column to include in the train split
648
- | eval: list of IDs in the attr_key column to include in the eval split
649
- | For example: {"attr_key": "individual",
650
- | "train": ["patient1", "patient2", "patient3", "patient4"],
651
- | "eval": ["patient5", "patient6"]}
652
- | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
653
- attr_to_split : None, str
654
- | Key for attribute on which to split data while balancing potential confounders
655
- | e.g. "patient_id" for splitting by patient while balancing other characteristics
656
- | Note: only available for CellClassifiers with 1-fold split (self.classifier="cell"; self.num_crossval_splits=1)
657
- attr_to_balance : None, list
658
- | List of attribute keys on which to balance data while splitting on attr_to_split
659
- | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient
660
- gene_balance : None, bool
661
- | Whether to automatically balance genes in training set.
662
- | Only available for binary gene classifications.
663
- max_trials : None, int
664
- | Maximum number of trials of random splitting to try to achieve balanced other attribute
665
- | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
666
- pval_threshold : None, float
667
- | P-value threshold to use for attribute balancing across splits
668
- | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
669
- save_eval_output : bool
670
- | Whether to save cross-fold eval output
671
- | Saves as pickle file of dictionary of eval metrics
672
- predict_eval : bool
673
- | Whether or not to save eval predictions
674
- | Saves as a pickle file of self.evaluate predictions
675
- predict_trainer : bool
676
- | Whether or not to save eval predictions from trainer
677
- | Saves as a pickle file of trainer predictions
678
- n_hyperopt_trials : int
679
- | Number of trials to run for hyperparameter optimization
680
- | If 0, will not optimize hyperparameters
681
- save_gene_split_datasets : bool
682
- | Whether or not to save train, valid, and test gene-labeled datasets
683
- """
684
- if self.num_crossval_splits == 0:
685
- logger.error("num_crossval_splits must be 1 or 5 to validate.")
686
- raise
687
-
688
- if (gene_balance is True) and (len(self.gene_class_dict.values()) != 2):
689
- logger.error(
690
- "Automatically balancing gene sets for training is only available for binary gene classifications."
691
- )
692
- raise
693
-
694
- # ensure number of genes in each class is > 5 if validating model
695
- if self.classifier == "gene":
696
- insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5]
697
- if (self.num_crossval_splits > 0) and (len(insuff_classes) > 0):
698
- logger.error(
699
- f"Insufficient # of members in class(es) {insuff_classes} to (cross-)validate."
700
- )
701
- raise
702
-
703
- ##### Load data and prepare output directory #####
704
- # load numerical id to class dictionary (id:class)
705
- with open(id_class_dict_file, "rb") as f:
706
- id_class_dict = pickle.load(f)
707
- class_id_dict = {v: k for k, v in id_class_dict.items()}
708
-
709
- # load previously filtered and prepared data
710
- data = pu.load_and_filter(None, self.nproc, prepared_input_data_file)
711
- data = data.shuffle(seed=42) # reshuffle in case users provide unshuffled data
712
-
713
- # define output directory path
714
- current_date = datetime.datetime.now()
715
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
716
- if output_directory[-1:] != "/": # add slash for dir if not present
717
- output_directory = output_directory + "/"
718
- output_dir = f"{output_directory}{datestamp}_geneformer_{self.classifier}Classifier_{output_prefix}/"
719
- subprocess.call(f"mkdir {output_dir}", shell=True)
720
-
721
- # get number of classes for classifier
722
- num_classes = cu.get_num_classes(id_class_dict)
723
-
724
- ##### (Cross-)validate the model #####
725
- results = []
726
- all_conf_mat = np.zeros((num_classes, num_classes))
727
- iteration_num = 1
728
- if self.classifier == "cell":
729
- for i in trange(self.num_crossval_splits):
730
- print(
731
- f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
732
- )
733
- ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
734
- if self.num_crossval_splits == 1:
735
- # single 1-eval_size:eval_size split
736
- if split_id_dict is not None:
737
- data_dict = dict()
738
- data_dict["train"] = pu.filter_by_dict(
739
- data,
740
- {split_id_dict["attr_key"]: split_id_dict["train"]},
741
- self.nproc,
742
- )
743
- data_dict["test"] = pu.filter_by_dict(
744
- data,
745
- {split_id_dict["attr_key"]: split_id_dict["eval"]},
746
- self.nproc,
747
- )
748
- elif attr_to_split is not None:
749
- data_dict, balance_df = cu.balance_attr_splits(
750
- data,
751
- attr_to_split,
752
- attr_to_balance,
753
- self.eval_size,
754
- max_trials,
755
- pval_threshold,
756
- self.cell_state_dict["state_key"],
757
- self.nproc,
758
- )
759
-
760
- balance_df.to_csv(
761
- f"{output_dir}/{output_prefix}_train_valid_balance_df.csv"
762
- )
763
- else:
764
- data_dict = data.train_test_split(
765
- test_size=self.eval_size,
766
- stratify_by_column=self.stratify_splits_col,
767
- seed=42,
768
- )
769
- train_data = data_dict["train"]
770
- eval_data = data_dict["test"]
771
- else:
772
- # 5-fold cross-validate
773
- num_cells = len(data)
774
- fifth_cells = int(np.floor(num_cells * 0.2))
775
- num_eval = min((self.eval_size * num_cells), fifth_cells)
776
- start = i * fifth_cells
777
- end = start + num_eval
778
- eval_indices = [j for j in range(start, end)]
779
- train_indices = [
780
- j for j in range(num_cells) if j not in eval_indices
781
- ]
782
- eval_data = data.select(eval_indices)
783
- train_data = data.select(train_indices)
784
- if n_hyperopt_trials == 0:
785
- trainer = self.train_classifier(
786
- model_directory,
787
- num_classes,
788
- train_data,
789
- eval_data,
790
- ksplit_output_dir,
791
- predict_trainer,
792
- )
793
- else:
794
- trainer = self.hyperopt_classifier(
795
- model_directory,
796
- num_classes,
797
- train_data,
798
- eval_data,
799
- ksplit_output_dir,
800
- n_trials=n_hyperopt_trials,
801
- )
802
- if iteration_num == self.num_crossval_splits:
803
- return
804
- else:
805
- iteration_num = iteration_num + 1
806
- continue
807
-
808
- result = self.evaluate_model(
809
- trainer.model,
810
- num_classes,
811
- id_class_dict,
812
- eval_data,
813
- predict_eval,
814
- ksplit_output_dir,
815
- output_prefix,
816
- )
817
- results += [result]
818
- all_conf_mat = all_conf_mat + result["conf_mat"]
819
- iteration_num = iteration_num + 1
820
-
821
- elif self.classifier == "gene":
822
- # set up (cross-)validation splits
823
- targets = pu.flatten_list(self.gene_class_dict.values())
824
- labels = pu.flatten_list(
825
- [
826
- [class_id_dict[label]] * len(targets)
827
- for label, targets in self.gene_class_dict.items()
828
- ]
829
- )
830
- assert len(targets) == len(labels)
831
- n_splits = int(1 / (1 - self.train_size))
832
- skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
833
- # (Cross-)validate
834
- test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
835
- for train_index, eval_index, test_index in tqdm(
836
- skf.split(targets, labels, test_ratio)
837
- ):
838
- print(
839
- f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
840
- )
841
- ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
842
- # filter data for examples containing classes for this split
843
- # subsample to max_ncells and relabel data in column "labels"
844
- train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
845
- data,
846
- targets,
847
- labels,
848
- train_index,
849
- eval_index,
850
- self.max_ncells,
851
- iteration_num,
852
- self.nproc,
853
- gene_balance,
854
- )
855
-
856
- if save_gene_split_datasets is True:
857
- for split_name in ["train", "valid"]:
858
- labeled_dataset_output_path = (
859
- Path(output_dir)
860
- / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}"
861
- ).with_suffix(".dataset")
862
- if split_name == "train":
863
- train_data.save_to_disk(str(labeled_dataset_output_path))
864
- elif split_name == "valid":
865
- eval_data.save_to_disk(str(labeled_dataset_output_path))
866
-
867
- if self.oos_test_size > 0:
868
- test_data = cu.prep_gene_classifier_split(
869
- data,
870
- targets,
871
- labels,
872
- test_index,
873
- "test",
874
- self.max_ncells,
875
- iteration_num,
876
- self.nproc,
877
- )
878
- if save_gene_split_datasets is True:
879
- test_labeled_dataset_output_path = (
880
- Path(output_dir)
881
- / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}"
882
- ).with_suffix(".dataset")
883
- test_data.save_to_disk(str(test_labeled_dataset_output_path))
884
- if debug_gene_split_datasets is True:
885
- logger.error(
886
- "Exiting after saving gene split datasets given debug_gene_split_datasets = True."
887
- )
888
- raise
889
- if n_hyperopt_trials == 0:
890
- trainer = self.train_classifier(
891
- model_directory,
892
- num_classes,
893
- train_data,
894
- eval_data,
895
- ksplit_output_dir,
896
- predict_trainer,
897
- )
898
- result = self.evaluate_model(
899
- trainer.model,
900
- num_classes,
901
- id_class_dict,
902
- eval_data,
903
- predict_eval,
904
- ksplit_output_dir,
905
- output_prefix,
906
- )
907
- else:
908
- trainer = self.hyperopt_classifier(
909
- model_directory,
910
- num_classes,
911
- train_data,
912
- eval_data,
913
- ksplit_output_dir,
914
- n_trials=n_hyperopt_trials,
915
- )
916
-
917
- model = cu.load_best_model(
918
- ksplit_output_dir, self.model_type, num_classes
919
- )
920
-
921
- if self.oos_test_size > 0:
922
- result = self.evaluate_model(
923
- model,
924
- num_classes,
925
- id_class_dict,
926
- test_data,
927
- predict_eval,
928
- ksplit_output_dir,
929
- output_prefix,
930
- )
931
- else:
932
- if iteration_num == self.num_crossval_splits:
933
- return
934
- else:
935
- iteration_num = iteration_num + 1
936
- continue
937
- results += [result]
938
- all_conf_mat = all_conf_mat + result["conf_mat"]
939
- # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
940
- if iteration_num == self.num_crossval_splits:
941
- break
942
- iteration_num = iteration_num + 1
943
-
944
- all_conf_mat_df = pd.DataFrame(
945
- all_conf_mat, columns=id_class_dict.values(), index=id_class_dict.values()
946
- )
947
- all_metrics = {
948
- "conf_matrix": all_conf_mat_df,
949
- "macro_f1": [result["macro_f1"] for result in results],
950
- "acc": [result["acc"] for result in results],
951
- }
952
- all_roc_metrics = None # roc metrics not reported for multiclass
953
- if num_classes == 2:
954
- mean_fpr = np.linspace(0, 1, 100)
955
- all_tpr = [result["roc_metrics"]["interp_tpr"] for result in results]
956
- all_roc_auc = [result["roc_metrics"]["auc"] for result in results]
957
- all_tpr_wt = [result["roc_metrics"]["tpr_wt"] for result in results]
958
- mean_tpr, roc_auc, roc_auc_sd = eu.get_cross_valid_roc_metrics(
959
- all_tpr, all_roc_auc, all_tpr_wt
960
- )
961
- all_roc_metrics = {
962
- "mean_tpr": mean_tpr,
963
- "mean_fpr": mean_fpr,
964
- "all_roc_auc": all_roc_auc,
965
- "roc_auc": roc_auc,
966
- "roc_auc_sd": roc_auc_sd,
967
- }
968
- all_metrics["all_roc_metrics"] = all_roc_metrics
969
- if save_eval_output is True:
970
- eval_metrics_output_path = (
971
- Path(output_dir) / f"{output_prefix}_eval_metrics_dict"
972
- ).with_suffix(".pkl")
973
- with open(eval_metrics_output_path, "wb") as f:
974
- pickle.dump(all_metrics, f)
975
-
976
- return all_metrics
977
-
978
- def hyperopt_classifier(
979
- self,
980
- model_directory,
981
- num_classes,
982
- train_data,
983
- eval_data,
984
- output_directory,
985
- n_trials=100,
986
- ):
987
- """
988
- Fine-tune model for cell state or gene classification.
989
-
990
- **Parameters**
991
-
992
- model_directory : Path
993
- | Path to directory containing model
994
- num_classes : int
995
- | Number of classes for classifier
996
- train_data : Dataset
997
- | Loaded training .dataset input
998
- | For cell classifier, labels in column "label".
999
- | For gene classifier, labels in column "labels".
1000
- eval_data : None, Dataset
1001
- | (Optional) Loaded evaluation .dataset input
1002
- | For cell classifier, labels in column "label".
1003
- | For gene classifier, labels in column "labels".
1004
- output_directory : Path
1005
- | Path to directory where fine-tuned model will be saved
1006
- n_trials : int
1007
- | Number of trials to run for hyperparameter optimization
1008
- """
1009
-
1010
- # initiate runtime environment for raytune
1011
- import ray
1012
- from ray import tune
1013
- from ray.tune.search.hyperopt import HyperOptSearch
1014
-
1015
- ray.shutdown() # engage new ray session
1016
- ray.init()
1017
-
1018
- ##### Validate and prepare data #####
1019
- train_data, eval_data = cu.validate_and_clean_cols(
1020
- train_data, eval_data, self.classifier
1021
- )
1022
-
1023
- if (self.no_eval is True) and (eval_data is not None):
1024
- logger.warning(
1025
- "no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
1026
- )
1027
-
1028
- # ensure not overwriting previously saved model
1029
- saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
1030
- if os.path.isfile(saved_model_test) is True:
1031
- logger.error("Model already saved to this designated output directory.")
1032
- raise
1033
- # make output directory
1034
- subprocess.call(f"mkdir {output_directory}", shell=True)
1035
-
1036
- ##### Load model and training args #####
1037
- model = pu.load_model(
1038
- self.model_type,
1039
- num_classes,
1040
- model_directory,
1041
- "train",
1042
- quantize=self.quantize,
1043
- )
1044
- def_training_args, def_freeze_layers = cu.get_default_train_args(
1045
- model, self.classifier, train_data, output_directory
1046
- )
1047
- del model
1048
-
1049
- if self.training_args is not None:
1050
- def_training_args.update(self.training_args)
1051
- logging_steps = round(
1052
- len(train_data) / def_training_args["per_device_train_batch_size"] / 10
1053
- )
1054
- def_training_args["logging_steps"] = logging_steps
1055
- def_training_args["output_dir"] = output_directory
1056
- if eval_data is None:
1057
- def_training_args["evaluation_strategy"] = "no"
1058
- def_training_args["load_best_model_at_end"] = False
1059
- def_training_args.update(
1060
- {"save_strategy": "epoch", "save_total_limit": 1}
1061
- ) # only save last model for each run
1062
- training_args_init = TrainingArguments(**def_training_args)
1063
-
1064
- ##### Fine-tune the model #####
1065
- # define the data collator
1066
- if self.classifier == "cell":
1067
- data_collator = DataCollatorForCellClassification(
1068
- token_dictionary=self.gene_token_dict
1069
- )
1070
- elif self.classifier == "gene":
1071
- data_collator = DataCollatorForGeneClassification(
1072
- token_dictionary=self.gene_token_dict
1073
- )
1074
-
1075
- # define function to initiate model
1076
- def model_init():
1077
- model = pu.load_model(
1078
- self.model_type,
1079
- num_classes,
1080
- model_directory,
1081
- "train",
1082
- quantize=self.quantize,
1083
- )
1084
-
1085
- if self.freeze_layers is not None:
1086
- def_freeze_layers = self.freeze_layers
1087
-
1088
- if def_freeze_layers > 0:
1089
- modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
1090
- for module in modules_to_freeze:
1091
- for param in module.parameters():
1092
- param.requires_grad = False
1093
-
1094
- if self.quantize is False:
1095
- model = model.to("cuda:0")
1096
- return model
1097
-
1098
- # create the trainer
1099
- trainer = Trainer(
1100
- model_init=model_init,
1101
- args=training_args_init,
1102
- data_collator=data_collator,
1103
- train_dataset=train_data,
1104
- eval_dataset=eval_data,
1105
- compute_metrics=cu.compute_metrics,
1106
- )
1107
-
1108
- # specify raytune hyperparameter search space
1109
- if self.ray_config is None:
1110
- logger.warning(
1111
- "No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
1112
- )
1113
- def_ray_config = {
1114
- "num_train_epochs": tune.choice([1]),
1115
- "learning_rate": tune.loguniform(1e-6, 1e-3),
1116
- "weight_decay": tune.uniform(0.0, 0.3),
1117
- "lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
1118
- "warmup_steps": tune.uniform(100, 2000),
1119
- "seed": tune.uniform(0, 100),
1120
- "per_device_train_batch_size": tune.choice(
1121
- [def_training_args["per_device_train_batch_size"]]
1122
- ),
1123
- }
1124
-
1125
- hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
1126
-
1127
- # optimize hyperparameters
1128
- trainer.hyperparameter_search(
1129
- direction="maximize",
1130
- backend="ray",
1131
- resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
1132
- hp_space=lambda _: def_ray_config
1133
- if self.ray_config is None
1134
- else self.ray_config,
1135
- search_alg=hyperopt_search,
1136
- n_trials=n_trials, # number of trials
1137
- progress_reporter=tune.CLIReporter(
1138
- max_report_frequency=600,
1139
- sort_by_metric=True,
1140
- max_progress_rows=n_trials,
1141
- mode="max",
1142
- metric="eval_macro_f1",
1143
- metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1144
- ),
1145
- storage_path=output_directory,
1146
- )
1147
-
1148
- return trainer
1149
-
1150
- def train_classifier(
1151
- self,
1152
- model_directory,
1153
- num_classes,
1154
- train_data,
1155
- eval_data,
1156
- output_directory,
1157
- predict=False,
1158
- ):
1159
- """
1160
- Fine-tune model for cell state or gene classification.
1161
-
1162
- **Parameters**
1163
-
1164
- model_directory : Path
1165
- | Path to directory containing model
1166
- num_classes : int
1167
- | Number of classes for classifier
1168
- train_data : Dataset
1169
- | Loaded training .dataset input
1170
- | For cell classifier, labels in column "label".
1171
- | For gene classifier, labels in column "labels".
1172
- eval_data : None, Dataset
1173
- | (Optional) Loaded evaluation .dataset input
1174
- | For cell classifier, labels in column "label".
1175
- | For gene classifier, labels in column "labels".
1176
- output_directory : Path
1177
- | Path to directory where fine-tuned model will be saved
1178
- predict : bool
1179
- | Whether or not to save eval predictions from trainer
1180
- """
1181
-
1182
- ##### Validate and prepare data #####
1183
- train_data, eval_data = cu.validate_and_clean_cols(
1184
- train_data, eval_data, self.classifier
1185
- )
1186
-
1187
- if (self.no_eval is True) and (eval_data is not None):
1188
- logger.warning(
1189
- "no_eval set to True; model will be trained without evaluation."
1190
- )
1191
- eval_data = None
1192
-
1193
- if (self.classifier == "gene") and (predict is True):
1194
- logger.warning(
1195
- "Predictions during training not currently available for gene classifiers; setting predict to False."
1196
- )
1197
- predict = False
1198
-
1199
- # ensure not overwriting previously saved model
1200
- saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
1201
- if os.path.isfile(saved_model_test) is True:
1202
- logger.error("Model already saved to this designated output directory.")
1203
- raise
1204
- # make output directory
1205
- subprocess.call(f"mkdir {output_directory}", shell=True)
1206
-
1207
- ##### Load model and training args #####
1208
- model = pu.load_model(
1209
- self.model_type,
1210
- num_classes,
1211
- model_directory,
1212
- "train",
1213
- quantize=self.quantize,
1214
- )
1215
-
1216
- def_training_args, def_freeze_layers = cu.get_default_train_args(
1217
- model, self.classifier, train_data, output_directory
1218
- )
1219
-
1220
- if self.training_args is not None:
1221
- def_training_args.update(self.training_args)
1222
- logging_steps = round(
1223
- len(train_data) / def_training_args["per_device_train_batch_size"] / 10
1224
- )
1225
- def_training_args["logging_steps"] = logging_steps
1226
- def_training_args["output_dir"] = output_directory
1227
- if eval_data is None:
1228
- def_training_args["evaluation_strategy"] = "no"
1229
- def_training_args["load_best_model_at_end"] = False
1230
- training_args_init = TrainingArguments(**def_training_args)
1231
-
1232
- if self.freeze_layers is not None:
1233
- def_freeze_layers = self.freeze_layers
1234
-
1235
- if def_freeze_layers > 0:
1236
- modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
1237
- for module in modules_to_freeze:
1238
- for param in module.parameters():
1239
- param.requires_grad = False
1240
-
1241
- ##### Fine-tune the model #####
1242
- # define the data collator
1243
- if self.classifier == "cell":
1244
- data_collator = DataCollatorForCellClassification(
1245
- token_dictionary=self.gene_token_dict
1246
- )
1247
- elif self.classifier == "gene":
1248
- data_collator = DataCollatorForGeneClassification(
1249
- token_dictionary=self.gene_token_dict
1250
- )
1251
-
1252
- # create the trainer
1253
- trainer = Trainer(
1254
- model=model,
1255
- args=training_args_init,
1256
- data_collator=data_collator,
1257
- train_dataset=train_data,
1258
- eval_dataset=eval_data,
1259
- compute_metrics=cu.compute_metrics,
1260
- )
1261
-
1262
- # train the classifier
1263
- trainer.train()
1264
- trainer.save_model(output_directory)
1265
- if predict is True:
1266
- # make eval predictions and save predictions and metrics
1267
- predictions = trainer.predict(eval_data)
1268
- prediction_output_path = f"{output_directory}/predictions.pkl"
1269
- with open(prediction_output_path, "wb") as f:
1270
- pickle.dump(predictions, f)
1271
- trainer.save_metrics("eval", predictions.metrics)
1272
- return trainer
1273
-
1274
- def evaluate_model(
1275
- self,
1276
- model,
1277
- num_classes,
1278
- id_class_dict,
1279
- eval_data,
1280
- predict=False,
1281
- output_directory=None,
1282
- output_prefix=None,
1283
- ):
1284
- """
1285
- Evaluate the fine-tuned model.
1286
-
1287
- **Parameters**
1288
-
1289
- model : nn.Module
1290
- | Loaded fine-tuned model (e.g. trainer.model)
1291
- num_classes : int
1292
- | Number of classes for classifier
1293
- id_class_dict : dict
1294
- | Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
1295
- | (dictionary of format: numerical IDs: class_labels)
1296
- eval_data : Dataset
1297
- | Loaded evaluation .dataset input
1298
- predict : bool
1299
- | Whether or not to save eval predictions
1300
- output_directory : Path
1301
- | Path to directory where eval data will be saved
1302
- output_prefix : str
1303
- | Prefix for output files
1304
- """
1305
-
1306
- ##### Evaluate the model #####
1307
- labels = id_class_dict.keys()
1308
- y_pred, y_true, logits_list = eu.classifier_predict(
1309
- model, self.classifier, eval_data, self.forward_batch_size
1310
- )
1311
- conf_mat, macro_f1, acc, roc_metrics = eu.get_metrics(
1312
- y_pred, y_true, logits_list, num_classes, labels
1313
- )
1314
- if predict is True:
1315
- pred_dict = {
1316
- "pred_ids": y_pred,
1317
- "label_ids": y_true,
1318
- "predictions": logits_list,
1319
- }
1320
- pred_dict_output_path = (
1321
- Path(output_directory) / f"{output_prefix}_pred_dict"
1322
- ).with_suffix(".pkl")
1323
- with open(pred_dict_output_path, "wb") as f:
1324
- pickle.dump(pred_dict, f)
1325
- return {
1326
- "conf_mat": conf_mat,
1327
- "macro_f1": macro_f1,
1328
- "acc": acc,
1329
- "roc_metrics": roc_metrics,
1330
- }
1331
-
1332
- def evaluate_saved_model(
1333
- self,
1334
- model_directory,
1335
- id_class_dict_file,
1336
- test_data_file,
1337
- output_directory,
1338
- output_prefix,
1339
- predict=True,
1340
- ):
1341
- """
1342
- Evaluate the fine-tuned model.
1343
-
1344
- **Parameters**
1345
-
1346
- model_directory : Path
1347
- | Path to directory containing model
1348
- id_class_dict_file : Path
1349
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1350
- | (dictionary of format: numerical IDs: class_labels)
1351
- test_data_file : Path
1352
- | Path to directory containing test .dataset
1353
- output_directory : Path
1354
- | Path to directory where eval data will be saved
1355
- output_prefix : str
1356
- | Prefix for output files
1357
- predict : bool
1358
- | Whether or not to save eval predictions
1359
- """
1360
-
1361
- # load numerical id to class dictionary (id:class)
1362
- with open(id_class_dict_file, "rb") as f:
1363
- id_class_dict = pickle.load(f)
1364
-
1365
- # get number of classes for classifier
1366
- num_classes = cu.get_num_classes(id_class_dict)
1367
-
1368
- # load previously filtered and prepared data
1369
- test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1370
-
1371
- # load previously fine-tuned model
1372
- model = pu.load_model(
1373
- self.model_type,
1374
- num_classes,
1375
- model_directory,
1376
- "eval",
1377
- quantize=self.quantize,
1378
- )
1379
-
1380
- # evaluate the model
1381
- result = self.evaluate_model(
1382
- model,
1383
- num_classes,
1384
- id_class_dict,
1385
- test_data,
1386
- predict=predict,
1387
- output_directory=output_directory,
1388
- output_prefix=output_prefix,
1389
- )
1390
-
1391
- all_conf_mat_df = pd.DataFrame(
1392
- result["conf_mat"],
1393
- columns=id_class_dict.values(),
1394
- index=id_class_dict.values(),
1395
- )
1396
- all_metrics = {
1397
- "conf_matrix": all_conf_mat_df,
1398
- "macro_f1": result["macro_f1"],
1399
- "acc": result["acc"],
1400
- }
1401
- all_roc_metrics = None # roc metrics not reported for multiclass
1402
-
1403
- if num_classes == 2:
1404
- mean_fpr = np.linspace(0, 1, 100)
1405
- mean_tpr = result["roc_metrics"]["interp_tpr"]
1406
- all_roc_auc = result["roc_metrics"]["auc"]
1407
- all_roc_metrics = {
1408
- "mean_tpr": mean_tpr,
1409
- "mean_fpr": mean_fpr,
1410
- "all_roc_auc": all_roc_auc,
1411
- }
1412
- all_metrics["all_roc_metrics"] = all_roc_metrics
1413
- test_metrics_output_path = (
1414
- Path(output_directory) / f"{output_prefix}_test_metrics_dict"
1415
- ).with_suffix(".pkl")
1416
- with open(test_metrics_output_path, "wb") as f:
1417
- pickle.dump(all_metrics, f)
1418
-
1419
- return all_metrics
1420
-
1421
- def plot_conf_mat(
1422
- self,
1423
- conf_mat_dict,
1424
- output_directory,
1425
- output_prefix,
1426
- custom_class_order=None,
1427
- ):
1428
- """
1429
- Plot confusion matrix results of evaluating the fine-tuned model.
1430
-
1431
- **Parameters**
1432
-
1433
- conf_mat_dict : dict
1434
- | Dictionary of model_name : confusion_matrix_DataFrame
1435
- | (all_metrics["conf_matrix"] from self.validate)
1436
- output_directory : Path
1437
- | Path to directory where plots will be saved
1438
- output_prefix : str
1439
- | Prefix for output file
1440
- custom_class_order : None, list
1441
- | List of classes in custom order for plots.
1442
- | Same order will be used for all models.
1443
- """
1444
-
1445
- for model_name in conf_mat_dict.keys():
1446
- eu.plot_confusion_matrix(
1447
- conf_mat_dict[model_name],
1448
- model_name,
1449
- output_directory,
1450
- output_prefix,
1451
- custom_class_order,
1452
- )
1453
-
1454
- def plot_roc(
1455
- self,
1456
- roc_metric_dict,
1457
- model_style_dict,
1458
- title,
1459
- output_directory,
1460
- output_prefix,
1461
- ):
1462
- """
1463
- Plot ROC curve results of evaluating the fine-tuned model.
1464
-
1465
- **Parameters**
1466
-
1467
- roc_metric_dict : dict
1468
- | Dictionary of model_name : roc_metrics
1469
- | (all_metrics["all_roc_metrics"] from self.validate)
1470
- model_style_dict : dict[dict]
1471
- | Dictionary of model_name : dictionary of style_attribute : style
1472
- | where style includes color and linestyle
1473
- | e.g. {'Model_A': {'color': 'black', 'linestyle': '-'}, 'Model_B': ...}
1474
- title : str
1475
- | Title of plot (e.g. 'Dosage-sensitive vs -insensitive factors')
1476
- output_directory : Path
1477
- | Path to directory where plots will be saved
1478
- output_prefix : str
1479
- | Prefix for output file
1480
- """
1481
-
1482
- eu.plot_ROC(
1483
- roc_metric_dict, model_style_dict, title, output_directory, output_prefix
1484
- )
1485
-
1486
- def plot_predictions(
1487
- self,
1488
- predictions_file,
1489
- id_class_dict_file,
1490
- title,
1491
- output_directory,
1492
- output_prefix,
1493
- custom_class_order=None,
1494
- kwargs_dict=None,
1495
- ):
1496
- """
1497
- Plot prediction results of evaluating the fine-tuned model.
1498
-
1499
- **Parameters**
1500
-
1501
- predictions_file : path
1502
- | Path of model predictions output to plot
1503
- | (saved output from self.validate if predict_eval=True)
1504
- | (or saved output from self.evaluate_saved_model)
1505
- id_class_dict_file : Path
1506
- | Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
1507
- | (dictionary of format: numerical IDs: class_labels)
1508
- title : str
1509
- | Title for legend containing class labels.
1510
- output_directory : Path
1511
- | Path to directory where plots will be saved
1512
- output_prefix : str
1513
- | Prefix for output file
1514
- custom_class_order : None, list
1515
- | List of classes in custom order for plots.
1516
- | Same order will be used for all models.
1517
- kwargs_dict : None, dict
1518
- | Dictionary of kwargs to pass to plotting function.
1519
- """
1520
- # load predictions
1521
- with open(predictions_file, "rb") as f:
1522
- predictions = pickle.load(f)
1523
-
1524
- # load numerical id to class dictionary (id:class)
1525
- with open(id_class_dict_file, "rb") as f:
1526
- id_class_dict = pickle.load(f)
1527
-
1528
- if isinstance(predictions, dict):
1529
- if all(
1530
- [
1531
- key in predictions.keys()
1532
- for key in ["pred_ids", "label_ids", "predictions"]
1533
- ]
1534
- ):
1535
- # format is output from self.evaluate_saved_model
1536
- predictions_logits = np.array(predictions["predictions"])
1537
- true_ids = predictions["label_ids"]
1538
- else:
1539
- # format is output from self.validate if predict_eval=True
1540
- predictions_logits = predictions.predictions
1541
- true_ids = predictions.label_ids
1542
-
1543
- num_classes = len(id_class_dict.keys())
1544
- num_predict_classes = predictions_logits.shape[1]
1545
- assert num_classes == num_predict_classes
1546
- classes = id_class_dict.values()
1547
- true_labels = [id_class_dict[idx] for idx in true_ids]
1548
- predictions_df = pd.DataFrame(predictions_logits, columns=classes)
1549
- if custom_class_order is not None:
1550
- predictions_df = predictions_df.reindex(columns=custom_class_order)
1551
- predictions_df["true"] = true_labels
1552
- custom_dict = dict(zip(classes, [i for i in range(len(classes))]))
1553
- if custom_class_order is not None:
1554
- custom_dict = dict(
1555
- zip(custom_class_order, [i for i in range(len(custom_class_order))])
1556
- )
1557
- predictions_df = predictions_df.sort_values(
1558
- by=["true"], key=lambda x: x.map(custom_dict)
1559
- )
1560
-
1561
- eu.plot_predictions(
1562
- predictions_df, title, output_directory, output_prefix, kwargs_dict
1563
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/classifier_utils.py DELETED
@@ -1,648 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import random
5
- from collections import Counter, defaultdict
6
-
7
- import numpy as np
8
- import pandas as pd
9
- from scipy.stats import chisquare, ranksums
10
- from sklearn.metrics import accuracy_score, f1_score
11
- from sklearn.model_selection import StratifiedKFold, train_test_split
12
-
13
- from . import perturber_utils as pu
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- def downsample_and_shuffle(data, max_ncells, max_ncells_per_class, cell_state_dict):
19
- data = data.shuffle(seed=42)
20
- num_cells = len(data)
21
- # if max number of cells is defined, then subsample to this max number
22
- if max_ncells is not None:
23
- if num_cells > max_ncells:
24
- data = data.select([i for i in range(max_ncells)])
25
- if max_ncells_per_class is not None:
26
- class_labels = data[cell_state_dict["state_key"]]
27
- random.seed(42)
28
- subsample_indices = subsample_by_class(class_labels, max_ncells_per_class)
29
- data = data.select(subsample_indices)
30
- return data
31
-
32
-
33
- # subsample labels to maximum number N per class and return indices
34
- def subsample_by_class(labels, N):
35
- label_indices = defaultdict(list)
36
- # Gather indices for each label
37
- for idx, label in enumerate(labels):
38
- label_indices[label].append(idx)
39
- selected_indices = []
40
- # Select up to N indices for each label
41
- for label, indices in label_indices.items():
42
- if len(indices) > N:
43
- selected_indices.extend(random.sample(indices, N))
44
- else:
45
- selected_indices.extend(indices)
46
- return selected_indices
47
-
48
-
49
- def rename_cols(data, state_key):
50
- data = data.rename_column(state_key, "label")
51
- return data
52
-
53
-
54
- def validate_and_clean_cols(train_data, eval_data, classifier):
55
- # validate that data has expected label column and remove others
56
- if classifier == "cell":
57
- label_col = "label"
58
- elif classifier == "gene":
59
- label_col = "labels"
60
-
61
- cols_to_keep = [label_col] + ["input_ids", "length"]
62
- if label_col not in train_data.column_names:
63
- logger.error(f"train_data must contain column {label_col} with class labels.")
64
- raise
65
- else:
66
- train_data = remove_cols(train_data, cols_to_keep)
67
-
68
- if eval_data is not None:
69
- if label_col not in eval_data.column_names:
70
- logger.error(
71
- f"eval_data must contain column {label_col} with class labels."
72
- )
73
- raise
74
- else:
75
- eval_data = remove_cols(eval_data, cols_to_keep)
76
- return train_data, eval_data
77
-
78
-
79
- def remove_cols(data, cols_to_keep):
80
- other_cols = list(data.features.keys())
81
- other_cols = [ele for ele in other_cols if ele not in cols_to_keep]
82
- data = data.remove_columns(other_cols)
83
- return data
84
-
85
-
86
- def remove_rare(data, rare_threshold, label, nproc):
87
- if rare_threshold > 0:
88
- total_cells = len(data)
89
- label_counter = Counter(data[label])
90
- nonrare_label_dict = {
91
- label: [k for k, v in label_counter if (v / total_cells) > rare_threshold]
92
- }
93
- data = pu.filter_by_dict(data, nonrare_label_dict, nproc)
94
- return data
95
-
96
-
97
- def label_classes(classifier, data, gene_class_dict, nproc):
98
- if classifier == "cell":
99
- label_set = set(data["label"])
100
- elif classifier == "gene":
101
- # remove cells without any of the target genes
102
- def if_contains_label(example):
103
- a = pu.flatten_list(gene_class_dict.values())
104
- b = example["input_ids"]
105
- return not set(a).isdisjoint(b)
106
-
107
- data = data.filter(if_contains_label, num_proc=nproc)
108
- label_set = gene_class_dict.keys()
109
-
110
- if len(data) == 0:
111
- logger.error(
112
- "No cells remain after filtering for target genes. Check target gene list."
113
- )
114
- raise
115
-
116
- class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))
117
- id_class_dict = {v: k for k, v in class_id_dict.items()}
118
-
119
- def classes_to_ids(example):
120
- if classifier == "cell":
121
- example["label"] = class_id_dict[example["label"]]
122
- elif classifier == "gene":
123
- example["labels"] = label_gene_classes(
124
- example, class_id_dict, gene_class_dict
125
- )
126
- return example
127
-
128
- data = data.map(classes_to_ids, num_proc=nproc)
129
- return data, id_class_dict
130
-
131
-
132
- def label_gene_classes(example, class_id_dict, gene_class_dict):
133
- return [
134
- class_id_dict.get(gene_class_dict.get(token_id, -100), -100)
135
- for token_id in example["input_ids"]
136
- ]
137
-
138
-
139
- def prep_gene_classifier_train_eval_split(
140
- data,
141
- targets,
142
- labels,
143
- train_index,
144
- eval_index,
145
- max_ncells,
146
- iteration_num,
147
- num_proc,
148
- balance=False,
149
- ):
150
- # generate cross-validation splits
151
- train_data = prep_gene_classifier_split(
152
- data,
153
- targets,
154
- labels,
155
- train_index,
156
- "train",
157
- max_ncells,
158
- iteration_num,
159
- num_proc,
160
- balance,
161
- )
162
- eval_data = prep_gene_classifier_split(
163
- data,
164
- targets,
165
- labels,
166
- eval_index,
167
- "eval",
168
- max_ncells,
169
- iteration_num,
170
- num_proc,
171
- balance,
172
- )
173
- return train_data, eval_data
174
-
175
-
176
- def prep_gene_classifier_split(
177
- data,
178
- targets,
179
- labels,
180
- index,
181
- subset_name,
182
- max_ncells,
183
- iteration_num,
184
- num_proc,
185
- balance=False,
186
- ):
187
- # generate cross-validation splits
188
- targets = np.array(targets)
189
- labels = np.array(labels)
190
- targets_subset = targets[index]
191
- labels_subset = labels[index]
192
- label_dict_subset = dict(zip(targets_subset, labels_subset))
193
-
194
- # function to filter by whether contains train or eval labels
195
- def if_contains_subset_label(example):
196
- a = targets_subset
197
- b = example["input_ids"]
198
- return not set(a).isdisjoint(b)
199
-
200
- # filter dataset for examples containing classes for this split
201
- logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
202
- subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
203
- logger.info(
204
- f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
205
- )
206
-
207
- # balance gene subsets if train
208
- if (subset_name == "train") and (balance is True):
209
- subset_data, label_dict_subset = balance_gene_split(
210
- subset_data, label_dict_subset, num_proc
211
- )
212
-
213
- # subsample to max_ncells
214
- subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
215
-
216
- # relabel genes for this split
217
- def subset_classes_to_ids(example):
218
- example["labels"] = [
219
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
220
- ]
221
- return example
222
-
223
- subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
224
-
225
- return subset_data
226
-
227
-
228
- def prep_gene_classifier_all_data(
229
- data, targets, labels, max_ncells, num_proc, balance=False
230
- ):
231
- targets = np.array(targets)
232
- labels = np.array(labels)
233
- label_dict_train = dict(zip(targets, labels))
234
-
235
- # function to filter by whether contains train labels
236
- def if_contains_train_label(example):
237
- a = targets
238
- b = example["input_ids"]
239
- return not set(a).isdisjoint(b)
240
-
241
- # filter dataset for examples containing classes for this split
242
- logger.info("Filtering training data for genes to classify.")
243
- train_data = data.filter(if_contains_train_label, num_proc=num_proc)
244
- logger.info(
245
- f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
246
- )
247
-
248
- if balance is True:
249
- train_data, label_dict_train = balance_gene_split(
250
- train_data, label_dict_train, num_proc
251
- )
252
-
253
- # subsample to max_ncells
254
- train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
255
-
256
- # relabel genes for this split
257
- def train_classes_to_ids(example):
258
- example["labels"] = [
259
- label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
260
- ]
261
- return example
262
-
263
- train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
264
-
265
- return train_data
266
-
267
-
268
- def balance_gene_split(subset_data, label_dict_subset, num_proc):
269
- # count occurrence of genes in each label category
270
- label0_counts, label1_counts = count_genes_for_balancing(
271
- subset_data, label_dict_subset, num_proc
272
- )
273
- label_ratio_0to1 = label0_counts / label1_counts
274
-
275
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
276
- # gene sets already balanced
277
- logger.info(
278
- "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n"
279
- )
280
- return subset_data, label_dict_subset
281
- else:
282
- label_ratio_0to1_orig = label_ratio_0to1 + 0
283
- label_dict_subset_orig = label_dict_subset.copy()
284
- # balance gene sets
285
- max_ntrials = 25
286
- boost = 1
287
- if label_ratio_0to1 > 10 / 8:
288
- # downsample label 0
289
- for i in range(max_ntrials):
290
- label0 = 0
291
- label0_genes = [k for k, v in label_dict_subset.items() if v == label0]
292
- label0_ngenes = len(label0_genes)
293
- label0_nremove = max(
294
- 1,
295
- int(
296
- np.floor(
297
- label0_ngenes - label0_ngenes / (label_ratio_0to1 * boost)
298
- )
299
- ),
300
- )
301
- random.seed(i)
302
- label0_remove_genes = random.sample(label0_genes, label0_nremove)
303
- label_dict_subset_new = {
304
- k: v
305
- for k, v in label_dict_subset.items()
306
- if k not in label0_remove_genes
307
- }
308
- label0_counts, label1_counts = count_genes_for_balancing(
309
- subset_data, label_dict_subset_new, num_proc
310
- )
311
- label_ratio_0to1 = label0_counts / label1_counts
312
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
313
- # if gene sets now balanced, return new filtered data and new label_dict_subset
314
- return filter_data_balanced_genes(
315
- subset_data, label_dict_subset_new, num_proc
316
- )
317
- elif label_ratio_0to1 > 10 / 8:
318
- boost = boost * 1.1
319
- elif label_ratio_0to1 < 8 / 10:
320
- boost = boost * 0.9
321
- else:
322
- # downsample label 1
323
- for i in range(max_ntrials):
324
- label1 = 1
325
- label1_genes = [k for k, v in label_dict_subset.items() if v == label1]
326
- label1_ngenes = len(label1_genes)
327
- label1_nremove = max(
328
- 1,
329
- int(
330
- np.floor(
331
- label1_ngenes
332
- - label1_ngenes / ((1 / label_ratio_0to1) * boost)
333
- )
334
- ),
335
- )
336
- random.seed(i)
337
- label1_remove_genes = random.sample(label1_genes, label1_nremove)
338
- label_dict_subset_new = {
339
- k: v
340
- for k, v in label_dict_subset.items()
341
- if k not in label1_remove_genes
342
- }
343
- label0_counts, label1_counts = count_genes_for_balancing(
344
- subset_data, label_dict_subset_new, num_proc
345
- )
346
- label_ratio_0to1 = label0_counts / label1_counts
347
- if 8 / 10 <= label_ratio_0to1 <= 10 / 8:
348
- # if gene sets now balanced, return new filtered data and new label_dict_subset
349
- return filter_data_balanced_genes(
350
- subset_data, label_dict_subset_new, num_proc
351
- )
352
- elif label_ratio_0to1 < 8 / 10:
353
- boost = boost * 1.1
354
- elif label_ratio_0to1 > 10 / 8:
355
- boost = boost * 0.9
356
-
357
- assert i + 1 == max_ntrials
358
- if (label_ratio_0to1 <= label_ratio_0to1_orig < 8 / 10) or (
359
- 10 / 8 > label_ratio_0to1_orig >= label_ratio_0to1
360
- ):
361
- label_ratio_0to1 = label_ratio_0to1_orig
362
- label_dict_subset_new = label_dict_subset_orig
363
- logger.warning(
364
- f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n"
365
- )
366
- return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc)
367
-
368
-
369
- def count_genes_for_balancing(subset_data, label_dict_subset, num_proc):
370
- def count_targets(example):
371
- labels = [
372
- label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
373
- ]
374
- counter_labels = Counter(labels)
375
- # get count of labels 0 or 1, or if absent, return 0
376
- example["labels_counts"] = [counter_labels.get(0, 0), counter_labels.get(1, 0)]
377
- return example
378
-
379
- subset_data = subset_data.map(count_targets, num_proc=num_proc)
380
-
381
- label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]])
382
- label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]])
383
-
384
- subset_data = subset_data.remove_columns("labels_counts")
385
-
386
- return label0_counts, label1_counts
387
-
388
-
389
- def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc):
390
- # function to filter by whether contains labels
391
- def if_contains_subset_label(example):
392
- a = list(label_dict_subset.keys())
393
- b = example["input_ids"]
394
- return not set(a).isdisjoint(b)
395
-
396
- # filter dataset for examples containing classes for this split
397
- logger.info("Filtering data for balanced genes")
398
- subset_data_len_orig = len(subset_data)
399
- subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc)
400
- logger.info(
401
- f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n"
402
- )
403
-
404
- return subset_data, label_dict_subset
405
-
406
-
407
- def balance_attr_splits(
408
- data,
409
- attr_to_split,
410
- attr_to_balance,
411
- eval_size,
412
- max_trials,
413
- pval_threshold,
414
- state_key,
415
- nproc,
416
- ):
417
- metadata_df = pd.DataFrame({"split_attr_ids": data[attr_to_split]})
418
- for attr in attr_to_balance:
419
- if attr == state_key:
420
- metadata_df[attr] = data["label"]
421
- else:
422
- metadata_df[attr] = data[attr]
423
- metadata_df = metadata_df.drop_duplicates()
424
-
425
- split_attr_ids = list(metadata_df["split_attr_ids"])
426
- assert len(split_attr_ids) == len(set(split_attr_ids))
427
- eval_num = round(len(split_attr_ids) * eval_size)
428
- colnames = (
429
- ["trial_num", "train_ids", "eval_ids"]
430
- + pu.flatten_list(
431
- [
432
- [
433
- f"{attr}_train_mean_or_counts",
434
- f"{attr}_eval_mean_or_counts",
435
- f"{attr}_pval",
436
- ]
437
- for attr in attr_to_balance
438
- ]
439
- )
440
- + ["mean_pval"]
441
- )
442
- balance_df = pd.DataFrame(columns=colnames)
443
- data_dict = dict()
444
- trial_num = 1
445
- for i in range(max_trials):
446
- if not all(
447
- count > 1 for count in list(Counter(metadata_df[state_key]).values())
448
- ):
449
- logger.error(
450
- f"Cannot balance by {attr_to_split} while retaining at least 1 occurrence of each {state_key} class in both data splits. "
451
- )
452
- raise
453
- eval_base = []
454
- for state in set(metadata_df[state_key]):
455
- eval_base += list(
456
- metadata_df.loc[
457
- metadata_df[state_key][metadata_df[state_key].eq(state)]
458
- .sample(1, random_state=i)
459
- .index
460
- ]["split_attr_ids"]
461
- )
462
- non_eval_base = [idx for idx in split_attr_ids if idx not in eval_base]
463
- random.seed(i)
464
- eval_ids = random.sample(non_eval_base, eval_num - len(eval_base)) + eval_base
465
- train_ids = [idx for idx in split_attr_ids if idx not in eval_ids]
466
- df_vals = [trial_num, train_ids, eval_ids]
467
- pvals = []
468
- for attr in attr_to_balance:
469
- train_attr = list(
470
- metadata_df[metadata_df["split_attr_ids"].isin(train_ids)][attr]
471
- )
472
- eval_attr = list(
473
- metadata_df[metadata_df["split_attr_ids"].isin(eval_ids)][attr]
474
- )
475
- if attr == state_key:
476
- # ensure IDs are interpreted as categorical
477
- train_attr = [str(item) for item in train_attr]
478
- eval_attr = [str(item) for item in eval_attr]
479
- if all(isinstance(item, (int, float)) for item in train_attr + eval_attr):
480
- train_attr_mean = np.nanmean(train_attr)
481
- eval_attr_mean = np.nanmean(eval_attr)
482
- pval = ranksums(train_attr, eval_attr, nan_policy="omit").pvalue
483
- df_vals += [train_attr_mean, eval_attr_mean, pval]
484
- elif all(isinstance(item, (str)) for item in train_attr + eval_attr):
485
- obs_counts = Counter(train_attr)
486
- exp_counts = Counter(eval_attr)
487
- all_categ = set(obs_counts.keys()).union(set(exp_counts.keys()))
488
- obs = [obs_counts[cat] for cat in all_categ]
489
- exp = [
490
- exp_counts[cat] * sum(obs) / sum(exp_counts.values())
491
- for cat in all_categ
492
- ]
493
- pval = chisquare(f_obs=obs, f_exp=exp).pvalue
494
- train_attr_counts = str(obs_counts).strip("Counter(").strip(")")
495
- eval_attr_counts = str(exp_counts).strip("Counter(").strip(")")
496
- df_vals += [train_attr_counts, eval_attr_counts, pval]
497
- else:
498
- logger.error(
499
- f"Inconsistent data types in attribute {attr}. "
500
- "Cannot infer if continuous or categorical. "
501
- "Must be all numeric (continuous) or all strings (categorical) to balance."
502
- )
503
- raise
504
- pvals += [pval]
505
-
506
- df_vals += [np.nanmean(pvals)]
507
- balance_df_i = pd.DataFrame(df_vals, index=colnames).T
508
- balance_df = pd.concat([balance_df, balance_df_i], ignore_index=True)
509
- valid_pvals = [
510
- pval_i
511
- for pval_i in pvals
512
- if isinstance(pval_i, (int, float)) and not np.isnan(pval_i)
513
- ]
514
- if all(i >= pval_threshold for i in valid_pvals):
515
- data_dict["train"] = pu.filter_by_dict(
516
- data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
517
- )
518
- data_dict["test"] = pu.filter_by_dict(
519
- data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
520
- )
521
- return data_dict, balance_df
522
- trial_num = trial_num + 1
523
- balance_max_df = balance_df.iloc[balance_df["mean_pval"].idxmax(), :]
524
- data_dict["train"] = pu.filter_by_dict(
525
- data, {attr_to_split: balance_df_i["train_ids"][0]}, nproc
526
- )
527
- data_dict["test"] = pu.filter_by_dict(
528
- data, {attr_to_split: balance_df_i["eval_ids"][0]}, nproc
529
- )
530
- logger.warning(
531
- f"No splits found without significant difference in attr_to_balance among {max_trials} trials. "
532
- f"Selecting optimal split (trial #{balance_max_df['trial_num']}) from completed trials."
533
- )
534
- return data_dict, balance_df
535
-
536
-
537
- def get_num_classes(id_class_dict):
538
- return len(set(id_class_dict.values()))
539
-
540
-
541
- def compute_metrics(pred):
542
- labels = pred.label_ids
543
- preds = pred.predictions.argmax(-1)
544
-
545
- # calculate accuracy and macro f1 using sklearn's function
546
- if len(labels.shape) == 1:
547
- acc = accuracy_score(labels, preds)
548
- macro_f1 = f1_score(labels, preds, average="macro")
549
- else:
550
- flat_labels = labels.flatten().tolist()
551
- flat_preds = preds.flatten().tolist()
552
- logit_label_paired = [
553
- item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
554
- ]
555
- y_pred = [item[0] for item in logit_label_paired]
556
- y_true = [item[1] for item in logit_label_paired]
557
-
558
- acc = accuracy_score(y_true, y_pred)
559
- macro_f1 = f1_score(y_true, y_pred, average="macro")
560
-
561
- return {"accuracy": acc, "macro_f1": macro_f1}
562
-
563
-
564
- def get_default_train_args(model, classifier, data, output_dir):
565
- num_layers = pu.quant_layers(model)
566
- freeze_layers = 0
567
- batch_size = 12
568
- if classifier == "cell":
569
- epochs = 10
570
- evaluation_strategy = "epoch"
571
- load_best_model_at_end = True
572
- else:
573
- epochs = 1
574
- evaluation_strategy = "no"
575
- load_best_model_at_end = False
576
-
577
- if num_layers == 6:
578
- default_training_args = {
579
- "learning_rate": 5e-5,
580
- "lr_scheduler_type": "linear",
581
- "warmup_steps": 500,
582
- "per_device_train_batch_size": batch_size,
583
- "per_device_eval_batch_size": batch_size,
584
- }
585
- else:
586
- default_training_args = {
587
- "per_device_train_batch_size": batch_size,
588
- "per_device_eval_batch_size": batch_size,
589
- }
590
-
591
- training_args = {
592
- "num_train_epochs": epochs,
593
- "do_train": True,
594
- "do_eval": True,
595
- "evaluation_strategy": evaluation_strategy,
596
- "logging_steps": np.floor(len(data) / batch_size / 8), # 8 evals per epoch
597
- "save_strategy": "epoch",
598
- "group_by_length": False,
599
- "length_column_name": "length",
600
- "disable_tqdm": False,
601
- "weight_decay": 0.001,
602
- "load_best_model_at_end": load_best_model_at_end,
603
- }
604
- training_args.update(default_training_args)
605
-
606
- return training_args, freeze_layers
607
-
608
-
609
- def load_best_model(directory, model_type, num_classes, mode="eval"):
610
- file_dict = dict()
611
- for subdir, dirs, files in os.walk(directory):
612
- for file in files:
613
- if file.endswith("result.json"):
614
- with open(f"{subdir}/{file}", "rb") as fp:
615
- result_json = json.load(fp)
616
- file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
617
- file_df = pd.DataFrame(
618
- {"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
619
- )
620
- model_superdir = (
621
- "run-"
622
- + file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
623
- .split("_objective_")[2]
624
- .split("_")[0]
625
- )
626
-
627
- for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
628
- for file in files:
629
- if file.endswith("model.safetensors"):
630
- model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
631
- return model
632
-
633
-
634
- class StratifiedKFold3(StratifiedKFold):
635
- def split(self, targets, labels, test_ratio=0.5, groups=None):
636
- s = super().split(targets, labels, groups)
637
- for train_indxs, test_indxs in s:
638
- if test_ratio == 0:
639
- yield train_indxs, test_indxs, None
640
- else:
641
- labels_test = np.array(labels)[test_indxs]
642
- valid_indxs, test_indxs = train_test_split(
643
- test_indxs,
644
- stratify=labels_test,
645
- test_size=test_ratio,
646
- random_state=0,
647
- )
648
- yield train_indxs, valid_indxs, test_indxs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/collator_for_classification.py DELETED
@@ -1,667 +0,0 @@
1
- """
2
- Geneformer collator for gene and cell classification.
3
- Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
4
- """
5
-
6
- import warnings
7
- from enum import Enum
8
- from typing import Dict, List, Optional, Union
9
-
10
- import numpy as np
11
- import torch
12
- from transformers import (
13
- BatchEncoding,
14
- DataCollatorForTokenClassification,
15
- SpecialTokensMixin,
16
- )
17
- from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
18
- from transformers.utils.generic import _is_tensorflow, _is_torch
19
-
20
- EncodedInput = List[int]
21
- logger = logging.get_logger(__name__)
22
- VERY_LARGE_INTEGER = int(
23
- 1e30
24
- ) # This is used to set the max input length for a model with infinite size input
25
- LARGE_INTEGER = int(
26
- 1e20
27
- ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
28
-
29
- # precollator functions
30
-
31
-
32
- class ExplicitEnum(Enum):
33
- """
34
- Enum with more explicit error message for missing values.
35
- """
36
-
37
- @classmethod
38
- def _missing_(cls, value):
39
- raise ValueError(
40
- "%r is not a valid %s, please select one of %s"
41
- % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
42
- )
43
-
44
-
45
- class TruncationStrategy(ExplicitEnum):
46
- """
47
- Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
48
- tab-completion in an IDE.
49
- """
50
-
51
- ONLY_FIRST = "only_first"
52
- ONLY_SECOND = "only_second"
53
- LONGEST_FIRST = "longest_first"
54
- DO_NOT_TRUNCATE = "do_not_truncate"
55
-
56
-
57
- class PaddingStrategy(ExplicitEnum):
58
- """
59
- Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
60
- in an IDE.
61
- """
62
-
63
- LONGEST = "longest"
64
- MAX_LENGTH = "max_length"
65
- DO_NOT_PAD = "do_not_pad"
66
-
67
-
68
- class TensorType(ExplicitEnum):
69
- """
70
- Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
71
- tab-completion in an IDE.
72
- """
73
-
74
- PYTORCH = "pt"
75
- TENSORFLOW = "tf"
76
- NUMPY = "np"
77
- JAX = "jax"
78
-
79
-
80
- class PrecollatorForGeneAndCellClassification(SpecialTokensMixin):
81
- def __init__(self, *args, **kwargs) -> None:
82
- super().__init__(mask_token="<mask>", pad_token="<pad>")
83
-
84
- self.token_dictionary = kwargs.get("token_dictionary")
85
- self.padding_side = "right"
86
- self.model_input_names = ["input_ids"]
87
- self._mask_token_id = self.token_dictionary.get("<mask>")
88
- self._pad_token_id = self.token_dictionary.get("<pad>")
89
- self._all_special_ids = [
90
- self.token_dictionary.get("<mask>"),
91
- self.token_dictionary.get("<pad>"),
92
- ]
93
-
94
- @property
95
- def all_special_ids(self):
96
- return self._all_special_ids
97
-
98
- @property
99
- def mask_token_id(self):
100
- return self._mask_token_id
101
-
102
- @property
103
- def pad_token_id(self):
104
- return self._pad_token_id
105
-
106
- def _get_padding_truncation_strategies(
107
- self,
108
- padding=True,
109
- truncation=False,
110
- max_length=None,
111
- pad_to_multiple_of=None,
112
- verbose=True,
113
- **kwargs,
114
- ):
115
- """
116
- Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
117
- and pad_to_max_length) and behaviors.
118
- """
119
- old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
120
- old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
121
-
122
- # Backward compatibility for previous behavior, maybe we should deprecate it:
123
- # If you only set max_length, it activates truncation for max_length
124
- if max_length is not None and padding is False and truncation is False:
125
- if verbose:
126
- if not self.deprecation_warnings.get(
127
- "Truncation-not-explicitly-activated", False
128
- ):
129
- logger.warning(
130
- "Truncation was not explicitly activated but `max_length` is provided a specific value, "
131
- "please use `truncation=True` to explicitly truncate examples to max length. "
132
- "Defaulting to 'longest_first' truncation strategy. "
133
- "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
134
- "more precisely by providing a specific strategy to `truncation`."
135
- )
136
- self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
137
- truncation = "longest_first"
138
-
139
- # Get padding strategy
140
- if padding is False and old_pad_to_max_length:
141
- if verbose:
142
- warnings.warn(
143
- "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
144
- "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
145
- "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
146
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
147
- "maximal input size of the model (e.g. 512 for Bert).",
148
- FutureWarning,
149
- )
150
- if max_length is None:
151
- padding_strategy = PaddingStrategy.LONGEST
152
- else:
153
- padding_strategy = PaddingStrategy.MAX_LENGTH
154
- elif padding is not False:
155
- if padding is True:
156
- padding_strategy = (
157
- PaddingStrategy.LONGEST
158
- ) # Default to pad to the longest sequence in the batch
159
- elif not isinstance(padding, PaddingStrategy):
160
- padding_strategy = PaddingStrategy(padding)
161
- elif isinstance(padding, PaddingStrategy):
162
- padding_strategy = padding
163
- else:
164
- padding_strategy = PaddingStrategy.DO_NOT_PAD
165
-
166
- # Get truncation strategy
167
- if truncation is False and old_truncation_strategy != "do_not_truncate":
168
- if verbose:
169
- warnings.warn(
170
- "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
171
- "use `truncation=True` to truncate examples to a max length. You can give a specific "
172
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
173
- "maximal input size of the model (e.g. 512 for Bert). "
174
- " If you have pairs of inputs, you can give a specific truncation strategy selected among "
175
- "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
176
- "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
177
- "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
178
- FutureWarning,
179
- )
180
- truncation_strategy = TruncationStrategy(old_truncation_strategy)
181
- elif truncation is not False:
182
- if truncation is True:
183
- truncation_strategy = (
184
- TruncationStrategy.LONGEST_FIRST
185
- ) # Default to truncate the longest sequences in pairs of inputs
186
- elif not isinstance(truncation, TruncationStrategy):
187
- truncation_strategy = TruncationStrategy(truncation)
188
- elif isinstance(truncation, TruncationStrategy):
189
- truncation_strategy = truncation
190
- else:
191
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
192
-
193
- # Set max length if needed
194
- if max_length is None:
195
- if padding_strategy == PaddingStrategy.MAX_LENGTH:
196
- if self.model_max_length > LARGE_INTEGER:
197
- if verbose:
198
- if not self.deprecation_warnings.get(
199
- "Asking-to-pad-to-max_length", False
200
- ):
201
- logger.warning(
202
- "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
203
- "Default to no padding."
204
- )
205
- self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
206
- padding_strategy = PaddingStrategy.DO_NOT_PAD
207
- else:
208
- max_length = self.model_max_length
209
-
210
- if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
211
- if self.model_max_length > LARGE_INTEGER:
212
- if verbose:
213
- if not self.deprecation_warnings.get(
214
- "Asking-to-truncate-to-max_length", False
215
- ):
216
- logger.warning(
217
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
218
- "Default to no truncation."
219
- )
220
- self.deprecation_warnings[
221
- "Asking-to-truncate-to-max_length"
222
- ] = True
223
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
224
- else:
225
- max_length = self.model_max_length
226
-
227
- # Test if we have a padding token
228
- if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
229
- not self.pad_token or self.pad_token_id < 0
230
- ):
231
- raise ValueError(
232
- "Asking to pad but the tokenizer does not have a padding token. "
233
- "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
234
- "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
235
- )
236
-
237
- # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
238
- if (
239
- truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
240
- and padding_strategy != PaddingStrategy.DO_NOT_PAD
241
- and pad_to_multiple_of is not None
242
- and max_length is not None
243
- and (max_length % pad_to_multiple_of != 0)
244
- ):
245
- raise ValueError(
246
- f"Truncation and padding are both activated but "
247
- f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
248
- )
249
-
250
- return padding_strategy, truncation_strategy, max_length, kwargs
251
-
252
- def pad(
253
- self,
254
- encoded_inputs: Union[
255
- BatchEncoding,
256
- List[BatchEncoding],
257
- Dict[str, EncodedInput],
258
- Dict[str, List[EncodedInput]],
259
- List[Dict[str, EncodedInput]],
260
- ],
261
- class_type, # options: "gene" or "cell"
262
- padding: Union[bool, str, PaddingStrategy] = True,
263
- max_length: Optional[int] = None,
264
- pad_to_multiple_of: Optional[int] = None,
265
- return_attention_mask: Optional[bool] = True,
266
- return_tensors: Optional[Union[str, TensorType]] = None,
267
- verbose: bool = True,
268
- ) -> BatchEncoding:
269
- """
270
- Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
271
- in the batch.
272
- Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
273
- ``self.pad_token_id`` and ``self.pad_token_type_id``)
274
- .. note::
275
- If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
276
- result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
277
- case of PyTorch tensors, you will lose the specific device of your tensors however.
278
- Args:
279
- encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
280
- Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
281
- List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
282
- List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
283
- well as in a PyTorch Dataloader collate function.
284
- Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
285
- see the note above for the return type.
286
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
287
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
288
- index) among:
289
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
290
- single sequence if provided).
291
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
292
- maximum acceptable input length for the model if that argument is not provided.
293
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
294
- different lengths).
295
- max_length (:obj:`int`, `optional`):
296
- Maximum length of the returned list and optionally padding length (see above).
297
- pad_to_multiple_of (:obj:`int`, `optional`):
298
- If set will pad the sequence to a multiple of the provided value.
299
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
300
- >= 7.5 (Volta).
301
- return_attention_mask (:obj:`bool`, `optional`):
302
- Whether to return the attention mask. If left to the default, will return the attention mask according
303
- to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
304
- `What are attention masks? <../glossary.html#attention-mask>`__
305
- return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
306
- If set, will return tensors instead of list of python integers. Acceptable values are:
307
- * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
308
- * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
309
- * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
310
- verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
311
- Whether or not to print more information and warnings.
312
- """
313
- # If we have a list of dicts, let's convert it in a dict of lists
314
- # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
315
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(
316
- encoded_inputs[0], (dict, BatchEncoding)
317
- ):
318
- encoded_inputs = {
319
- key: [example[key] for example in encoded_inputs]
320
- for key in encoded_inputs[0].keys()
321
- }
322
-
323
- # The model's main input name, usually `input_ids`, has be passed for padding
324
- if self.model_input_names[0] not in encoded_inputs:
325
- raise ValueError(
326
- "You should supply an encoding or a list of encodings to this method"
327
- f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
328
- )
329
-
330
- required_input = encoded_inputs[self.model_input_names[0]]
331
-
332
- if not required_input:
333
- if return_attention_mask:
334
- encoded_inputs["attention_mask"] = []
335
- return encoded_inputs
336
-
337
- # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
338
- # and rebuild them afterwards if no return_tensors is specified
339
- # Note that we lose the specific device the tensor may be on for PyTorch
340
-
341
- first_element = required_input[0]
342
- if isinstance(first_element, (list, tuple)):
343
- # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
344
- index = 0
345
- while len(required_input[index]) == 0:
346
- index += 1
347
- if index < len(required_input):
348
- first_element = required_input[index][0]
349
- # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
350
- if not isinstance(first_element, (int, list, tuple)):
351
- if is_tf_available() and _is_tensorflow(first_element):
352
- return_tensors = "tf" if return_tensors is None else return_tensors
353
- elif is_torch_available() and _is_torch(first_element):
354
- return_tensors = "pt" if return_tensors is None else return_tensors
355
- elif isinstance(first_element, np.ndarray):
356
- return_tensors = "np" if return_tensors is None else return_tensors
357
- else:
358
- raise ValueError(
359
- f"type of {first_element} unknown: {type(first_element)}. "
360
- f"Should be one of a python, numpy, pytorch or tensorflow object."
361
- )
362
-
363
- for key, value in encoded_inputs.items():
364
- encoded_inputs[key] = to_py_obj(value)
365
-
366
- # Convert padding_strategy in PaddingStrategy
367
- padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
368
- padding=padding, max_length=max_length, verbose=verbose
369
- )
370
-
371
- required_input = encoded_inputs[self.model_input_names[0]]
372
- if required_input and not isinstance(required_input[0], (list, tuple)):
373
- encoded_inputs = self._pad(
374
- encoded_inputs,
375
- class_type=class_type,
376
- max_length=max_length,
377
- padding_strategy=padding_strategy,
378
- pad_to_multiple_of=pad_to_multiple_of,
379
- return_attention_mask=return_attention_mask,
380
- )
381
- return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
382
-
383
- batch_size = len(required_input)
384
- assert all(
385
- len(v) == batch_size for v in encoded_inputs.values()
386
- ), "Some items in the output dictionary have a different batch size than others."
387
-
388
- if padding_strategy == PaddingStrategy.LONGEST:
389
- max_length = max(len(inputs) for inputs in required_input)
390
- padding_strategy = PaddingStrategy.MAX_LENGTH
391
-
392
- batch_outputs = {}
393
- for i in range(batch_size):
394
- inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
395
- outputs = self._pad(
396
- inputs,
397
- class_type=class_type,
398
- max_length=max_length,
399
- padding_strategy=padding_strategy,
400
- pad_to_multiple_of=pad_to_multiple_of,
401
- return_attention_mask=return_attention_mask,
402
- )
403
-
404
- for key, value in outputs.items():
405
- if key not in batch_outputs:
406
- batch_outputs[key] = []
407
- batch_outputs[key].append(value)
408
- if class_type == "cell":
409
- del batch_outputs["label"]
410
- return BatchEncoding(batch_outputs, tensor_type=return_tensors)
411
-
412
- def _pad(
413
- self,
414
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
415
- class_type, # options: "gene" or "cell"
416
- max_length: Optional[int] = None,
417
- padding_strategy: PaddingStrategy = PaddingStrategy.LONGEST,
418
- pad_to_multiple_of: Optional[int] = None,
419
- return_attention_mask: Optional[bool] = True,
420
- ) -> dict:
421
- """
422
- Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
423
- Args:
424
- encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
425
- max_length: maximum length of the returned list and optionally padding length (see below).
426
- Will truncate by taking into account the special tokens.
427
- padding_strategy: PaddingStrategy to use for padding.
428
- - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
429
- - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
430
- - PaddingStrategy.DO_NOT_PAD: Do not pad
431
- The tokenizer padding sides are defined in self.padding_side:
432
- - 'left': pads on the left of the sequences
433
- - 'right': pads on the right of the sequences
434
- pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
435
- This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
436
- >= 7.5 (Volta).
437
- return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
438
- """
439
- # Load from model defaults
440
- if return_attention_mask is None:
441
- return_attention_mask = "attention_mask" in self.model_input_names
442
-
443
- required_input = encoded_inputs[self.model_input_names[0]]
444
-
445
- if padding_strategy == PaddingStrategy.LONGEST:
446
- max_length = len(required_input)
447
-
448
- if (
449
- max_length is not None
450
- and pad_to_multiple_of is not None
451
- and (max_length % pad_to_multiple_of != 0)
452
- ):
453
- max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
454
-
455
- needs_to_be_padded = (
456
- padding_strategy != PaddingStrategy.DO_NOT_PAD
457
- and len(required_input) != max_length
458
- )
459
-
460
- if needs_to_be_padded:
461
- difference = max_length - len(required_input)
462
- if self.padding_side == "right":
463
- if return_attention_mask:
464
- encoded_inputs["attention_mask"] = [1] * len(required_input) + [
465
- 0
466
- ] * difference
467
- if "token_type_ids" in encoded_inputs:
468
- encoded_inputs["token_type_ids"] = (
469
- encoded_inputs["token_type_ids"]
470
- + [self.pad_token_type_id] * difference
471
- )
472
- if "special_tokens_mask" in encoded_inputs:
473
- encoded_inputs["special_tokens_mask"] = (
474
- encoded_inputs["special_tokens_mask"] + [1] * difference
475
- )
476
- encoded_inputs[self.model_input_names[0]] = (
477
- required_input + [self.pad_token_id] * difference
478
- )
479
- if class_type == "gene":
480
- encoded_inputs["labels"] = (
481
- encoded_inputs["labels"] + [-100] * difference
482
- )
483
- elif self.padding_side == "left":
484
- if return_attention_mask:
485
- encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
486
- required_input
487
- )
488
- if "token_type_ids" in encoded_inputs:
489
- encoded_inputs["token_type_ids"] = [
490
- self.pad_token_type_id
491
- ] * difference + encoded_inputs["token_type_ids"]
492
- if "special_tokens_mask" in encoded_inputs:
493
- encoded_inputs["special_tokens_mask"] = [
494
- 1
495
- ] * difference + encoded_inputs["special_tokens_mask"]
496
- encoded_inputs[self.model_input_names[0]] = [
497
- self.pad_token_id
498
- ] * difference + required_input
499
- if class_type == "gene":
500
- encoded_inputs["labels"] = [-100] * difference + encoded_inputs[
501
- "labels"
502
- ]
503
- else:
504
- raise ValueError("Invalid padding strategy:" + str(self.padding_side))
505
- elif return_attention_mask and "attention_mask" not in encoded_inputs:
506
- encoded_inputs["attention_mask"] = [1] * len(required_input)
507
-
508
- return encoded_inputs
509
-
510
- def get_special_tokens_mask(
511
- self,
512
- token_ids_0: List[int],
513
- token_ids_1: Optional[List[int]] = None,
514
- already_has_special_tokens: bool = False,
515
- ) -> List[int]:
516
- """
517
- Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
518
- special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
519
- Args:
520
- token_ids_0 (:obj:`List[int]`):
521
- List of ids of the first sequence.
522
- token_ids_1 (:obj:`List[int]`, `optional`):
523
- List of ids of the second sequence.
524
- already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
525
- Whether or not the token list is already formatted with special tokens for the model.
526
- Returns:
527
- A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
528
- """
529
- assert already_has_special_tokens and token_ids_1 is None, (
530
- "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
531
- "Please use a slow (full python) tokenizer to activate this argument."
532
- "Or set `return_special_tokens_mask=True` when calling the encoding method "
533
- "to get the special tokens mask in any tokenizer. "
534
- )
535
-
536
- all_special_ids = self.all_special_ids # cache the property
537
-
538
- special_tokens_mask = [
539
- 1 if token in all_special_ids else 0 for token in token_ids_0
540
- ]
541
-
542
- return special_tokens_mask
543
-
544
- def convert_tokens_to_ids(
545
- self, tokens: Union[str, List[str]]
546
- ) -> Union[int, List[int]]:
547
- """
548
- Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
549
- vocabulary.
550
- Args:
551
- tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
552
- Returns:
553
- :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
554
- """
555
- if tokens is None:
556
- return None
557
-
558
- if isinstance(tokens, str):
559
- return self._convert_token_to_id_with_added_voc(tokens)
560
-
561
- ids = []
562
- for token in tokens:
563
- ids.append(self._convert_token_to_id_with_added_voc(token))
564
- return ids
565
-
566
- def _convert_token_to_id_with_added_voc(self, token):
567
- if token is None:
568
- return None
569
-
570
- return self.token_dictionary.get(token)
571
-
572
- def __len__(self):
573
- return len(self.token_dictionary)
574
-
575
-
576
- # collator functions
577
-
578
-
579
- class DataCollatorForGeneClassification(DataCollatorForTokenClassification):
580
- """
581
- Data collator that will dynamically pad the inputs received, as well as the labels.
582
- Args:
583
- tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
584
- The tokenizer used for encoding the data.
585
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
586
- Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
587
- among:
588
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
589
- sequence if provided).
590
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
591
- maximum acceptable input length for the model if that argument is not provided.
592
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
593
- different lengths).
594
- max_length (:obj:`int`, `optional`):
595
- Maximum length of the returned list and optionally padding length (see above).
596
- pad_to_multiple_of (:obj:`int`, `optional`):
597
- If set will pad the sequence to a multiple of the provided value.
598
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
599
- 7.5 (Volta).
600
- label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
601
- The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
602
- """
603
-
604
- class_type = "gene"
605
- padding: Union[bool, str, PaddingStrategy] = True
606
- max_length: Optional[int] = None
607
- pad_to_multiple_of: Optional[int] = None
608
- label_pad_token_id: int = -100
609
-
610
- def __init__(self, *args, **kwargs) -> None:
611
- self.token_dictionary = kwargs.pop("token_dictionary")
612
- super().__init__(
613
- tokenizer=PrecollatorForGeneAndCellClassification(
614
- token_dictionary=self.token_dictionary
615
- ),
616
- padding=self.padding,
617
- max_length=self.max_length,
618
- pad_to_multiple_of=self.pad_to_multiple_of,
619
- label_pad_token_id=self.label_pad_token_id,
620
- *args,
621
- **kwargs,
622
- )
623
-
624
- def _prepare_batch(self, features):
625
- label_name = "label" if "label" in features[0].keys() else "labels"
626
- labels = (
627
- [feature[label_name] for feature in features]
628
- if label_name in features[0].keys()
629
- else None
630
- )
631
- batch = self.tokenizer.pad(
632
- features,
633
- class_type=self.class_type,
634
- padding=self.padding,
635
- max_length=self.max_length,
636
- pad_to_multiple_of=self.pad_to_multiple_of,
637
- return_tensors="pt",
638
- )
639
- return batch
640
-
641
- def __call__(self, features):
642
- batch = self._prepare_batch(features)
643
-
644
- batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
645
- return batch
646
-
647
-
648
- class DataCollatorForCellClassification(DataCollatorForGeneClassification):
649
- class_type = "cell"
650
-
651
- def _prepare_batch(self, features):
652
- batch = super()._prepare_batch(features)
653
-
654
- # Special handling for labels.
655
- # Ensure that tensor is created with the correct type
656
- # (it should be automatically the case, but let's make sure of it.)
657
- first = features[0]
658
- if "label" in first and first["label"] is not None:
659
- label = (
660
- first["label"].item()
661
- if isinstance(first["label"], torch.Tensor)
662
- else first["label"]
663
- )
664
- dtype = torch.long if isinstance(label, int) else torch.float
665
- batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
666
-
667
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/emb_extractor.py DELETED
@@ -1,863 +0,0 @@
1
- """
2
- Geneformer embedding extractor.
3
-
4
- **Description:**
5
-
6
- | Extracts gene or cell embeddings.
7
- | Plots cell embeddings as heatmaps or UMAPs.
8
- | Generates cell state embedding dictionary for use with InSilicoPerturber.
9
-
10
- """
11
-
12
- # imports
13
- import logging
14
- import pickle
15
- from collections import Counter
16
- from pathlib import Path
17
-
18
- import anndata
19
- import matplotlib.pyplot as plt
20
- import pandas as pd
21
- import scanpy as sc
22
- import seaborn as sns
23
- import torch
24
- from tdigest import TDigest
25
- from tqdm.auto import trange
26
-
27
- from . import TOKEN_DICTIONARY_FILE
28
- from . import perturber_utils as pu
29
-
30
- logger = logging.getLogger(__name__)
31
-
32
-
33
- # extract embeddings
34
- def get_embs(
35
- model,
36
- filtered_input_data,
37
- emb_mode,
38
- layer_to_quant,
39
- pad_token_id,
40
- forward_batch_size,
41
- token_gene_dict,
42
- special_token=False,
43
- summary_stat=None,
44
- silent=False,
45
- ):
46
- model_input_size = pu.get_model_input_size(model)
47
- total_batch_length = len(filtered_input_data)
48
-
49
- if summary_stat is None:
50
- embs_list = []
51
- elif summary_stat is not None:
52
- # get # of emb dims
53
- emb_dims = pu.get_model_emb_dims(model)
54
- if emb_mode == "cell":
55
- # initiate tdigests for # of emb dims
56
- embs_tdigests = [TDigest() for _ in range(emb_dims)]
57
- if emb_mode == "gene":
58
- gene_set = list(
59
- {
60
- element
61
- for sublist in filtered_input_data["input_ids"]
62
- for element in sublist
63
- }
64
- )
65
- # initiate dict with genes as keys and tdigests for # of emb dims as values
66
- embs_tdigests_dict = {
67
- k: [TDigest() for _ in range(emb_dims)] for k in gene_set
68
- }
69
-
70
- # Check if CLS and EOS token is present in the token dictionary
71
- cls_present = any("<cls>" in value for value in token_gene_dict.values())
72
- eos_present = any("<eos>" in value for value in token_gene_dict.values())
73
- if emb_mode == "cls":
74
- assert cls_present, "<cls> token missing in token dictionary"
75
- # Check to make sure that the first token of the filtered input data is cls token
76
- gene_token_dict = {v: k for k, v in token_gene_dict.items()}
77
- cls_token_id = gene_token_dict["<cls>"]
78
- assert (
79
- filtered_input_data["input_ids"][0][0] == cls_token_id
80
- ), "First token is not <cls> token value"
81
- elif emb_mode == "cell":
82
- if cls_present:
83
- logger.warning(
84
- "CLS token present in token dictionary, excluding from average."
85
- )
86
- if eos_present:
87
- logger.warning(
88
- "EOS token present in token dictionary, excluding from average."
89
- )
90
-
91
- overall_max_len = 0
92
-
93
- for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
94
- max_range = min(i + forward_batch_size, total_batch_length)
95
-
96
- minibatch = filtered_input_data.select([i for i in range(i, max_range)])
97
-
98
- max_len = int(max(minibatch["length"]))
99
- original_lens = torch.tensor(minibatch["length"], device="cuda")
100
- minibatch.set_format(type="torch")
101
-
102
- input_data_minibatch = minibatch["input_ids"]
103
- input_data_minibatch = pu.pad_tensor_list(
104
- input_data_minibatch, max_len, pad_token_id, model_input_size
105
- )
106
-
107
- with torch.no_grad():
108
- outputs = model(
109
- input_ids=input_data_minibatch.to("cuda"),
110
- attention_mask=pu.gen_attention_mask(minibatch),
111
- )
112
-
113
- embs_i = outputs.hidden_states[layer_to_quant]
114
-
115
- if emb_mode == "cell":
116
- if cls_present:
117
- non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
118
- if eos_present:
119
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
120
- else:
121
- mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
122
- else:
123
- mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
124
- if summary_stat is None:
125
- embs_list.append(mean_embs)
126
- elif summary_stat is not None:
127
- # update tdigests with current batch for each emb dim
128
- accumulate_tdigests(embs_tdigests, mean_embs, emb_dims)
129
- del mean_embs
130
- elif emb_mode == "gene":
131
- if summary_stat is None:
132
- embs_list.append(embs_i)
133
- elif summary_stat is not None:
134
- for h in trange(len(minibatch)):
135
- length_h = minibatch[h]["length"]
136
- input_ids_h = minibatch[h]["input_ids"][0:length_h]
137
-
138
- # double check dimensions before unsqueezing
139
- embs_i_dim = embs_i.dim()
140
- if embs_i_dim != 3:
141
- logger.error(
142
- f"Embedding tensor should have 3 dimensions, not {embs_i_dim}"
143
- )
144
- raise
145
-
146
- embs_h = embs_i[h, :, :].unsqueeze(dim=1)
147
- dict_h = dict(zip(input_ids_h, embs_h))
148
- for k in dict_h.keys():
149
- accumulate_tdigests(
150
- embs_tdigests_dict[int(k)], dict_h[k], emb_dims
151
- )
152
- del embs_h
153
- del dict_h
154
- elif emb_mode == "cls":
155
- cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer
156
- embs_list.append(cls_embs)
157
- del cls_embs
158
-
159
- overall_max_len = max(overall_max_len, max_len)
160
- del outputs
161
- del minibatch
162
- del input_data_minibatch
163
- del embs_i
164
-
165
- torch.cuda.empty_cache()
166
-
167
- if summary_stat is None:
168
- if (emb_mode == "cell") or (emb_mode == "cls"):
169
- embs_stack = torch.cat(embs_list, dim=0)
170
- elif emb_mode == "gene":
171
- embs_stack = pu.pad_tensor_list(
172
- embs_list,
173
- overall_max_len,
174
- pad_token_id,
175
- model_input_size,
176
- 1,
177
- pu.pad_3d_tensor,
178
- )
179
-
180
- # calculate summary stat embs from approximated tdigests
181
- elif summary_stat is not None:
182
- if emb_mode == "cell":
183
- if summary_stat == "mean":
184
- summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
185
- elif summary_stat == "median":
186
- summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
187
- embs_stack = torch.tensor(summary_emb_list)
188
- elif emb_mode == "gene":
189
- if summary_stat == "mean":
190
- [
191
- update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
192
- for gene in embs_tdigests_dict.keys()
193
- ]
194
- elif summary_stat == "median":
195
- [
196
- update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims)
197
- for gene in embs_tdigests_dict.keys()
198
- ]
199
- return embs_tdigests_dict
200
-
201
- return embs_stack
202
-
203
-
204
- def accumulate_tdigests(embs_tdigests, mean_embs, emb_dims):
205
- # note: tdigest batch update known to be slow so updating serially
206
- [
207
- embs_tdigests[j].update(mean_embs[i, j].item())
208
- for i in range(mean_embs.size(0))
209
- for j in range(emb_dims)
210
- ]
211
-
212
-
213
- def update_tdigest_dict(embs_tdigests_dict, gene, gene_embs, emb_dims):
214
- embs_tdigests_dict[gene] = accumulate_tdigests(
215
- embs_tdigests_dict[gene], gene_embs, emb_dims
216
- )
217
-
218
-
219
- def update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims):
220
- embs_tdigests_dict[gene] = tdigest_mean(embs_tdigests_dict[gene], emb_dims)
221
-
222
-
223
- def update_tdigest_dict_median(embs_tdigests_dict, gene, emb_dims):
224
- embs_tdigests_dict[gene] = tdigest_median(embs_tdigests_dict[gene], emb_dims)
225
-
226
-
227
- def summarize_gene_embs(h, minibatch, embs_i, embs_tdigests_dict, emb_dims):
228
- length_h = minibatch[h]["length"]
229
- input_ids_h = minibatch[h]["input_ids"][0:length_h]
230
- embs_h = embs_i[h, :, :].unsqueeze(dim=1)
231
- dict_h = dict(zip(input_ids_h, embs_h))
232
- [
233
- update_tdigest_dict(embs_tdigests_dict, k, dict_h[k], emb_dims)
234
- for k in dict_h.keys()
235
- ]
236
-
237
-
238
- def tdigest_mean(embs_tdigests, emb_dims):
239
- return [embs_tdigests[i].trimmed_mean(0, 100) for i in range(emb_dims)]
240
-
241
-
242
- def tdigest_median(embs_tdigests, emb_dims):
243
- return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
244
-
245
-
246
- def label_cell_embs(embs, downsampled_data, emb_labels):
247
- embs_df = pd.DataFrame(embs.cpu().numpy())
248
- if emb_labels is not None:
249
- for label in emb_labels:
250
- emb_label = downsampled_data[label]
251
- embs_df[label] = emb_label
252
- return embs_df
253
-
254
-
255
- def label_gene_embs(embs, downsampled_data, token_gene_dict):
256
- gene_set = {
257
- element for sublist in downsampled_data["input_ids"] for element in sublist
258
- }
259
- gene_emb_dict = {k: [] for k in gene_set}
260
- for i in range(embs.size()[0]):
261
- length = downsampled_data[i]["length"]
262
- dict_i = dict(
263
- zip(
264
- downsampled_data[i]["input_ids"][0:length],
265
- embs[i, :, :].unsqueeze(dim=1),
266
- )
267
- )
268
- for k in dict_i.keys():
269
- gene_emb_dict[k].append(dict_i[k])
270
- for k in gene_emb_dict.keys():
271
- gene_emb_dict[k] = (
272
- torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
273
- .cpu()
274
- .numpy()
275
- )
276
- embs_df = pd.DataFrame(gene_emb_dict).T
277
- embs_df.index = [token_gene_dict[token] for token in embs_df.index]
278
- return embs_df
279
-
280
-
281
- def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0):
282
- only_embs_df = embs_df.iloc[:, :emb_dims]
283
- only_embs_df.index = pd.RangeIndex(0, only_embs_df.shape[0], name=None).astype(str)
284
- only_embs_df.columns = pd.RangeIndex(0, only_embs_df.shape[1], name=None).astype(
285
- str
286
- )
287
- vars_dict = {"embs": only_embs_df.columns}
288
- obs_dict = {"cell_id": list(only_embs_df.index), f"{label}": list(embs_df[label])}
289
- adata = anndata.AnnData(X=only_embs_df, obs=obs_dict, var=vars_dict)
290
- sc.tl.pca(adata, svd_solver="arpack")
291
- sc.pp.neighbors(adata, random_state=seed)
292
- sc.tl.umap(adata, random_state=seed)
293
- sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3)
294
- sns.set_style("white")
295
- default_kwargs_dict = {"size": 200}
296
- if kwargs_dict is not None:
297
- default_kwargs_dict.update(kwargs_dict)
298
-
299
- cats = set(embs_df[label])
300
-
301
- with plt.rc_context():
302
- ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict)
303
- ax.legend(
304
- markerscale=2,
305
- frameon=False,
306
- loc="center left",
307
- bbox_to_anchor=(1, 0.5),
308
- ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3),
309
- )
310
- plt.show()
311
- plt.savefig(output_file, bbox_inches="tight")
312
-
313
-
314
- def gen_heatmap_class_colors(labels, df):
315
- pal = sns.cubehelix_palette(
316
- len(Counter(labels).keys()),
317
- light=0.9,
318
- dark=0.1,
319
- hue=1,
320
- reverse=True,
321
- start=1,
322
- rot=-2,
323
- )
324
- lut = dict(zip(map(str, Counter(labels).keys()), pal))
325
- colors = pd.Series(labels, index=df.index).map(lut)
326
- return colors
327
-
328
-
329
- def gen_heatmap_class_dict(classes, label_colors_series):
330
- class_color_dict_df = pd.DataFrame(
331
- {"classes": classes, "color": label_colors_series}
332
- )
333
- class_color_dict_df = class_color_dict_df.drop_duplicates(subset=["classes"])
334
- return dict(zip(class_color_dict_df["classes"], class_color_dict_df["color"]))
335
-
336
-
337
- def make_colorbar(embs_df, label):
338
- labels = list(embs_df[label])
339
-
340
- cell_type_colors = gen_heatmap_class_colors(labels, embs_df)
341
- label_colors = pd.DataFrame(cell_type_colors, columns=[label])
342
-
343
- # create dictionary for colors and classes
344
- label_color_dict = gen_heatmap_class_dict(labels, label_colors[label])
345
- return label_colors, label_color_dict
346
-
347
-
348
- def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
349
- sns.set_style("white")
350
- sns.set(font_scale=2)
351
- plt.figure(figsize=(15, 15), dpi=150)
352
- label_colors, label_color_dict = make_colorbar(embs_df, label)
353
-
354
- default_kwargs_dict = {
355
- "row_cluster": True,
356
- "col_cluster": True,
357
- "row_colors": label_colors,
358
- "standard_scale": 1,
359
- "linewidths": 0,
360
- "xticklabels": False,
361
- "yticklabels": False,
362
- "figsize": (15, 15),
363
- "center": 0,
364
- "cmap": "magma",
365
- }
366
-
367
- if kwargs_dict is not None:
368
- default_kwargs_dict.update(kwargs_dict)
369
- g = sns.clustermap(
370
- embs_df.iloc[:, 0:emb_dims].apply(pd.to_numeric), **default_kwargs_dict
371
- )
372
-
373
- plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
374
-
375
- for label_color in list(label_color_dict.keys()):
376
- g.ax_col_dendrogram.bar(
377
- 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
378
- )
379
-
380
- g.ax_col_dendrogram.legend(
381
- title=f"{label}",
382
- loc="lower center",
383
- ncol=4,
384
- bbox_to_anchor=(0.5, 1),
385
- facecolor="white",
386
- )
387
- plt.show()
388
- logger.info(f"Output file: {output_file}")
389
- plt.savefig(output_file, bbox_inches="tight")
390
-
391
-
392
- class EmbExtractor:
393
- valid_option_dict = {
394
- "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
395
- "num_classes": {int},
396
- "emb_mode": {"cls", "cell", "gene"},
397
- "cell_emb_style": {"mean_pool"},
398
- "gene_emb_style": {"mean_pool"},
399
- "filter_data": {None, dict},
400
- "max_ncells": {None, int},
401
- "emb_layer": {-1, 0},
402
- "emb_label": {None, list},
403
- "labels_to_plot": {None, list},
404
- "forward_batch_size": {int},
405
- "token_dictionary_file": {None, str},
406
- "nproc": {int},
407
- "summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
408
- }
409
-
410
- def __init__(
411
- self,
412
- model_type="Pretrained",
413
- num_classes=0,
414
- emb_mode="cls",
415
- cell_emb_style="mean_pool",
416
- gene_emb_style="mean_pool",
417
- filter_data=None,
418
- max_ncells=1000,
419
- emb_layer=-1,
420
- emb_label=None,
421
- labels_to_plot=None,
422
- forward_batch_size=100,
423
- nproc=4,
424
- summary_stat=None,
425
- token_dictionary_file=None,
426
- ):
427
- """
428
- Initialize embedding extractor.
429
-
430
- **Parameters:**
431
-
432
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier"}
433
- | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier.
434
- num_classes : int
435
- | If model is a gene or cell classifier, specify number of classes it was trained to classify.
436
- | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
437
- emb_mode : {"cls", "cell", "gene"}
438
- | Whether to output CLS, cell, or gene embeddings.
439
- | CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
440
- cell_emb_style : {"mean_pool"}
441
- | Method for summarizing cell embeddings if not using CLS token.
442
- | Currently only option is mean pooling of gene embeddings for given cell.
443
- gene_emb_style : "mean_pool"
444
- | Method for summarizing gene embeddings.
445
- | Currently only option is mean pooling of contextual gene embeddings for given gene.
446
- filter_data : None, dict
447
- | Default is to extract embeddings from all input data.
448
- | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
449
- max_ncells : None, int
450
- | Maximum number of cells to extract embeddings from.
451
- | Default is 1000 cells randomly sampled from input data.
452
- | If None, will extract embeddings from all cells.
453
- emb_layer : {-1, 0}
454
- | Embedding layer to extract.
455
- | The last layer is most specifically weighted to optimize the given learning objective.
456
- | Generally, it is best to extract the 2nd to last layer to get a more general representation.
457
- | -1: 2nd to last layer
458
- | 0: last layer
459
- emb_label : None, list
460
- | List of column name(s) in .dataset to add as labels to embedding output.
461
- labels_to_plot : None, list
462
- | Cell labels to plot.
463
- | Shown as color bar in heatmap.
464
- | Shown as cell color in umap.
465
- | Plotting umap requires labels to plot.
466
- forward_batch_size : int
467
- | Batch size for forward pass.
468
- nproc : int
469
- | Number of CPU processes to use.
470
- summary_stat : {None, "mean", "median", "exact_mean", "exact_median"}
471
- | If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
472
- | If mean or median, outputs only approximated mean or median embedding of input data.
473
- | Non-exact recommended if encountering memory constraints while generating goal embedding positions.
474
- | Non-exact is slower but more memory-efficient.
475
- token_dictionary_file : Path
476
- | Default is the Geneformer token dictionary
477
- | Path to pickle file containing token dictionary (Ensembl ID:token).
478
-
479
- **Examples:**
480
-
481
- .. code-block :: python
482
-
483
- >>> from geneformer import EmbExtractor
484
- >>> embex = EmbExtractor(model_type="CellClassifier",
485
- ... num_classes=3,
486
- ... emb_mode="cell",
487
- ... filter_data={"cell_type":["cardiomyocyte"]},
488
- ... max_ncells=1000,
489
- ... emb_layer=-1,
490
- ... emb_label=["disease", "cell_type"],
491
- ... labels_to_plot=["disease", "cell_type"])
492
-
493
- """
494
-
495
- self.model_type = model_type
496
- self.num_classes = num_classes
497
- self.emb_mode = emb_mode
498
- self.cell_emb_style = cell_emb_style
499
- self.gene_emb_style = gene_emb_style
500
- self.filter_data = filter_data
501
- self.max_ncells = max_ncells
502
- self.emb_layer = emb_layer
503
- self.emb_label = emb_label
504
- self.labels_to_plot = labels_to_plot
505
- self.token_dictionary_file = token_dictionary_file
506
- self.forward_batch_size = forward_batch_size
507
- self.nproc = nproc
508
- if (summary_stat is not None) and ("exact" in summary_stat):
509
- self.summary_stat = None
510
- self.exact_summary_stat = summary_stat
511
- else:
512
- self.summary_stat = summary_stat
513
- self.exact_summary_stat = None
514
-
515
- self.validate_options()
516
-
517
- # load token dictionary (Ensembl IDs:token)
518
- if self.token_dictionary_file is None:
519
- token_dictionary_file = TOKEN_DICTIONARY_FILE
520
- with open(token_dictionary_file, "rb") as f:
521
- self.gene_token_dict = pickle.load(f)
522
-
523
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
524
- self.pad_token_id = self.gene_token_dict.get("<pad>")
525
-
526
- def validate_options(self):
527
- # confirm arguments are within valid options and compatible with each other
528
- for attr_name, valid_options in self.valid_option_dict.items():
529
- attr_value = self.__dict__[attr_name]
530
- if not isinstance(attr_value, (list, dict)):
531
- if attr_value in valid_options:
532
- continue
533
- valid_type = False
534
- for option in valid_options:
535
- if (option in [int, list, dict, bool, str]) and isinstance(
536
- attr_value, option
537
- ):
538
- valid_type = True
539
- break
540
- if valid_type:
541
- continue
542
- logger.error(
543
- f"Invalid option for {attr_name}. "
544
- f"Valid options for {attr_name}: {valid_options}"
545
- )
546
- raise
547
-
548
- if self.filter_data is not None:
549
- for key, value in self.filter_data.items():
550
- if not isinstance(value, list):
551
- self.filter_data[key] = [value]
552
- logger.warning(
553
- "Values in filter_data dict must be lists. "
554
- f"Changing {key} value to list ([{value}])."
555
- )
556
-
557
- def extract_embs(
558
- self,
559
- model_directory,
560
- input_data_file,
561
- output_directory,
562
- output_prefix,
563
- output_torch_embs=False,
564
- cell_state=None,
565
- ):
566
- """
567
- Extract embeddings from input data and save as results in output_directory.
568
-
569
- **Parameters:**
570
-
571
- model_directory : Path
572
- | Path to directory containing model
573
- input_data_file : Path
574
- | Path to directory containing .dataset inputs
575
- output_directory : Path
576
- | Path to directory where embedding data will be saved as csv
577
- output_prefix : str
578
- | Prefix for output file
579
- output_torch_embs : bool
580
- | Whether or not to also output the embeddings as a tensor.
581
- | Note, if true, will output embeddings as both dataframe and tensor.
582
- cell_state : dict
583
- | Cell state key and value for state embedding extraction.
584
-
585
- **Examples:**
586
-
587
- .. code-block :: python
588
-
589
- >>> embs = embex.extract_embs("path/to/model",
590
- ... "path/to/input_data",
591
- ... "path/to/output_directory",
592
- ... "output_prefix")
593
-
594
- """
595
-
596
- filtered_input_data = pu.load_and_filter(
597
- self.filter_data, self.nproc, input_data_file
598
- )
599
-
600
- # Check to make sure that all the labels exist in the tokenized data:
601
- if self.emb_label is not None:
602
- for label in self.emb_label:
603
- assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
604
-
605
- if cell_state is not None:
606
- filtered_input_data = pu.filter_by_dict(
607
- filtered_input_data, cell_state, self.nproc
608
- )
609
- downsampled_data = pu.downsample_and_sort(filtered_input_data, self.max_ncells)
610
- model = pu.load_model(
611
- self.model_type, self.num_classes, model_directory, mode="eval"
612
- )
613
- layer_to_quant = pu.quant_layers(model) + self.emb_layer
614
- embs = get_embs(
615
- model=model,
616
- filtered_input_data=downsampled_data,
617
- emb_mode=self.emb_mode,
618
- layer_to_quant=layer_to_quant,
619
- pad_token_id=self.pad_token_id,
620
- forward_batch_size=self.forward_batch_size,
621
- token_gene_dict=self.token_gene_dict,
622
- summary_stat=self.summary_stat,
623
- )
624
-
625
- if self.emb_mode == "cell":
626
- if self.summary_stat is None:
627
- embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
628
- elif self.summary_stat is not None:
629
- embs_df = pd.DataFrame(embs.cpu().numpy()).T
630
- elif self.emb_mode == "gene":
631
- if self.summary_stat is None:
632
- embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
633
- elif self.summary_stat is not None:
634
- embs_df = pd.DataFrame(embs).T
635
- embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
636
- elif self.emb_mode == "cls":
637
- embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
638
-
639
- # save embeddings to output_path
640
- if cell_state is None:
641
- output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
642
- embs_df.to_csv(output_path)
643
-
644
- if self.exact_summary_stat == "exact_mean":
645
- embs = embs.mean(dim=0)
646
- emb_dims = pu.get_model_emb_dims(model)
647
- embs_df = pd.DataFrame(
648
- embs_df[0 : emb_dims - 1].mean(axis="rows"),
649
- columns=[self.exact_summary_stat],
650
- ).T
651
- elif self.exact_summary_stat == "exact_median":
652
- embs = torch.median(embs, dim=0)[0]
653
- emb_dims = pu.get_model_emb_dims(model)
654
- embs_df = pd.DataFrame(
655
- embs_df[0 : emb_dims - 1].median(axis="rows"),
656
- columns=[self.exact_summary_stat],
657
- ).T
658
-
659
- if cell_state is not None:
660
- return embs
661
- else:
662
- if output_torch_embs:
663
- return embs_df, embs
664
- else:
665
- return embs_df
666
-
667
- def get_state_embs(
668
- self,
669
- cell_states_to_model,
670
- model_directory,
671
- input_data_file,
672
- output_directory,
673
- output_prefix,
674
- output_torch_embs=True,
675
- ):
676
- """
677
- Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
678
-
679
- **Parameters:**
680
-
681
- cell_states_to_model : None, dict
682
- | Cell states to model if testing perturbations that achieve goal state change.
683
- | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
684
- | state_key: key specifying name of column in .dataset that defines the start/goal states
685
- | start_state: value in the state_key column that specifies the start state
686
- | goal_state: value in the state_key column taht specifies the goal end state
687
- | alt_states: list of values in the state_key column that specify the alternate end states
688
- | For example:
689
- | {"state_key": "disease",
690
- | "start_state": "dcm",
691
- | "goal_state": "nf",
692
- | "alt_states": ["hcm", "other1", "other2"]}
693
- model_directory : Path
694
- | Path to directory containing model
695
- input_data_file : Path
696
- | Path to directory containing .dataset inputs
697
- output_directory : Path
698
- | Path to directory where embedding data will be saved as csv
699
- output_prefix : str
700
- | Prefix for output file
701
- output_torch_embs : bool
702
- | Whether or not to also output the embeddings as a tensor.
703
- | Note, if true, will output embeddings as both dataframe and tensor.
704
-
705
- **Outputs**
706
-
707
- | Outputs state_embs_dict for use with in silico perturber.
708
- | Format is dictionary of embedding positions of each cell state to model shifts from/towards.
709
- | Keys specify each possible cell state to model.
710
- | Values are target embedding positions as torch.tensor.
711
- | For example:
712
- | {"nf": emb_nf,
713
- | "hcm": emb_hcm,
714
- | "dcm": emb_dcm,
715
- | "other1": emb_other1,
716
- | "other2": emb_other2}
717
- """
718
-
719
- pu.validate_cell_states_to_model(cell_states_to_model)
720
- valid_summary_stats = ["exact_mean", "exact_median"]
721
- if self.exact_summary_stat not in valid_summary_stats:
722
- logger.error(
723
- "For extracting state embs, summary_stat in EmbExtractor "
724
- f"must be set to option in {valid_summary_stats}"
725
- )
726
- raise
727
-
728
- if self.emb_label is not None:
729
- logger.error(
730
- "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
731
- )
732
- raise
733
-
734
- state_embs_dict = dict()
735
- state_key = cell_states_to_model["state_key"]
736
- for k, v in cell_states_to_model.items():
737
- if k == "state_key":
738
- continue
739
- elif (k == "start_state") or (k == "goal_state"):
740
- state_embs_dict[v] = self.extract_embs(
741
- model_directory,
742
- input_data_file,
743
- output_directory,
744
- output_prefix,
745
- output_torch_embs,
746
- cell_state={state_key: v},
747
- )
748
- else: # k == "alt_states"
749
- for alt_state in v:
750
- state_embs_dict[alt_state] = self.extract_embs(
751
- model_directory,
752
- input_data_file,
753
- output_directory,
754
- output_prefix,
755
- output_torch_embs,
756
- cell_state={state_key: alt_state},
757
- )
758
-
759
- output_path = (Path(output_directory) / output_prefix).with_suffix(".pkl")
760
- with open(output_path, "wb") as fp:
761
- pickle.dump(state_embs_dict, fp)
762
-
763
- return state_embs_dict
764
-
765
- def plot_embs(
766
- self,
767
- embs,
768
- plot_style,
769
- output_directory,
770
- output_prefix,
771
- max_ncells_to_plot=1000,
772
- kwargs_dict=None,
773
- ):
774
- """
775
- Plot embeddings, coloring by provided labels.
776
-
777
- **Parameters:**
778
-
779
- embs : pandas.core.frame.DataFrame
780
- | Pandas dataframe containing embeddings output from extract_embs
781
- plot_style : str
782
- | Style of plot: "heatmap" or "umap"
783
- output_directory : Path
784
- | Path to directory where plots will be saved as pdf
785
- output_prefix : str
786
- | Prefix for output file
787
- max_ncells_to_plot : None, int
788
- | Maximum number of cells to plot.
789
- | Default is 1000 cells randomly sampled from embeddings.
790
- | If None, will plot embeddings from all cells.
791
- kwargs_dict : dict
792
- | Dictionary of kwargs to pass to plotting function.
793
-
794
- **Examples:**
795
-
796
- .. code-block :: python
797
-
798
- >>> embex.plot_embs(embs=embs,
799
- ... plot_style="heatmap",
800
- ... output_directory="path/to/output_directory",
801
- ... output_prefix="output_prefix")
802
-
803
- """
804
-
805
- if plot_style not in ["heatmap", "umap"]:
806
- logger.error(
807
- "Invalid option for 'plot_style'. " "Valid options: {'heatmap','umap'}"
808
- )
809
- raise
810
-
811
- if (plot_style == "umap") and (self.labels_to_plot is None):
812
- logger.error("Plotting UMAP requires 'labels_to_plot'. ")
813
- raise
814
-
815
- if max_ncells_to_plot is not None:
816
- if max_ncells_to_plot > self.max_ncells:
817
- max_ncells_to_plot = self.max_ncells
818
- logger.warning(
819
- "max_ncells_to_plot must be <= max_ncells. "
820
- f"Changing max_ncells_to_plot to {self.max_ncells}."
821
- )
822
- elif max_ncells_to_plot < self.max_ncells:
823
- embs = embs.sample(max_ncells_to_plot, axis=0)
824
-
825
- if self.emb_label is None:
826
- label_len = 0
827
- else:
828
- label_len = len(self.emb_label)
829
-
830
- emb_dims = embs.shape[1] - label_len
831
-
832
- if self.emb_label is None:
833
- emb_labels = None
834
- else:
835
- emb_labels = embs.columns[emb_dims:]
836
-
837
- if plot_style == "umap":
838
- for label in self.labels_to_plot:
839
- if label not in emb_labels:
840
- logger.warning(
841
- f"Label {label} from labels_to_plot "
842
- f"not present in provided embeddings dataframe."
843
- )
844
- continue
845
- output_prefix_label = output_prefix + f"_umap_{label}"
846
- output_file = (
847
- Path(output_directory) / output_prefix_label
848
- ).with_suffix(".pdf")
849
- plot_umap(embs, emb_dims, label, output_file, kwargs_dict)
850
-
851
- if plot_style == "heatmap":
852
- for label in self.labels_to_plot:
853
- if label not in emb_labels:
854
- logger.warning(
855
- f"Label {label} from labels_to_plot "
856
- f"not present in provided embeddings dataframe."
857
- )
858
- continue
859
- output_prefix_label = output_prefix + f"_heatmap_{label}"
860
- output_file = (
861
- Path(output_directory) / output_prefix_label
862
- ).with_suffix(".pdf")
863
- plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/evaluation_utils.py DELETED
@@ -1,287 +0,0 @@
1
- import logging
2
- import math
3
- import pickle
4
- from pathlib import Path
5
-
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import pandas as pd
9
- import seaborn as sns
10
- import torch
11
- from datasets.utils.logging import disable_progress_bar, enable_progress_bar
12
- from sklearn import preprocessing
13
- from sklearn.metrics import (
14
- ConfusionMatrixDisplay,
15
- accuracy_score,
16
- auc,
17
- confusion_matrix,
18
- f1_score,
19
- roc_curve,
20
- )
21
- from tqdm.auto import trange
22
-
23
- from . import TOKEN_DICTIONARY_FILE
24
- from .emb_extractor import make_colorbar
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- def preprocess_classifier_batch(cell_batch, max_len, label_name):
30
- if max_len is None:
31
- max_len = max([len(i) for i in cell_batch["input_ids"]])
32
-
33
- # load token dictionary (Ensembl IDs:token)
34
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
35
- gene_token_dict = pickle.load(f)
36
-
37
- def pad_label_example(example):
38
- example[label_name] = np.pad(
39
- example[label_name],
40
- (0, max_len - len(example["input_ids"])),
41
- mode="constant",
42
- constant_values=-100,
43
- )
44
- example["input_ids"] = np.pad(
45
- example["input_ids"],
46
- (0, max_len - len(example["input_ids"])),
47
- mode="constant",
48
- constant_values=gene_token_dict.get("<pad>"),
49
- )
50
- example["attention_mask"] = (
51
- example["input_ids"] != gene_token_dict.get("<pad>")
52
- ).astype(int)
53
- return example
54
-
55
- padded_batch = cell_batch.map(pad_label_example)
56
- return padded_batch
57
-
58
-
59
- # Function to find the largest number smaller
60
- # than or equal to N that is divisible by k
61
- def find_largest_div(N, K):
62
- rem = N % K
63
- if rem == 0:
64
- return N
65
- else:
66
- return N - rem
67
-
68
-
69
- def vote(logit_list):
70
- m = max(logit_list)
71
- logit_list.index(m)
72
- indices = [i for i, x in enumerate(logit_list) if x == m]
73
- if len(indices) > 1:
74
- return "tie"
75
- else:
76
- return indices[0]
77
-
78
-
79
- def py_softmax(vector):
80
- e = np.exp(vector)
81
- return e / e.sum()
82
-
83
-
84
- def classifier_predict(model, classifier_type, evalset, forward_batch_size):
85
- if classifier_type == "gene":
86
- label_name = "labels"
87
- elif classifier_type == "cell":
88
- label_name = "label"
89
-
90
- predict_logits = []
91
- predict_labels = []
92
- model.eval()
93
-
94
- # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
95
- evalset_len = len(evalset)
96
- max_divisible = find_largest_div(evalset_len, forward_batch_size)
97
- if len(evalset) - max_divisible == 1:
98
- evalset_len = max_divisible
99
-
100
- max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
101
-
102
- disable_progress_bar() # disable progress bar for preprocess_classifier_batch mapping
103
- for i in trange(0, evalset_len, forward_batch_size):
104
- max_range = min(i + forward_batch_size, evalset_len)
105
- batch_evalset = evalset.select([i for i in range(i, max_range)])
106
- padded_batch = preprocess_classifier_batch(
107
- batch_evalset, max_evalset_len, label_name
108
- )
109
- padded_batch.set_format(type="torch")
110
-
111
- input_data_batch = padded_batch["input_ids"]
112
- attn_msk_batch = padded_batch["attention_mask"]
113
- label_batch = padded_batch[label_name]
114
- with torch.no_grad():
115
- outputs = model(
116
- input_ids=input_data_batch.to("cuda"),
117
- attention_mask=attn_msk_batch.to("cuda"),
118
- labels=label_batch.to("cuda"),
119
- )
120
- predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
121
- predict_labels += [torch.squeeze(label_batch.to("cpu"))]
122
-
123
- enable_progress_bar()
124
- logits_by_cell = torch.cat(predict_logits)
125
- last_dim = len(logits_by_cell.shape) - 1
126
- all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[last_dim])
127
- labels_by_cell = torch.cat(predict_labels)
128
- all_labels = torch.flatten(labels_by_cell)
129
- logit_label_paired = [
130
- item
131
- for item in list(zip(all_logits.tolist(), all_labels.tolist()))
132
- if item[1] != -100
133
- ]
134
- y_pred = [vote(item[0]) for item in logit_label_paired]
135
- y_true = [item[1] for item in logit_label_paired]
136
- logits_list = [item[0] for item in logit_label_paired]
137
- return y_pred, y_true, logits_list
138
-
139
-
140
- def get_metrics(y_pred, y_true, logits_list, num_classes, labels):
141
- conf_mat = confusion_matrix(y_true, y_pred, labels=list(labels))
142
- macro_f1 = f1_score(y_true, y_pred, average="macro")
143
- acc = accuracy_score(y_true, y_pred)
144
- roc_metrics = None # roc metrics not reported for multiclass
145
- if num_classes == 2:
146
- y_score = [py_softmax(item)[1] for item in logits_list]
147
- fpr, tpr, _ = roc_curve(y_true, y_score)
148
- mean_fpr = np.linspace(0, 1, 100)
149
- interp_tpr = np.interp(mean_fpr, fpr, tpr)
150
- interp_tpr[0] = 0.0
151
- tpr_wt = len(tpr)
152
- roc_auc = auc(fpr, tpr)
153
- roc_metrics = {
154
- "fpr": fpr,
155
- "tpr": tpr,
156
- "interp_tpr": interp_tpr,
157
- "auc": roc_auc,
158
- "tpr_wt": tpr_wt,
159
- }
160
- return conf_mat, macro_f1, acc, roc_metrics
161
-
162
-
163
- # get cross-validated mean and sd metrics
164
- def get_cross_valid_roc_metrics(all_tpr, all_roc_auc, all_tpr_wt):
165
- wts = [count / sum(all_tpr_wt) for count in all_tpr_wt]
166
- all_weighted_tpr = [a * b for a, b in zip(all_tpr, wts)]
167
- mean_tpr = np.sum(all_weighted_tpr, axis=0)
168
- mean_tpr[-1] = 1.0
169
- all_weighted_roc_auc = [a * b for a, b in zip(all_roc_auc, wts)]
170
- roc_auc = np.sum(all_weighted_roc_auc)
171
- roc_auc_sd = math.sqrt(np.average((all_roc_auc - roc_auc) ** 2, weights=wts))
172
- return mean_tpr, roc_auc, roc_auc_sd
173
-
174
-
175
- # plot ROC curve
176
- def plot_ROC(roc_metric_dict, model_style_dict, title, output_dir, output_prefix):
177
- fig = plt.figure()
178
- fig.set_size_inches(10, 8)
179
- sns.set(font_scale=2)
180
- sns.set_style("white")
181
- lw = 3
182
- for model_name in roc_metric_dict.keys():
183
- mean_fpr = roc_metric_dict[model_name]["mean_fpr"]
184
- mean_tpr = roc_metric_dict[model_name]["mean_tpr"]
185
- roc_auc = roc_metric_dict[model_name]["roc_auc"]
186
- roc_auc_sd = roc_metric_dict[model_name]["roc_auc_sd"]
187
- color = model_style_dict[model_name]["color"]
188
- linestyle = model_style_dict[model_name]["linestyle"]
189
- if len(roc_metric_dict[model_name]["all_roc_auc"]) > 1:
190
- label = f"{model_name} (AUC {roc_auc:0.2f} $\pm$ {roc_auc_sd:0.2f})"
191
- else:
192
- label = f"{model_name} (AUC {roc_auc:0.2f})"
193
- plt.plot(
194
- mean_fpr, mean_tpr, color=color, linestyle=linestyle, lw=lw, label=label
195
- )
196
-
197
- plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
198
- plt.xlim([0.0, 1.0])
199
- plt.ylim([0.0, 1.05])
200
- plt.xlabel("False Positive Rate")
201
- plt.ylabel("True Positive Rate")
202
- plt.title(title)
203
- plt.legend(loc="lower right")
204
-
205
- output_file = (Path(output_dir) / f"{output_prefix}_roc").with_suffix(".pdf")
206
- plt.savefig(output_file, bbox_inches="tight")
207
- plt.show()
208
-
209
-
210
- # plot confusion matrix
211
- def plot_confusion_matrix(
212
- conf_mat_df, title, output_dir, output_prefix, custom_class_order
213
- ):
214
- fig = plt.figure()
215
- fig.set_size_inches(10, 10)
216
- sns.set(font_scale=1)
217
- sns.set_style("whitegrid", {"axes.grid": False})
218
- if custom_class_order is not None:
219
- conf_mat_df = conf_mat_df.reindex(
220
- index=custom_class_order, columns=custom_class_order
221
- )
222
- display_labels = generate_display_labels(conf_mat_df)
223
- conf_mat = preprocessing.normalize(conf_mat_df.to_numpy(), norm="l1")
224
- display = ConfusionMatrixDisplay(
225
- confusion_matrix=conf_mat, display_labels=display_labels
226
- )
227
- display.plot(cmap="Blues", values_format=".2g")
228
- plt.title(title)
229
- plt.show()
230
-
231
- output_file = (Path(output_dir) / f"{output_prefix}_conf_mat").with_suffix(".pdf")
232
- display.figure_.savefig(output_file, bbox_inches="tight")
233
-
234
-
235
- def generate_display_labels(conf_mat_df):
236
- display_labels = []
237
- i = 0
238
- for label in conf_mat_df.index:
239
- display_labels += [f"{label}\nn={conf_mat_df.iloc[i,:].sum():.0f}"]
240
- i = i + 1
241
- return display_labels
242
-
243
-
244
- def plot_predictions(predictions_df, title, output_dir, output_prefix, kwargs_dict):
245
- sns.set(font_scale=2)
246
- plt.figure(figsize=(10, 10), dpi=150)
247
- label_colors, label_color_dict = make_colorbar(predictions_df, "true")
248
- predictions_df = predictions_df.drop(columns=["true"])
249
- predict_colors_list = [label_color_dict[label] for label in predictions_df.columns]
250
- predict_label_list = [label for label in predictions_df.columns]
251
- predict_colors = pd.DataFrame(
252
- pd.Series(predict_colors_list, index=predict_label_list), columns=["predicted"]
253
- )
254
-
255
- default_kwargs_dict = {
256
- "row_cluster": False,
257
- "col_cluster": False,
258
- "row_colors": label_colors,
259
- "col_colors": predict_colors,
260
- "linewidths": 0,
261
- "xticklabels": False,
262
- "yticklabels": False,
263
- "center": 0,
264
- "cmap": "vlag",
265
- }
266
-
267
- if kwargs_dict is not None:
268
- default_kwargs_dict.update(kwargs_dict)
269
- g = sns.clustermap(predictions_df, **default_kwargs_dict)
270
-
271
- plt.setp(g.ax_row_colors.get_xmajorticklabels(), rotation=45, ha="right")
272
-
273
- for label_color in list(label_color_dict.keys()):
274
- g.ax_col_dendrogram.bar(
275
- 0, 0, color=label_color_dict[label_color], label=label_color, linewidth=0
276
- )
277
-
278
- g.ax_col_dendrogram.legend(
279
- title=f"{title}",
280
- loc="lower center",
281
- ncol=4,
282
- bbox_to_anchor=(0.5, 1),
283
- facecolor="white",
284
- )
285
-
286
- output_file = (Path(output_dir) / f"{output_prefix}_pred").with_suffix(".pdf")
287
- plt.savefig(output_file, bbox_inches="tight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/in_silico_perturber.py DELETED
@@ -1,1579 +0,0 @@
1
- """
2
- Geneformer in silico perturber.
3
-
4
- **Usage:**
5
-
6
- .. code-block :: python
7
-
8
- >>> from geneformer import InSilicoPerturber
9
- >>> isp = InSilicoPerturber(perturb_type="delete",
10
- ... perturb_rank_shift=None,
11
- ... genes_to_perturb="all",
12
- ... model_type="CellClassifier",
13
- ... num_classes=0,
14
- ... emb_mode="cell",
15
- ... filter_data={"cell_type":["cardiomyocyte"]},
16
- ... cell_states_to_model={"state_key": "disease", "start_state": "dcm", "goal_state": "nf", "alt_states": ["hcm", "other1", "other2"]},
17
- ... state_embs_dict ={"nf": emb_nf, "hcm": emb_hcm, "dcm": emb_dcm, "other1": emb_other1, "other2": emb_other2},
18
- ... max_ncells=None,
19
- ... emb_layer=0,
20
- ... forward_batch_size=100,
21
- ... nproc=16)
22
- >>> isp.perturb_data("path/to/model",
23
- ... "path/to/input_data",
24
- ... "path/to/output_directory",
25
- ... "output_prefix")
26
-
27
- **Description:**
28
-
29
- | Performs in silico perturbation (e.g. deletion or overexpression) of defined set of genes or all genes in sample of cells.
30
- | Outputs impact of perturbation on cell or gene embeddings.
31
- | Output files are analyzed with ``in_silico_perturber_stats``.
32
-
33
- """
34
-
35
- import logging
36
-
37
- # imports
38
- import os
39
- import pickle
40
- from collections import defaultdict
41
-
42
- import torch
43
- from datasets import Dataset
44
- from multiprocess import set_start_method
45
- from tqdm.auto import trange
46
-
47
- from . import TOKEN_DICTIONARY_FILE
48
- from . import perturber_utils as pu
49
- from .emb_extractor import get_embs
50
-
51
- import datasets
52
- datasets.logging.disable_progress_bar()
53
-
54
-
55
- logger = logging.getLogger(__name__)
56
-
57
-
58
- class InSilicoPerturber:
59
- valid_option_dict = {
60
- "perturb_type": {"delete", "overexpress", "inhibit", "activate"},
61
- "perturb_rank_shift": {None, 1, 2, 3},
62
- "genes_to_perturb": {"all", list},
63
- "combos": {0, 1},
64
- "anchor_gene": {None, str},
65
- "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
66
- "num_classes": {int},
67
- "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
68
- "cell_emb_style": {"mean_pool"},
69
- "filter_data": {None, dict},
70
- "cell_states_to_model": {None, dict},
71
- "state_embs_dict": {None, dict},
72
- "max_ncells": {None, int},
73
- "cell_inds_to_perturb": {"all", dict},
74
- "emb_layer": {-1, 0},
75
- "token_dictionary_file": {None, str},
76
- "forward_batch_size": {int},
77
- "nproc": {int},
78
- }
79
-
80
- def __init__(
81
- self,
82
- perturb_type="delete",
83
- perturb_rank_shift=None,
84
- genes_to_perturb="all",
85
- combos=0,
86
- anchor_gene=None,
87
- model_type="Pretrained",
88
- num_classes=0,
89
- emb_mode="cls",
90
- cell_emb_style="mean_pool",
91
- filter_data=None,
92
- cell_states_to_model=None,
93
- state_embs_dict=None,
94
- max_ncells=None,
95
- cell_inds_to_perturb="all",
96
- emb_layer=-1,
97
- forward_batch_size=100,
98
- nproc=4,
99
- token_dictionary_file=None,
100
- clear_mem_ncells=1000,
101
- ):
102
- """
103
- Initialize in silico perturber.
104
-
105
- **Parameters:**
106
-
107
- perturb_type : {"delete", "overexpress", "inhibit", "activate"}
108
- | Type of perturbation.
109
- | "delete": delete gene from rank value encoding
110
- | "overexpress": move gene to front of rank value encoding
111
- | *(TBA)* "inhibit": move gene to lower quartile of rank value encoding
112
- | *(TBA)* "activate": move gene to higher quartile of rank value encoding
113
- *(TBA)* perturb_rank_shift : None, {1,2,3}
114
- | Number of quartiles by which to shift rank of gene.
115
- | For example, if perturb_type="activate" and perturb_rank_shift=1:
116
- | genes in 4th quartile will move to middle of 3rd quartile.
117
- | genes in 3rd quartile will move to middle of 2nd quartile.
118
- | genes in 2nd quartile will move to middle of 1st quartile.
119
- | genes in 1st quartile will move to front of rank value encoding.
120
- | For example, if perturb_type="inhibit" and perturb_rank_shift=2:
121
- | genes in 1st quartile will move to middle of 3rd quartile.
122
- | genes in 2nd quartile will move to middle of 4th quartile.
123
- | genes in 3rd or 4th quartile will move to bottom of rank value encoding.
124
- genes_to_perturb : "all", list
125
- | Default is perturbing each gene detected in each cell in the dataset.
126
- | Otherwise, may provide a list of ENSEMBL IDs of genes to perturb.
127
- | If gene list is provided, then perturber will only test perturbing them all together
128
- | (rather than testing each possible combination of the provided genes).
129
- combos : {0,1}
130
- | Whether to perturb genes individually (0) or in pairs (1).
131
- anchor_gene : None, str
132
- | ENSEMBL ID of gene to use as anchor in combination perturbations.
133
- | For example, if combos=1 and anchor_gene="ENSG00000148400":
134
- | anchor gene will be perturbed in combination with each other gene.
135
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
136
- | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
137
- num_classes : int
138
- | If model is a gene or cell classifier, specify number of classes it was trained to classify.
139
- | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
140
- emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
141
- | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
142
- | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
143
- cell_emb_style : "mean_pool"
144
- | Method for summarizing cell embeddings if not using CLS token.
145
- | Currently only option is mean pooling of gene embeddings for given cell.
146
- filter_data : None, dict
147
- | Default is to use all input data for in silico perturbation study.
148
- | Otherwise, dictionary specifying .dataset column name and list of values to filter by.
149
- cell_states_to_model : None, dict
150
- | Cell states to model if testing perturbations that achieve goal state change.
151
- | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
152
- | state_key: key specifying name of column in .dataset that defines the start/goal states
153
- | start_state: value in the state_key column that specifies the start state
154
- | goal_state: value in the state_key column taht specifies the goal end state
155
- | alt_states: list of values in the state_key column that specify the alternate end states
156
- | For example: {"state_key": "disease",
157
- | "start_state": "dcm",
158
- | "goal_state": "nf",
159
- | "alt_states": ["hcm", "other1", "other2"]}
160
- state_embs_dict : None, dict
161
- | Embedding positions of each cell state to model shifts from/towards (e.g. mean or median).
162
- | Dictionary with keys specifying each possible cell state to model.
163
- | Values are target embedding positions as torch.tensor.
164
- | For example: {"nf": emb_nf,
165
- | "hcm": emb_hcm,
166
- | "dcm": emb_dcm,
167
- | "other1": emb_other1,
168
- | "other2": emb_other2}
169
- max_ncells : None, int
170
- | Maximum number of cells to test.
171
- | If None, will test all cells.
172
- cell_inds_to_perturb : "all", list
173
- | Default is perturbing each cell in the dataset.
174
- | Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
175
- | start_ind: the first index to perturb.
176
- | end_ind: the last index to perturb (exclusive).
177
- | Indices will be selected *after* the filter_data criteria and sorting.
178
- | Useful for splitting extremely large datasets across separate GPUs.
179
- emb_layer : {-1, 0}
180
- | Embedding layer to use for quantification.
181
- | 0: last layer (recommended for questions closely tied to model's training objective)
182
- | -1: 2nd to last layer (recommended for questions requiring more general representations)
183
- forward_batch_size : int
184
- | Batch size for forward pass.
185
- nproc : int
186
- | Number of CPU processes to use.
187
- token_dictionary_file : Path
188
- | Path to pickle file containing token dictionary (Ensembl ID:token).
189
- clear_mem_ncells : int
190
- | Clear memory every n cells.
191
- """
192
- try:
193
- set_start_method("spawn")
194
- except RuntimeError:
195
- pass
196
-
197
- self.perturb_type = perturb_type
198
- self.perturb_rank_shift = perturb_rank_shift
199
- self.genes_to_perturb = genes_to_perturb
200
- self.combos = combos
201
- self.anchor_gene = anchor_gene
202
- if self.genes_to_perturb == "all":
203
- self.perturb_group = False
204
- else:
205
- self.perturb_group = True
206
- if (self.anchor_gene is not None) or (self.combos != 0):
207
- self.anchor_gene = None
208
- self.combos = 0
209
- logger.warning(
210
- "anchor_gene set to None and combos set to 0. "
211
- "If providing list of genes to perturb, "
212
- "list of genes_to_perturb will be perturbed together, "
213
- "without anchor gene or combinations."
214
- )
215
- self.model_type = model_type
216
- self.num_classes = num_classes
217
- self.emb_mode = emb_mode
218
- self.cell_emb_style = cell_emb_style
219
- self.filter_data = filter_data
220
- self.cell_states_to_model = cell_states_to_model
221
- self.state_embs_dict = state_embs_dict
222
- self.max_ncells = max_ncells
223
- self.cell_inds_to_perturb = cell_inds_to_perturb
224
- self.emb_layer = emb_layer
225
- self.forward_batch_size = forward_batch_size
226
- self.nproc = nproc
227
- self.token_dictionary_file = token_dictionary_file
228
- self.clear_mem_ncells = clear_mem_ncells
229
-
230
- self.validate_options()
231
-
232
- # load token dictionary (Ensembl IDs:token)
233
- if self.token_dictionary_file is None:
234
- token_dictionary_file = TOKEN_DICTIONARY_FILE
235
- with open(token_dictionary_file, "rb") as f:
236
- self.gene_token_dict = pickle.load(f)
237
- self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
238
-
239
- self.pad_token_id = self.gene_token_dict.get("<pad>")
240
- self.cls_token_id = self.gene_token_dict.get("<cls>")
241
- self.eos_token_id = self.gene_token_dict.get("<eos>")
242
-
243
- # Identify if special token is present in the token dictionary
244
- if (self.cls_token_id is not None) and (self.eos_token_id is not None):
245
- self.special_token = True
246
- else:
247
- if "cls" in self.emb_mode:
248
- logger.error(
249
- f"emb_mode set to {self.emb_mode} but <cls> or <eos> token not in token dictionary."
250
- )
251
- raise
252
- self.special_token = False
253
-
254
- if self.anchor_gene is None:
255
- self.anchor_token = None
256
- else:
257
- try:
258
- self.anchor_token = [self.gene_token_dict[self.anchor_gene]]
259
- except KeyError:
260
- logger.error(f"Anchor gene {self.anchor_gene} not in token dictionary.")
261
- raise
262
-
263
- if self.genes_to_perturb == "all":
264
- self.tokens_to_perturb = "all"
265
- else:
266
- missing_genes = [
267
- gene
268
- for gene in self.genes_to_perturb
269
- if gene not in self.gene_token_dict.keys()
270
- ]
271
- if len(missing_genes) == len(self.genes_to_perturb):
272
- logger.error(
273
- "None of the provided genes to perturb are in token dictionary."
274
- )
275
- raise
276
- elif len(missing_genes) > 0:
277
- logger.warning(
278
- f"Genes to perturb {missing_genes} are not in token dictionary."
279
- )
280
- self.tokens_to_perturb = [
281
- self.gene_token_dict.get(gene) for gene in self.genes_to_perturb
282
- ]
283
-
284
- def validate_options(self):
285
- # first disallow options under development
286
- if self.perturb_type in ["inhibit", "activate"]:
287
- logger.error(
288
- "In silico inhibition and activation currently under development. "
289
- "Current valid options for 'perturb_type': 'delete' or 'overexpress'"
290
- )
291
- raise
292
- if (self.combos > 0) and (self.anchor_gene is None):
293
- logger.error(
294
- "Combination perturbation without anchor gene is currently under development. "
295
- "Currently, must provide anchor gene for combination perturbation."
296
- )
297
- raise
298
-
299
- # confirm arguments are within valid options and compatible with each other
300
- for attr_name, valid_options in self.valid_option_dict.items():
301
- attr_value = self.__dict__[attr_name]
302
- if type(attr_value) not in {list, dict}:
303
- if attr_value in valid_options:
304
- continue
305
- if attr_name in ["anchor_gene"]:
306
- if type(attr_name) in {str}:
307
- continue
308
- valid_type = False
309
- for option in valid_options:
310
- if (option in [bool, int, list, dict, str]) and isinstance(
311
- attr_value, option
312
- ):
313
- valid_type = True
314
- break
315
- if valid_type:
316
- continue
317
- logger.error(
318
- f"Invalid option for {attr_name}. "
319
- f"Valid options for {attr_name}: {valid_options}"
320
- )
321
- raise
322
-
323
- if self.perturb_type in ["delete", "overexpress"]:
324
- if self.perturb_rank_shift is not None:
325
- if self.perturb_type == "delete":
326
- logger.warning(
327
- "perturb_rank_shift set to None. "
328
- "If perturb type is delete then gene is deleted entirely "
329
- "rather than shifted by quartile"
330
- )
331
- elif self.perturb_type == "overexpress":
332
- logger.warning(
333
- "perturb_rank_shift set to None. "
334
- "If perturb type is overexpress then gene is moved to front "
335
- "of rank value encoding rather than shifted by quartile"
336
- )
337
- self.perturb_rank_shift = None
338
-
339
- if (self.anchor_gene is not None) and (self.emb_mode == "cell_and_gene"):
340
- self.emb_mode = "cell"
341
- logger.warning(
342
- "emb_mode set to 'cell'. "
343
- "Currently, analysis with anchor gene "
344
- "only outputs effect on cell embeddings."
345
- )
346
-
347
- if self.cell_states_to_model is not None:
348
- pu.validate_cell_states_to_model(self.cell_states_to_model)
349
-
350
- if self.anchor_gene is not None:
351
- self.anchor_gene = None
352
- logger.warning(
353
- "anchor_gene set to None. "
354
- "Currently, anchor gene not available "
355
- "when modeling multiple cell states."
356
- )
357
-
358
- if self.state_embs_dict is None:
359
- logger.error(
360
- "state_embs_dict must be provided for mode with cell_states_to_model. "
361
- "Format is dictionary with keys specifying each possible cell state to model. "
362
- "Values are target embedding positions as torch.tensor."
363
- )
364
- raise
365
-
366
- for state_emb in self.state_embs_dict.values():
367
- if not torch.is_tensor(state_emb):
368
- logger.error(
369
- "state_embs_dict must be dictionary with values being torch.tensor."
370
- )
371
- raise
372
-
373
- keys_absent = []
374
- for k, v in self.cell_states_to_model.items():
375
- if (k == "start_state") or (k == "goal_state"):
376
- if v not in self.state_embs_dict.keys():
377
- keys_absent.append(v)
378
- if k == "alt_states":
379
- for state in v:
380
- if state not in self.state_embs_dict.keys():
381
- keys_absent.append(state)
382
- if len(keys_absent) > 0:
383
- logger.error(
384
- "Each start_state, goal_state, and alt_states in cell_states_to_model "
385
- "must be a key in state_embs_dict with the value being "
386
- "the state's embedding position as torch.tensor. "
387
- f"Missing keys: {keys_absent}"
388
- )
389
- raise
390
-
391
- if self.perturb_type in ["inhibit", "activate"]:
392
- if self.perturb_rank_shift is None:
393
- logger.error(
394
- "If perturb_type is inhibit or activate then "
395
- "quartile to shift by must be specified."
396
- )
397
- raise
398
-
399
- if self.filter_data is not None:
400
- for key, value in self.filter_data.items():
401
- if not isinstance(value, list):
402
- self.filter_data[key] = [value]
403
- logger.warning(
404
- "Values in filter_data dict must be lists. "
405
- f"Changing {key} value to list ([{value}])."
406
- )
407
-
408
- if self.cell_inds_to_perturb != "all":
409
- if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
410
- logger.error(
411
- "If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
412
- )
413
- raise
414
- if (
415
- self.cell_inds_to_perturb["start"] < 0
416
- or self.cell_inds_to_perturb["end"] < 0
417
- ):
418
- logger.error("cell_inds_to_perturb must be positive.")
419
- raise
420
-
421
- def perturb_data(
422
- self, model_directory, input_data_file, output_directory, output_prefix
423
- ):
424
- """
425
- Perturb genes in input data and save as results in output_directory.
426
-
427
- **Parameters:**
428
-
429
- model_directory : Path
430
- | Path to directory containing model
431
- input_data_file : Path
432
- | Path to directory containing .dataset inputs
433
- output_directory : Path
434
- | Path to directory where perturbation data will be saved as batched pickle files
435
- output_prefix : str
436
- | Prefix for output files
437
- """
438
-
439
- ### format output path ###
440
- output_path_prefix = os.path.join(
441
- output_directory, f"in_silico_{self.perturb_type}_{output_prefix}"
442
- )
443
-
444
- ### load model and define parameters ###
445
- model = pu.load_model(
446
- self.model_type, self.num_classes, model_directory, mode="eval"
447
- )
448
- self.max_len = pu.get_model_input_size(model)
449
- layer_to_quant = pu.quant_layers(model) + self.emb_layer
450
-
451
- ### filter input data ###
452
- # general filtering of input data based on filter_data argument
453
- filtered_input_data = pu.load_and_filter(
454
- self.filter_data, self.nproc, input_data_file
455
- )
456
-
457
- # Ensure emb_mode is cls if first token of the filtered input data is cls token
458
- if self.special_token:
459
- if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and (
460
- "cls" not in self.emb_mode
461
- ):
462
- logger.error(
463
- "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
464
- )
465
- raise
466
- if "cls" in self.emb_mode:
467
- if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (
468
- filtered_input_data["input_ids"][0][-1] != self.eos_token_id
469
- ):
470
- logger.error(
471
- "Emb mode 'cls' and 'cls_and_gene' require that first token is <cls> and last token is <eos>."
472
- )
473
- raise
474
-
475
- filtered_input_data = self.apply_additional_filters(filtered_input_data)
476
-
477
- if self.perturb_group is True:
478
- if (self.special_token) and ("cls" in self.emb_mode):
479
- self.isp_perturb_set_special(
480
- model, filtered_input_data, layer_to_quant, output_path_prefix
481
- )
482
- else:
483
- self.isp_perturb_set(
484
- model, filtered_input_data, layer_to_quant, output_path_prefix
485
- )
486
- else:
487
- if (self.special_token) and ("cls" in self.emb_mode):
488
- self.isp_perturb_all_special(
489
- model, filtered_input_data, layer_to_quant, output_path_prefix
490
- )
491
- else:
492
- self.isp_perturb_all(
493
- model, filtered_input_data, layer_to_quant, output_path_prefix
494
- )
495
-
496
- def apply_additional_filters(self, filtered_input_data):
497
- # additional filtering of input data dependent on isp mode
498
- if self.cell_states_to_model is not None:
499
- # filter for cells with start_state and log result
500
- filtered_input_data = pu.filter_data_by_start_state(
501
- filtered_input_data, self.cell_states_to_model, self.nproc
502
- )
503
-
504
- if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
505
- # filter for cells with tokens_to_perturb and log result
506
- filtered_input_data = pu.filter_data_by_tokens_and_log(
507
- filtered_input_data,
508
- self.tokens_to_perturb,
509
- self.nproc,
510
- "genes_to_perturb",
511
- )
512
-
513
- if self.anchor_token is not None:
514
- # filter for cells with anchor gene and log result
515
- filtered_input_data = pu.filter_data_by_tokens_and_log(
516
- filtered_input_data, self.anchor_token, self.nproc, "anchor_gene"
517
- )
518
-
519
- # downsample and sort largest to smallest to encounter memory constraints earlier
520
- filtered_input_data = pu.downsample_and_sort(
521
- filtered_input_data, self.max_ncells
522
- )
523
-
524
- # slice dataset if cells_inds_to_perturb is not "all"
525
- if self.cell_inds_to_perturb != "all":
526
- filtered_input_data = pu.slice_by_inds_to_perturb(
527
- filtered_input_data, self.cell_inds_to_perturb
528
- )
529
-
530
- return filtered_input_data
531
-
532
- def isp_perturb_set(
533
- self,
534
- model,
535
- filtered_input_data: Dataset,
536
- layer_to_quant: int,
537
- output_path_prefix: str,
538
- ):
539
- def make_group_perturbation_batch(example):
540
- example_input_ids = example["input_ids"]
541
- example["tokens_to_perturb"] = self.tokens_to_perturb
542
- indices_to_perturb = [
543
- example_input_ids.index(token) if token in example_input_ids else None
544
- for token in self.tokens_to_perturb
545
- ]
546
- indices_to_perturb = [
547
- item for item in indices_to_perturb if item is not None
548
- ]
549
- if len(indices_to_perturb) > 0:
550
- example["perturb_index"] = indices_to_perturb
551
- else:
552
- # -100 indicates tokens to overexpress are not present in rank value encoding
553
- example["perturb_index"] = [-100]
554
- if self.perturb_type == "delete":
555
- example = pu.delete_indices(example)
556
- elif self.perturb_type == "overexpress":
557
- example = pu.overexpress_tokens(
558
- example, self.max_len, self.special_token
559
- )
560
- example["n_overflow"] = pu.calc_n_overflow(
561
- self.max_len,
562
- example["length"],
563
- self.tokens_to_perturb,
564
- indices_to_perturb,
565
- )
566
- return example
567
-
568
- total_batch_length = len(filtered_input_data)
569
- if self.cell_states_to_model is None:
570
- cos_sims_dict = defaultdict(list)
571
- else:
572
- cos_sims_dict = {
573
- state: defaultdict(list)
574
- for state in pu.get_possible_states(self.cell_states_to_model)
575
- }
576
-
577
- perturbed_data = filtered_input_data.map(
578
- make_group_perturbation_batch, num_proc=self.nproc
579
- )
580
-
581
- if self.perturb_type == "overexpress":
582
- filtered_input_data = filtered_input_data.add_column(
583
- "n_overflow", perturbed_data["n_overflow"]
584
- )
585
- # remove overflow genes from original data so that embeddings are comparable
586
- # i.e. if original cell has genes 0:2047 and you want to overexpress new gene 2048,
587
- # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
588
- # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
589
- # rather than only adding 2048)
590
- filtered_input_data = filtered_input_data.map(
591
- pu.truncate_by_n_overflow, num_proc=self.nproc
592
- )
593
-
594
- if self.emb_mode == "cell_and_gene":
595
- stored_gene_embs_dict = defaultdict(list)
596
-
597
- # iterate through batches
598
- for i in trange(0, total_batch_length, self.forward_batch_size):
599
- max_range = min(i + self.forward_batch_size, total_batch_length)
600
- inds_select = [i for i in range(i, max_range)]
601
-
602
- minibatch = filtered_input_data.select(inds_select)
603
- perturbation_batch = perturbed_data.select(inds_select)
604
-
605
- if self.cell_emb_style == "mean_pool":
606
- full_original_emb = get_embs(
607
- model,
608
- minibatch,
609
- "gene",
610
- layer_to_quant,
611
- self.pad_token_id,
612
- self.forward_batch_size,
613
- token_gene_dict=self.token_gene_dict,
614
- summary_stat=None,
615
- silent=True,
616
- )
617
- indices_to_perturb = perturbation_batch["perturb_index"]
618
- # remove indices that were perturbed
619
- original_emb = pu.remove_perturbed_indices_set(
620
- full_original_emb,
621
- self.perturb_type,
622
- indices_to_perturb,
623
- self.tokens_to_perturb,
624
- minibatch["length"],
625
- )
626
- full_perturbation_emb = get_embs(
627
- model,
628
- perturbation_batch,
629
- "gene",
630
- layer_to_quant,
631
- self.pad_token_id,
632
- self.forward_batch_size,
633
- token_gene_dict=self.token_gene_dict,
634
- summary_stat=None,
635
- silent=True,
636
- )
637
-
638
- # remove overexpressed genes
639
- if self.perturb_type == "overexpress":
640
- perturbation_emb = full_perturbation_emb[
641
- :, len(self.tokens_to_perturb) :, :
642
- ]
643
-
644
- elif self.perturb_type == "delete":
645
- perturbation_emb = full_perturbation_emb[
646
- :, : max(perturbation_batch["length"]), :
647
- ]
648
-
649
- n_perturbation_genes = perturbation_emb.size()[1]
650
-
651
- # if no goal states, the cosine similarties are the mean of gene cosine similarities
652
- if (
653
- self.cell_states_to_model is None
654
- or self.emb_mode == "cell_and_gene"
655
- ):
656
- gene_cos_sims = pu.quant_cos_sims(
657
- perturbation_emb,
658
- original_emb,
659
- self.cell_states_to_model,
660
- self.state_embs_dict,
661
- emb_mode="gene",
662
- )
663
-
664
- # if there are goal states, the cosine similarities are the cell cosine similarities
665
- if self.cell_states_to_model is not None:
666
- original_cell_emb = pu.mean_nonpadding_embs(
667
- full_original_emb,
668
- torch.tensor(minibatch["length"], device="cuda"),
669
- dim=1,
670
- )
671
- perturbation_cell_emb = pu.mean_nonpadding_embs(
672
- full_perturbation_emb,
673
- torch.tensor(perturbation_batch["length"], device="cuda"),
674
- dim=1,
675
- )
676
- cell_cos_sims = pu.quant_cos_sims(
677
- perturbation_cell_emb,
678
- original_cell_emb,
679
- self.cell_states_to_model,
680
- self.state_embs_dict,
681
- emb_mode="cell",
682
- )
683
-
684
- # get cosine similarities in gene embeddings
685
- # if getting gene embeddings, need gene names
686
- if self.emb_mode == "cell_and_gene":
687
- gene_list = minibatch["input_ids"]
688
- # need to truncate gene_list
689
- gene_list = [
690
- [g for g in genes if g not in self.tokens_to_perturb][
691
- :n_perturbation_genes
692
- ]
693
- for genes in gene_list
694
- ]
695
-
696
- for cell_i, genes in enumerate(gene_list):
697
- for gene_j, affected_gene in enumerate(genes):
698
- if len(self.genes_to_perturb) > 1:
699
- tokens_to_perturb = tuple(self.tokens_to_perturb)
700
- else:
701
- tokens_to_perturb = self.tokens_to_perturb[0]
702
-
703
- # fill in the gene cosine similarities
704
- try:
705
- stored_gene_embs_dict[
706
- (tokens_to_perturb, affected_gene)
707
- ].append(gene_cos_sims[cell_i, gene_j].item())
708
- except KeyError:
709
- stored_gene_embs_dict[
710
- (tokens_to_perturb, affected_gene)
711
- ] = gene_cos_sims[cell_i, gene_j].item()
712
- else:
713
- gene_list = None
714
-
715
- if self.cell_states_to_model is None:
716
- # calculate the mean of the gene cosine similarities for cell shift
717
- # tensor of nonpadding lengths for each cell
718
- if self.perturb_type == "overexpress":
719
- # subtract number of genes that were overexpressed
720
- # since they are removed before getting cos sims
721
- n_overexpressed = len(self.tokens_to_perturb)
722
- nonpadding_lens = [
723
- x - n_overexpressed for x in perturbation_batch["length"]
724
- ]
725
- else:
726
- nonpadding_lens = perturbation_batch["length"]
727
- cos_sims_data = pu.mean_nonpadding_embs(
728
- gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
729
- )
730
- cos_sims_dict = self.update_perturbation_dictionary(
731
- cos_sims_dict,
732
- cos_sims_data,
733
- gene_list,
734
- )
735
- else:
736
- cos_sims_data = cell_cos_sims
737
- for state in cos_sims_dict.keys():
738
- cos_sims_dict[state] = self.update_perturbation_dictionary(
739
- cos_sims_dict[state],
740
- cos_sims_data[state],
741
- gene_list,
742
- )
743
- del minibatch
744
- del perturbation_batch
745
- del original_emb
746
- del perturbation_emb
747
- del cos_sims_data
748
-
749
- torch.cuda.empty_cache()
750
-
751
- pu.write_perturbation_dictionary(
752
- cos_sims_dict,
753
- f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
754
- )
755
-
756
- if self.emb_mode == "cell_and_gene":
757
- pu.write_perturbation_dictionary(
758
- stored_gene_embs_dict,
759
- f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
760
- )
761
-
762
- def isp_perturb_set_special(
763
- self,
764
- model,
765
- filtered_input_data: Dataset,
766
- layer_to_quant: int,
767
- output_path_prefix: str,
768
- ):
769
- def make_group_perturbation_batch(example):
770
- example_input_ids = example["input_ids"]
771
- example["tokens_to_perturb"] = self.tokens_to_perturb
772
- indices_to_perturb = [
773
- example_input_ids.index(token) if token in example_input_ids else None
774
- for token in self.tokens_to_perturb
775
- ]
776
- indices_to_perturb = [
777
- item for item in indices_to_perturb if item is not None
778
- ]
779
- if len(indices_to_perturb) > 0:
780
- example["perturb_index"] = indices_to_perturb
781
- else:
782
- # -100 indicates tokens to overexpress are not present in rank value encoding
783
- example["perturb_index"] = [-100]
784
- if self.perturb_type == "delete":
785
- example = pu.delete_indices(example)
786
- elif self.perturb_type == "overexpress":
787
- example = pu.overexpress_tokens(
788
- example, self.max_len, self.special_token
789
- )
790
- example["n_overflow"] = pu.calc_n_overflow(
791
- self.max_len,
792
- example["length"],
793
- self.tokens_to_perturb,
794
- indices_to_perturb,
795
- )
796
- return example
797
-
798
- total_batch_length = len(filtered_input_data)
799
-
800
-
801
- if self.cell_states_to_model is None:
802
- cos_sims_dict = defaultdict(list)
803
- else:
804
- cos_sims_dict = {
805
- state: defaultdict(list)
806
- for state in pu.get_possible_states(self.cell_states_to_model)
807
- }
808
-
809
- perturbed_data = filtered_input_data.map(
810
- make_group_perturbation_batch, num_proc=self.nproc
811
- )
812
-
813
- if self.perturb_type == "overexpress":
814
- filtered_input_data = filtered_input_data.add_column(
815
- "n_overflow", perturbed_data["n_overflow"]
816
- )
817
- filtered_input_data = filtered_input_data.map(
818
- pu.truncate_by_n_overflow_special, num_proc=self.nproc
819
- )
820
-
821
- if self.emb_mode == "cls_and_gene":
822
- stored_gene_embs_dict = defaultdict(list)
823
-
824
- # iterate through batches
825
- for i in trange(0, total_batch_length, self.forward_batch_size):
826
- max_range = min(i + self.forward_batch_size, total_batch_length)
827
- inds_select = [i for i in range(i, max_range)]
828
-
829
- minibatch = filtered_input_data.select(inds_select)
830
- perturbation_batch = perturbed_data.select(inds_select)
831
-
832
- ##### CLS Embedding Mode #####
833
- if self.emb_mode == "cls":
834
- indices_to_perturb = perturbation_batch["perturb_index"]
835
-
836
- original_cls_emb = get_embs(
837
- model,
838
- minibatch,
839
- "cls",
840
- layer_to_quant,
841
- self.pad_token_id,
842
- self.forward_batch_size,
843
- token_gene_dict=self.token_gene_dict,
844
- summary_stat=None,
845
- silent=True,
846
- )
847
-
848
- perturbation_cls_emb = get_embs(
849
- model,
850
- perturbation_batch,
851
- "cls",
852
- layer_to_quant,
853
- self.pad_token_id,
854
- self.forward_batch_size,
855
- token_gene_dict=self.token_gene_dict,
856
- summary_stat=None,
857
- silent=True,
858
- )
859
-
860
- # Calculate the cosine similarities
861
- cls_cos_sims = pu.quant_cos_sims(
862
- perturbation_cls_emb,
863
- original_cls_emb,
864
- self.cell_states_to_model,
865
- self.state_embs_dict,
866
- emb_mode="cell",
867
- )
868
-
869
- # Update perturbation dictionary
870
- if self.cell_states_to_model is None:
871
- cos_sims_dict = self.update_perturbation_dictionary(
872
- cos_sims_dict,
873
- cls_cos_sims,
874
- gene_list=None,
875
- )
876
- else:
877
- for state in cos_sims_dict.keys():
878
- cos_sims_dict[state] = self.update_perturbation_dictionary(
879
- cos_sims_dict[state],
880
- cls_cos_sims[state],
881
- gene_list=None,
882
- )
883
-
884
- ##### CLS and Gene Embedding Mode #####
885
- elif self.emb_mode == "cls_and_gene":
886
- full_original_emb = get_embs(
887
- model,
888
- minibatch,
889
- "gene",
890
- layer_to_quant,
891
- self.pad_token_id,
892
- self.forward_batch_size,
893
- self.token_gene_dict,
894
- summary_stat=None,
895
- silent=True,
896
- )
897
- indices_to_perturb = perturbation_batch["perturb_index"]
898
-
899
- # remove indices that were perturbed
900
- original_emb = pu.remove_perturbed_indices_set(
901
- full_original_emb,
902
- self.perturb_type,
903
- indices_to_perturb,
904
- self.tokens_to_perturb,
905
- minibatch["length"],
906
- )
907
-
908
- full_perturbation_emb = get_embs(
909
- model,
910
- perturbation_batch,
911
- "gene",
912
- layer_to_quant,
913
- self.pad_token_id,
914
- self.forward_batch_size,
915
- self.token_gene_dict,
916
- summary_stat=None,
917
- silent=True,
918
- )
919
-
920
- # remove special tokens and padding
921
- original_emb = original_emb[:, 1:-1, :]
922
- if self.perturb_type == "overexpress":
923
- perturbation_emb = full_perturbation_emb[
924
- :, 1 + len(self.tokens_to_perturb) : -1, :
925
- ]
926
- elif self.perturb_type == "delete":
927
- perturbation_emb = full_perturbation_emb[
928
- :, 1 : max(perturbation_batch["length"]) - 1, :
929
- ]
930
-
931
- n_perturbation_genes = perturbation_emb.size()[1]
932
-
933
- # truncate the original embedding as necessary
934
- if self.perturb_type == "overexpress":
935
- def calc_perturbation_length(ids):
936
- if ids == [-100]:
937
- return 0
938
- else:
939
- return len(ids)
940
-
941
- max_tensor_size = max([length - calc_perturbation_length(ids) - 2 for length, ids in zip(minibatch["length"], indices_to_perturb)])
942
-
943
- max_n_overflow = max(minibatch["n_overflow"])
944
- if max_n_overflow > 0 and perturbation_emb.size()[1] < original_emb.size()[1]:
945
- original_emb = original_emb[:, 0 : perturbation_emb.size()[1], :]
946
- elif perturbation_emb.size()[1] < original_emb.size()[1]:
947
- original_emb = original_emb[:, 0:max_tensor_size, :]
948
-
949
- gene_cos_sims = pu.quant_cos_sims(
950
- perturbation_emb,
951
- original_emb,
952
- self.cell_states_to_model,
953
- self.state_embs_dict,
954
- emb_mode="gene",
955
- )
956
-
957
- # get cls emb
958
- original_cls_emb = full_original_emb[:, 0, :]
959
- perturbation_cls_emb = full_perturbation_emb[:, 0, :]
960
-
961
- cls_cos_sims = pu.quant_cos_sims(
962
- perturbation_cls_emb,
963
- original_cls_emb,
964
- self.cell_states_to_model,
965
- self.state_embs_dict,
966
- emb_mode="cell",
967
- )
968
-
969
- # get cosine similarities in gene embeddings
970
- # since getting gene embeddings, need gene names
971
-
972
- gene_list = minibatch["input_ids"]
973
- # need to truncate gene_list
974
- genes_to_exclude = self.tokens_to_perturb + [
975
- self.cls_token_id,
976
- self.eos_token_id,
977
- ]
978
- gene_list = [
979
- [g for g in genes if g not in genes_to_exclude][
980
- :n_perturbation_genes
981
- ]
982
- for genes in gene_list
983
- ]
984
-
985
- for cell_i, genes in enumerate(gene_list):
986
- for gene_j, affected_gene in enumerate(genes):
987
- if len(self.genes_to_perturb) > 1:
988
- tokens_to_perturb = tuple(self.tokens_to_perturb)
989
- else:
990
- tokens_to_perturb = self.tokens_to_perturb[0]
991
-
992
- # fill in the gene cosine similarities
993
- try:
994
- stored_gene_embs_dict[
995
- (tokens_to_perturb, affected_gene)
996
- ].append(gene_cos_sims[cell_i, gene_j].item())
997
- except KeyError:
998
- stored_gene_embs_dict[
999
- (tokens_to_perturb, affected_gene)
1000
- ] = gene_cos_sims[cell_i, gene_j].item()
1001
-
1002
- if self.cell_states_to_model is None:
1003
- cos_sims_dict = self.update_perturbation_dictionary(
1004
- cos_sims_dict,
1005
- cls_cos_sims,
1006
- gene_list=None,
1007
- )
1008
- else:
1009
- for state in cos_sims_dict.keys():
1010
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1011
- cos_sims_dict[state],
1012
- cls_cos_sims[state],
1013
- gene_list=None,
1014
- )
1015
- del full_original_emb
1016
- del original_emb
1017
- del full_perturbation_emb
1018
- del perturbation_emb
1019
- del gene_cos_sims
1020
-
1021
- del original_cls_emb
1022
- del perturbation_cls_emb
1023
- del cls_cos_sims
1024
- del minibatch
1025
- del perturbation_batch
1026
-
1027
- torch.cuda.empty_cache()
1028
-
1029
- pu.write_perturbation_dictionary(
1030
- cos_sims_dict,
1031
- f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}",
1032
- )
1033
-
1034
- if self.emb_mode == "cls_and_gene":
1035
- pu.write_perturbation_dictionary(
1036
- stored_gene_embs_dict,
1037
- f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}",
1038
- )
1039
-
1040
- def isp_perturb_all(
1041
- self,
1042
- model,
1043
- filtered_input_data: Dataset,
1044
- layer_to_quant: int,
1045
- output_path_prefix: str,
1046
- ):
1047
- pickle_batch = -1
1048
- if self.cell_states_to_model is None:
1049
- cos_sims_dict = defaultdict(list)
1050
- else:
1051
- cos_sims_dict = {
1052
- state: defaultdict(list)
1053
- for state in pu.get_possible_states(self.cell_states_to_model)
1054
- }
1055
-
1056
- if self.emb_mode == "cell_and_gene":
1057
- stored_gene_embs_dict = defaultdict(list)
1058
-
1059
- num_inds_perturbed = 1 + self.combos
1060
- for h in trange(len(filtered_input_data)):
1061
- example_cell = filtered_input_data.select([h])
1062
- full_original_emb = get_embs(
1063
- model,
1064
- example_cell,
1065
- "gene",
1066
- layer_to_quant,
1067
- self.pad_token_id,
1068
- self.forward_batch_size,
1069
- self.token_gene_dict,
1070
- summary_stat=None,
1071
- silent=True,
1072
- )
1073
-
1074
- if self.cell_states_to_model is not None:
1075
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
1076
- full_original_emb, "mean_pool"
1077
- )
1078
-
1079
- # gene_list is used to assign cos sims back to genes
1080
- gene_list = example_cell["input_ids"][0][:]
1081
- # need to remove the anchor gene
1082
- if self.anchor_token is not None:
1083
- for token in self.anchor_token:
1084
- gene_list.remove(token)
1085
- # index 0 is not overexpressed so remove
1086
- if self.perturb_type == "overexpress":
1087
- gene_list = gene_list[num_inds_perturbed:]
1088
- # remove perturbed index for gene list dict
1089
- perturbed_gene_dict = {
1090
- gene: gene_list[:i] + gene_list[i + 1 :]
1091
- for i, gene in enumerate(gene_list)
1092
- }
1093
-
1094
- perturbation_batch, indices_to_perturb = pu.make_perturbation_batch(
1095
- example_cell,
1096
- self.perturb_type,
1097
- self.tokens_to_perturb,
1098
- self.anchor_token,
1099
- self.combos,
1100
- self.nproc,
1101
- )
1102
-
1103
- ispall_total_batch_length = len(perturbation_batch)
1104
- for i in trange(
1105
- 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1106
- ):
1107
- ispall_max_range = min(
1108
- i + self.forward_batch_size, ispall_total_batch_length
1109
- )
1110
- perturbation_minibatch = perturbation_batch.select(
1111
- [i for i in range(i, ispall_max_range)]
1112
- )
1113
- indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1114
- gene_list_mini = gene_list[
1115
- i:ispall_max_range
1116
- ] # only perturbed genes from this minibatch
1117
-
1118
- full_perturbation_emb = get_embs(
1119
- model,
1120
- perturbation_minibatch,
1121
- "gene",
1122
- layer_to_quant,
1123
- self.pad_token_id,
1124
- self.forward_batch_size,
1125
- self.token_gene_dict,
1126
- summary_stat=None,
1127
- silent=True,
1128
- )
1129
-
1130
- del perturbation_minibatch
1131
-
1132
- # need to remove overexpressed gene to quantify cosine shifts
1133
- if self.perturb_type == "overexpress":
1134
- perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
1135
-
1136
- elif self.perturb_type == "delete":
1137
- perturbation_emb = full_perturbation_emb
1138
-
1139
- if (
1140
- self.cell_states_to_model is None
1141
- or self.emb_mode == "cell_and_gene"
1142
- ):
1143
- original_emb_minibatch = pu.make_comparison_batch(
1144
- full_original_emb, indices_to_perturb_mini, perturb_group=False
1145
- )
1146
- gene_cos_sims = pu.quant_cos_sims(
1147
- perturbation_emb,
1148
- original_emb_minibatch,
1149
- self.cell_states_to_model,
1150
- self.state_embs_dict,
1151
- emb_mode="gene",
1152
- )
1153
- del original_emb_minibatch
1154
-
1155
- if self.cell_states_to_model is not None:
1156
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
1157
- full_perturbation_emb, "mean_pool"
1158
- )
1159
-
1160
- cell_cos_sims = pu.quant_cos_sims(
1161
- perturbation_cell_emb,
1162
- original_cell_emb,
1163
- self.cell_states_to_model,
1164
- self.state_embs_dict,
1165
- emb_mode="cell",
1166
- )
1167
- del perturbation_cell_emb
1168
-
1169
- if self.emb_mode == "cell_and_gene":
1170
- for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1171
- for gene_j, affected_gene in enumerate(
1172
- perturbed_gene_dict[perturbed_gene]
1173
- ):
1174
- try:
1175
- stored_gene_embs_dict[
1176
- (perturbed_gene, affected_gene)
1177
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
1178
- except KeyError:
1179
- stored_gene_embs_dict[
1180
- (perturbed_gene, affected_gene)
1181
- ] = gene_cos_sims[perturbation_i, gene_j].item()
1182
-
1183
- del full_perturbation_emb
1184
-
1185
- if self.cell_states_to_model is None:
1186
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
1187
- cos_sims_dict = self.update_perturbation_dictionary(
1188
- cos_sims_dict,
1189
- cos_sims_data,
1190
- gene_list_mini,
1191
- )
1192
- else:
1193
- cos_sims_data = cell_cos_sims
1194
- for state in cos_sims_dict.keys():
1195
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1196
- cos_sims_dict[state],
1197
- cos_sims_data[state],
1198
- gene_list_mini,
1199
- )
1200
-
1201
- # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1202
- if i % self.clear_mem_ncells / 10 == 0:
1203
- pu.write_perturbation_dictionary(
1204
- cos_sims_dict,
1205
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1206
- )
1207
- if self.emb_mode == "cell_and_gene":
1208
- pu.write_perturbation_dictionary(
1209
- stored_gene_embs_dict,
1210
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1211
- )
1212
-
1213
- # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1214
- if i % self.clear_mem_ncells == 0:
1215
- pickle_batch += 1
1216
- if self.cell_states_to_model is None:
1217
- cos_sims_dict = defaultdict(list)
1218
- else:
1219
- cos_sims_dict = {
1220
- state: defaultdict(list)
1221
- for state in pu.get_possible_states(
1222
- self.cell_states_to_model
1223
- )
1224
- }
1225
-
1226
- if self.emb_mode == "cell_and_gene":
1227
- stored_gene_embs_dict = defaultdict(list)
1228
-
1229
- torch.cuda.empty_cache()
1230
-
1231
- pu.write_perturbation_dictionary(
1232
- cos_sims_dict,
1233
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1234
- )
1235
-
1236
- if self.emb_mode == "cell_and_gene":
1237
- pu.write_perturbation_dictionary(
1238
- stored_gene_embs_dict,
1239
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1240
- )
1241
-
1242
- pickle_batch = -1
1243
- if self.cell_states_to_model is None:
1244
- cos_sims_dict = defaultdict(list)
1245
- else:
1246
- cos_sims_dict = {
1247
- state: defaultdict(list)
1248
- for state in pu.get_possible_states(self.cell_states_to_model)
1249
- }
1250
-
1251
- if self.emb_mode == "cell_and_gene":
1252
- stored_gene_embs_dict = defaultdict(list)
1253
-
1254
- # clear memory between cells
1255
- del perturbation_batch
1256
- del full_original_emb
1257
- if self.cell_states_to_model is not None:
1258
- del original_cell_emb
1259
- torch.cuda.empty_cache()
1260
-
1261
- def isp_perturb_all_special(
1262
- self,
1263
- model,
1264
- filtered_input_data: Dataset,
1265
- layer_to_quant: int,
1266
- output_path_prefix: str,
1267
- ):
1268
- pickle_batch = -1
1269
- if self.cell_states_to_model is None:
1270
- cos_sims_dict = defaultdict(list)
1271
- else:
1272
- cos_sims_dict = {
1273
- state: defaultdict(list)
1274
- for state in pu.get_possible_states(self.cell_states_to_model)
1275
- }
1276
-
1277
- if self.emb_mode == "cls_and_gene":
1278
- stored_gene_embs_dict = defaultdict(list)
1279
-
1280
- num_inds_perturbed = 1 + self.combos
1281
- for h in trange(len(filtered_input_data)):
1282
- example_cell = filtered_input_data.select([h])
1283
-
1284
- # get original example cell cls and/or gene embs for comparison
1285
- if self.emb_mode == "cls":
1286
- original_cls_emb = get_embs(
1287
- model,
1288
- example_cell,
1289
- "cls",
1290
- layer_to_quant,
1291
- self.pad_token_id,
1292
- self.forward_batch_size,
1293
- self.token_gene_dict,
1294
- summary_stat=None,
1295
- silent=True,
1296
- )
1297
- elif self.emb_mode == "cls_and_gene":
1298
- full_original_emb = get_embs(
1299
- model,
1300
- example_cell,
1301
- "gene",
1302
- layer_to_quant,
1303
- self.pad_token_id,
1304
- self.forward_batch_size,
1305
- self.token_gene_dict,
1306
- summary_stat=None,
1307
- silent=True,
1308
- )
1309
- original_cls_emb = full_original_emb[:, 0, :].clone().detach()
1310
-
1311
- # gene_list is used to assign cos sims back to genes
1312
- gene_list = example_cell["input_ids"][0][:]
1313
-
1314
- # need to remove special tokens
1315
- for token in [self.cls_token_id, self.eos_token_id]:
1316
- gene_list.remove(token)
1317
- # need to remove the anchor gene
1318
- if self.anchor_token is not None:
1319
- for token in self.anchor_token:
1320
- gene_list.remove(token)
1321
- # index 0 is not overexpressed so remove
1322
- if self.perturb_type == "overexpress":
1323
- gene_list = gene_list[num_inds_perturbed:]
1324
- # remove perturbed index for gene list dict
1325
- perturbed_gene_dict = {
1326
- gene: gene_list[:i] + gene_list[i + 1 :]
1327
- for i, gene in enumerate(gene_list)
1328
- }
1329
-
1330
- perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special(
1331
- example_cell,
1332
- self.perturb_type,
1333
- self.tokens_to_perturb,
1334
- self.anchor_token,
1335
- self.combos,
1336
- self.nproc,
1337
- )
1338
-
1339
- ispall_total_batch_length = len(perturbation_batch)
1340
- for i in trange(
1341
- 0, ispall_total_batch_length, self.forward_batch_size, leave=False
1342
- ):
1343
- ispall_max_range = min(
1344
- i + self.forward_batch_size, ispall_total_batch_length
1345
- )
1346
- perturbation_minibatch = perturbation_batch.select(
1347
- [i for i in range(i, ispall_max_range)]
1348
- )
1349
- indices_to_perturb_mini = indices_to_perturb[i:ispall_max_range]
1350
- gene_list_mini = gene_list[
1351
- i:ispall_max_range
1352
- ] # only perturbed genes from this minibatch
1353
-
1354
- ##### CLS Embedding Mode #####
1355
- if self.emb_mode == "cls":
1356
- # Extract cls embeddings from perturbed cells
1357
- perturbation_cls_emb = get_embs(
1358
- model,
1359
- perturbation_minibatch,
1360
- "cls",
1361
- layer_to_quant,
1362
- self.pad_token_id,
1363
- self.forward_batch_size,
1364
- self.token_gene_dict,
1365
- summary_stat=None,
1366
- silent=True,
1367
- )
1368
-
1369
- # Calculate cosine similarities
1370
- cls_cos_sims = pu.quant_cos_sims(
1371
- perturbation_cls_emb,
1372
- original_cls_emb,
1373
- self.cell_states_to_model,
1374
- self.state_embs_dict,
1375
- emb_mode="cell",
1376
- )
1377
-
1378
- if self.cell_states_to_model is None:
1379
- cos_sims_dict = self.update_perturbation_dictionary(
1380
- cos_sims_dict,
1381
- cls_cos_sims,
1382
- gene_list_mini,
1383
- )
1384
- else:
1385
- for state in cos_sims_dict.keys():
1386
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1387
- cos_sims_dict[state],
1388
- cls_cos_sims[state],
1389
- gene_list_mini,
1390
- )
1391
-
1392
- del perturbation_minibatch
1393
- del perturbation_cls_emb
1394
- del cls_cos_sims
1395
-
1396
- ##### CLS and Gene Embedding Mode #####
1397
- elif self.emb_mode == "cls_and_gene":
1398
- full_perturbation_emb = get_embs(
1399
- model,
1400
- perturbation_minibatch,
1401
- "gene",
1402
- layer_to_quant,
1403
- self.pad_token_id,
1404
- self.forward_batch_size,
1405
- self.token_gene_dict,
1406
- summary_stat=None,
1407
- silent=True,
1408
- )
1409
-
1410
- # need to remove overexpressed gene and cls/eos to quantify cosine shifts
1411
- if self.perturb_type == "overexpress":
1412
- perturbation_emb = (
1413
- full_perturbation_emb[:, 1 + num_inds_perturbed : -1, :]
1414
- .clone()
1415
- .detach()
1416
- )
1417
- elif self.perturb_type == "delete":
1418
- perturbation_emb = (
1419
- full_perturbation_emb[:, 1:-1, :].clone().detach()
1420
- )
1421
-
1422
- original_emb_minibatch = pu.make_comparison_batch(
1423
- full_original_emb, indices_to_perturb_mini, perturb_group=False
1424
- )
1425
-
1426
- original_emb_minibatch = (
1427
- original_emb_minibatch[:, 1:-1, :].clone().detach()
1428
- )
1429
- gene_cos_sims = pu.quant_cos_sims(
1430
- perturbation_emb,
1431
- original_emb_minibatch,
1432
- self.cell_states_to_model,
1433
- self.state_embs_dict,
1434
- emb_mode="gene",
1435
- )
1436
-
1437
- for perturbation_i, perturbed_gene in enumerate(gene_list_mini):
1438
- for gene_j, affected_gene in enumerate(
1439
- perturbed_gene_dict[perturbed_gene]
1440
- ):
1441
- try:
1442
- stored_gene_embs_dict[
1443
- (perturbed_gene, affected_gene)
1444
- ].append(gene_cos_sims[perturbation_i, gene_j].item())
1445
- except KeyError:
1446
- stored_gene_embs_dict[
1447
- (perturbed_gene, affected_gene)
1448
- ] = gene_cos_sims[perturbation_i, gene_j].item()
1449
-
1450
- # get cls emb
1451
- perturbation_cls_emb = (
1452
- full_perturbation_emb[:, 0, :].clone().detach()
1453
- )
1454
-
1455
- cls_cos_sims = pu.quant_cos_sims(
1456
- perturbation_cls_emb,
1457
- original_cls_emb,
1458
- self.cell_states_to_model,
1459
- self.state_embs_dict,
1460
- emb_mode="cell",
1461
- )
1462
-
1463
- if self.cell_states_to_model is None:
1464
- cos_sims_dict = self.update_perturbation_dictionary(
1465
- cos_sims_dict,
1466
- cls_cos_sims,
1467
- gene_list_mini,
1468
- )
1469
- else:
1470
- for state in cos_sims_dict.keys():
1471
- cos_sims_dict[state] = self.update_perturbation_dictionary(
1472
- cos_sims_dict[state],
1473
- cls_cos_sims[state],
1474
- gene_list_mini,
1475
- )
1476
-
1477
- del perturbation_minibatch
1478
- del original_emb_minibatch
1479
- del full_perturbation_emb
1480
- del perturbation_emb
1481
- del perturbation_cls_emb
1482
- del cls_cos_sims
1483
- del gene_cos_sims
1484
-
1485
- # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells
1486
- if i % max(1, self.clear_mem_ncells / 10) == 0:
1487
- pu.write_perturbation_dictionary(
1488
- cos_sims_dict,
1489
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1490
- )
1491
- if self.emb_mode == "cls_and_gene":
1492
- pu.write_perturbation_dictionary(
1493
- stored_gene_embs_dict,
1494
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1495
- )
1496
-
1497
- # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell
1498
- if i % self.clear_mem_ncells == 0:
1499
- pickle_batch += 1
1500
- if self.cell_states_to_model is None:
1501
- cos_sims_dict = defaultdict(list)
1502
- else:
1503
- cos_sims_dict = {
1504
- state: defaultdict(list)
1505
- for state in pu.get_possible_states(
1506
- self.cell_states_to_model
1507
- )
1508
- }
1509
-
1510
- if self.emb_mode == "cls_and_gene":
1511
- stored_gene_embs_dict = defaultdict(list)
1512
-
1513
- torch.cuda.empty_cache()
1514
-
1515
- pu.write_perturbation_dictionary(
1516
- cos_sims_dict,
1517
- f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}",
1518
- )
1519
-
1520
- if self.emb_mode == "cls_and_gene":
1521
- pu.write_perturbation_dictionary(
1522
- stored_gene_embs_dict,
1523
- f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}",
1524
- )
1525
-
1526
- pickle_batch = -1
1527
- if self.cell_states_to_model is None:
1528
- cos_sims_dict = defaultdict(list)
1529
- else:
1530
- cos_sims_dict = {
1531
- state: defaultdict(list)
1532
- for state in pu.get_possible_states(self.cell_states_to_model)
1533
- }
1534
-
1535
- if self.emb_mode == "cls_and_gene":
1536
- stored_gene_embs_dict = defaultdict(list)
1537
-
1538
- # clear memory between cells
1539
- del perturbation_batch
1540
- del original_cls_emb
1541
- if self.emb_mode == "cls_and_gene":
1542
- del full_original_emb
1543
- torch.cuda.empty_cache()
1544
-
1545
- def update_perturbation_dictionary(
1546
- self,
1547
- cos_sims_dict: defaultdict,
1548
- cos_sims_data: torch.Tensor,
1549
- gene_list=None,
1550
- ):
1551
- if gene_list is not None and cos_sims_data.shape[0] != len(gene_list):
1552
- logger.error(
1553
- f"len(cos_sims_data.shape[0]) != len(gene_list). \n \
1554
- {cos_sims_data.shape[0]=}.\n \
1555
- {len(gene_list)=}."
1556
- )
1557
- raise
1558
-
1559
- if self.perturb_group is True:
1560
- if len(self.tokens_to_perturb) > 1:
1561
- perturbed_genes = tuple(self.tokens_to_perturb)
1562
- else:
1563
- perturbed_genes = self.tokens_to_perturb[0]
1564
-
1565
- # if cell embeddings, can just append
1566
- # shape will be (batch size, 1)
1567
- cos_sims_data = torch.squeeze(cos_sims_data).tolist()
1568
-
1569
- # handle case of single cell left
1570
- if not isinstance(cos_sims_data, list):
1571
- cos_sims_data = [cos_sims_data]
1572
-
1573
- cos_sims_dict[(perturbed_genes, "cell_emb")] += cos_sims_data
1574
-
1575
- else:
1576
- for i, cos in enumerate(cos_sims_data.tolist()):
1577
- cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1578
-
1579
- return cos_sims_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/in_silico_perturber_stats.py DELETED
@@ -1,1104 +0,0 @@
1
- """
2
- Geneformer in silico perturber stats generator.
3
-
4
- **Usage:**
5
-
6
- .. code-block :: python
7
-
8
- >>> from geneformer import InSilicoPerturberStats
9
- >>> ispstats = InSilicoPerturberStats(mode="goal_state_shift",
10
- ... cell_states_to_model={"state_key": "disease",
11
- ... "start_state": "dcm",
12
- ... "goal_state": "nf",
13
- ... "alt_states": ["hcm", "other1", "other2"]})
14
- >>> ispstats.get_stats("path/to/input_data",
15
- ... None,
16
- ... "path/to/output_directory",
17
- ... "output_prefix")
18
-
19
- **Description:**
20
-
21
- | Aggregates data or calculates stats for in silico perturbations based on type of statistics specified in InSilicoPerturberStats.
22
- | Input data is raw in silico perturbation results in the form of dictionaries outputted by ``in_silico_perturber``.
23
-
24
- """
25
-
26
-
27
- import logging
28
- import os
29
- import pickle
30
- import random
31
- from pathlib import Path
32
-
33
- import numpy as np
34
- import pandas as pd
35
- import statsmodels.stats.multitest as smt
36
- from scipy.stats import ranksums
37
- from sklearn.mixture import GaussianMixture
38
- from tqdm.auto import tqdm, trange
39
-
40
- from . import ENSEMBL_DICTIONARY_FILE, TOKEN_DICTIONARY_FILE
41
- from .perturber_utils import flatten_list, validate_cell_states_to_model
42
-
43
- logger = logging.getLogger(__name__)
44
-
45
-
46
- # invert dictionary keys/values
47
- def invert_dict(dictionary):
48
- return {v: k for k, v in dictionary.items()}
49
-
50
-
51
- def read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token):
52
- if cell_or_gene_emb == "cell":
53
- cell_emb_dict = {
54
- k: v for k, v in cos_sims_dict.items() if v and "cell_emb" in k
55
- }
56
- return [cell_emb_dict]
57
- elif cell_or_gene_emb == "gene":
58
- if anchor_token is None:
59
- gene_emb_dict = {k: v for k, v in cos_sims_dict.items() if v}
60
- else:
61
- gene_emb_dict = {
62
- k: v for k, v in cos_sims_dict.items() if v and anchor_token == k[0]
63
- }
64
- return [gene_emb_dict]
65
-
66
-
67
- # read raw dictionary files
68
- def read_dictionaries(
69
- input_data_directory,
70
- cell_or_gene_emb,
71
- anchor_token,
72
- cell_states_to_model,
73
- pickle_suffix,
74
- ):
75
- file_found = False
76
- file_path_list = []
77
- if cell_states_to_model is None:
78
- dict_list = []
79
- else:
80
- validate_cell_states_to_model(cell_states_to_model)
81
- cell_states_to_model_valid = {
82
- state: value
83
- for state, value in cell_states_to_model.items()
84
- if state != "state_key"
85
- and cell_states_to_model[state] is not None
86
- and cell_states_to_model[state] != []
87
- }
88
- cell_states_list = []
89
- # flatten all state values into list
90
- for state in cell_states_to_model_valid:
91
- value = cell_states_to_model_valid[state]
92
- if isinstance(value, list):
93
- cell_states_list += value
94
- else:
95
- cell_states_list.append(value)
96
- state_dict = {state_value: dict() for state_value in cell_states_list}
97
- for file in os.listdir(input_data_directory):
98
- # process only files with given suffix (e.g. "_raw.pickle")
99
- if file.endswith(pickle_suffix):
100
- file_found = True
101
- file_path_list += [f"{input_data_directory}/{file}"]
102
- for file_path in tqdm(file_path_list):
103
- with open(file_path, "rb") as fp:
104
- cos_sims_dict = pickle.load(fp)
105
- if cell_states_to_model is None:
106
- dict_list += read_dict(cos_sims_dict, cell_or_gene_emb, anchor_token)
107
- else:
108
- for state_value in cell_states_list:
109
- new_dict = read_dict(
110
- cos_sims_dict[state_value], cell_or_gene_emb, anchor_token
111
- )[0]
112
- for key in new_dict:
113
- try:
114
- state_dict[state_value][key] += new_dict[key]
115
- except KeyError:
116
- state_dict[state_value][key] = new_dict[key]
117
-
118
- if not file_found:
119
- logger.error(
120
- "No raw data for processing found within provided directory. "
121
- "Please ensure data files end with '{pickle_suffix}'."
122
- )
123
- raise
124
- if cell_states_to_model is None:
125
- return dict_list
126
- else:
127
- return state_dict
128
-
129
-
130
- # get complete gene list
131
- def get_gene_list(dict_list, mode):
132
- if mode == "cell":
133
- position = 0
134
- elif mode == "gene":
135
- position = 1
136
- gene_set = set()
137
- if isinstance(dict_list, list):
138
- for dict_i in dict_list:
139
- gene_set.update([k[position] for k, v in dict_i.items() if v])
140
- elif isinstance(dict_list, dict):
141
- for state, dict_i in dict_list.items():
142
- gene_set.update([k[position] for k, v in dict_i.items() if v])
143
- else:
144
- logger.error(
145
- "dict_list should be a list, or if modeling shift to goal states, a dict. "
146
- f"{type(dict_list)} is not the correct format."
147
- )
148
- raise
149
- gene_list = list(gene_set)
150
- if mode == "gene":
151
- gene_list.remove("cell_emb")
152
- gene_list.sort()
153
- return gene_list
154
-
155
-
156
- def token_tuple_to_ensembl_ids(token_tuple, gene_token_id_dict):
157
- try:
158
- return tuple([gene_token_id_dict.get(i, np.nan) for i in token_tuple])
159
- except TypeError:
160
- return gene_token_id_dict.get(token_tuple, np.nan)
161
-
162
-
163
- def n_detections(token, dict_list, mode, anchor_token):
164
- cos_sim_megalist = []
165
- for dict_i in dict_list:
166
- if mode == "cell":
167
- cos_sim_megalist += dict_i.get((token, "cell_emb"), [])
168
- elif mode == "gene":
169
- cos_sim_megalist += dict_i.get((anchor_token, token), [])
170
- return len(cos_sim_megalist)
171
-
172
-
173
- def get_fdr(pvalues):
174
- return list(smt.multipletests(pvalues, alpha=0.05, method="fdr_bh")[1])
175
-
176
-
177
- def get_impact_component(test_value, gaussian_mixture_model):
178
- impact_border = gaussian_mixture_model.means_[0][0]
179
- nonimpact_border = gaussian_mixture_model.means_[1][0]
180
- if test_value > nonimpact_border:
181
- impact_component = 0
182
- elif test_value < impact_border:
183
- impact_component = 1
184
- else:
185
- impact_component_raw = gaussian_mixture_model.predict([[test_value]])[0]
186
- if impact_component_raw == 1:
187
- impact_component = 0
188
- elif impact_component_raw == 0:
189
- impact_component = 1
190
- return impact_component
191
-
192
-
193
- # aggregate data for single perturbation in multiple cells
194
- def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
195
- names = ["Cosine_sim", "Gene"]
196
- cos_sims_full_dfs = []
197
- if isinstance(genes_perturbed, list):
198
- if len(genes_perturbed) > 1:
199
- gene_ids_df = cos_sims_df.loc[
200
- np.isin(
201
- [set(idx) for idx in cos_sims_df["Ensembl_ID"]],
202
- set(genes_perturbed),
203
- ),
204
- :,
205
- ]
206
- else:
207
- gene_ids_df = cos_sims_df.loc[
208
- np.isin(cos_sims_df["Ensembl_ID"], genes_perturbed), :
209
- ]
210
- else:
211
- logger.error(
212
- "aggregate_data is for perturbation of single gene or single group of genes. genes_to_perturb should be formatted as list."
213
- )
214
- raise
215
-
216
- if gene_ids_df.empty:
217
- logger.error("genes_to_perturb not found in data.")
218
- raise
219
-
220
- tokens = gene_ids_df["Gene"]
221
- symbols = gene_ids_df["Gene_name"]
222
-
223
- for token, symbol in zip(tokens, symbols):
224
- cos_shift_data = []
225
- for dict_i in dict_list:
226
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
227
-
228
- df = pd.DataFrame(columns=names)
229
- df["Cosine_sim"] = cos_shift_data
230
- df["Gene"] = symbol
231
- cos_sims_full_dfs.append(df)
232
-
233
- return pd.concat(cos_sims_full_dfs)
234
-
235
-
236
- def find(variable, x):
237
- try:
238
- if x in variable: # Test if variable is iterable and contains x
239
- return True
240
- elif x == variable:
241
- return True
242
- except (ValueError, TypeError):
243
- return x == variable # Test if variable is x if non-iterable
244
-
245
-
246
- def isp_aggregate_gene_shifts(
247
- cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype
248
- ):
249
- cos_shift_data = dict()
250
- for i in trange(cos_sims_df.shape[0]):
251
- token = cos_sims_df["Gene"][i]
252
- for dict_i in dict_list:
253
- if token_dtype == "nontuple":
254
- affected_pairs = [k for k, v in dict_i.items() if k[0] == token]
255
- else:
256
- affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)]
257
- for key in affected_pairs:
258
- if key in cos_shift_data.keys():
259
- cos_shift_data[key] += dict_i.get(key, [])
260
- else:
261
- cos_shift_data[key] = dict_i.get(key, [])
262
-
263
- cos_data_mean = {
264
- k: [np.mean(v), np.std(v), len(v)] for k, v in cos_shift_data.items()
265
- }
266
- cos_sims_full_df = pd.DataFrame()
267
- cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()]
268
- cos_sims_full_df["Gene_name"] = [
269
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item()
270
- for k, v in cos_data_mean.items()
271
- ]
272
- cos_sims_full_df["Ensembl_ID"] = [
273
- cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item()
274
- for k, v in cos_data_mean.items()
275
- ]
276
-
277
- cos_sims_full_df["Affected"] = [k[1] for k, v in cos_data_mean.items()]
278
- cos_sims_full_df["Affected_gene_name"] = [
279
- gene_id_name_dict.get(gene_token_id_dict.get(token, np.nan), np.nan)
280
- for token in cos_sims_full_df["Affected"]
281
- ]
282
- cos_sims_full_df["Affected_Ensembl_ID"] = [
283
- gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
284
- ]
285
- cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
286
- cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
287
- cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
288
-
289
- specific_val = "cell_emb"
290
- cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
291
- # reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
292
- cos_sims_full_df = cos_sims_full_df.sort_values(
293
- by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
294
- ).drop("temp", axis=1)
295
-
296
- return cos_sims_full_df
297
-
298
-
299
- # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
300
- def isp_stats_to_goal_state(
301
- cos_sims_df, result_dict, cell_states_to_model, genes_perturbed
302
- ):
303
- if (
304
- ("alt_states" not in cell_states_to_model.keys())
305
- or (len(cell_states_to_model["alt_states"]) == 0)
306
- or (cell_states_to_model["alt_states"] == [None])
307
- ):
308
- alt_end_state_exists = False
309
- elif (len(cell_states_to_model["alt_states"]) > 0) and (
310
- cell_states_to_model["alt_states"] != [None]
311
- ):
312
- alt_end_state_exists = True
313
-
314
- # for single perturbation in multiple cells, there are no random perturbations to compare to
315
- if genes_perturbed != "all":
316
- cos_sims_full_df = pd.DataFrame()
317
-
318
- cos_shift_data_end = []
319
- token = cos_sims_df["Gene"][0]
320
- cos_shift_data_end += result_dict[cell_states_to_model["goal_state"]].get(
321
- (token, "cell_emb"), []
322
- )
323
- cos_sims_full_df["Shift_to_goal_end"] = [np.mean(cos_shift_data_end)]
324
- if alt_end_state_exists is True:
325
- for alt_state in cell_states_to_model["alt_states"]:
326
- cos_shift_data_alt_state = []
327
- cos_shift_data_alt_state += result_dict.get(alt_state).get(
328
- (token, "cell_emb"), []
329
- )
330
- cos_sims_full_df[f"Shift_to_alt_end_{alt_state}"] = [
331
- np.mean(cos_shift_data_alt_state)
332
- ]
333
-
334
- # sort by shift to desired state
335
- cos_sims_full_df = cos_sims_full_df.sort_values(
336
- by=["Shift_to_goal_end"], ascending=[False]
337
- )
338
- return cos_sims_full_df
339
-
340
- elif genes_perturbed == "all":
341
- goal_end_random_megalist = []
342
- if alt_end_state_exists is True:
343
- alt_end_state_random_dict = {
344
- alt_state: [] for alt_state in cell_states_to_model["alt_states"]
345
- }
346
- for i in trange(cos_sims_df.shape[0]):
347
- token = cos_sims_df["Gene"][i]
348
- goal_end_random_megalist += result_dict[
349
- cell_states_to_model["goal_state"]
350
- ].get((token, "cell_emb"), [])
351
- if alt_end_state_exists is True:
352
- for alt_state in cell_states_to_model["alt_states"]:
353
- alt_end_state_random_dict[alt_state] += result_dict[alt_state].get(
354
- (token, "cell_emb"), []
355
- )
356
-
357
- # downsample to improve speed of ranksums
358
- if len(goal_end_random_megalist) > 100_000:
359
- random.seed(42)
360
- goal_end_random_megalist = random.sample(
361
- goal_end_random_megalist, k=100_000
362
- )
363
- if alt_end_state_exists is True:
364
- for alt_state in cell_states_to_model["alt_states"]:
365
- if len(alt_end_state_random_dict[alt_state]) > 100_000:
366
- random.seed(42)
367
- alt_end_state_random_dict[alt_state] = random.sample(
368
- alt_end_state_random_dict[alt_state], k=100_000
369
- )
370
-
371
- names = [
372
- "Gene",
373
- "Gene_name",
374
- "Ensembl_ID",
375
- "Shift_to_goal_end",
376
- "Goal_end_vs_random_pval",
377
- ]
378
- if alt_end_state_exists is True:
379
- [
380
- names.append(f"Shift_to_alt_end_{alt_state}")
381
- for alt_state in cell_states_to_model["alt_states"]
382
- ]
383
- names.append(names.pop(names.index("Goal_end_vs_random_pval")))
384
- [
385
- names.append(f"Alt_end_vs_random_pval_{alt_state}")
386
- for alt_state in cell_states_to_model["alt_states"]
387
- ]
388
- cos_sims_full_df = pd.DataFrame(columns=names)
389
-
390
- n_detections_dict = dict()
391
- for i in trange(cos_sims_df.shape[0]):
392
- token = cos_sims_df["Gene"][i]
393
- name = cos_sims_df["Gene_name"][i]
394
- ensembl_id = cos_sims_df["Ensembl_ID"][i]
395
- goal_end_cos_sim_megalist = result_dict[
396
- cell_states_to_model["goal_state"]
397
- ].get((token, "cell_emb"), [])
398
- n_detections_dict[token] = len(goal_end_cos_sim_megalist)
399
- mean_goal_end = np.mean(goal_end_cos_sim_megalist)
400
- pval_goal_end = ranksums(
401
- goal_end_random_megalist, goal_end_cos_sim_megalist
402
- ).pvalue
403
-
404
- if alt_end_state_exists is True:
405
- alt_end_state_dict = {
406
- alt_state: [] for alt_state in cell_states_to_model["alt_states"]
407
- }
408
- for alt_state in cell_states_to_model["alt_states"]:
409
- alt_end_state_dict[alt_state] = result_dict[alt_state].get(
410
- (token, "cell_emb"), []
411
- )
412
- alt_end_state_dict[f"{alt_state}_mean"] = np.mean(
413
- alt_end_state_dict[alt_state]
414
- )
415
- alt_end_state_dict[f"{alt_state}_pval"] = ranksums(
416
- alt_end_state_random_dict[alt_state],
417
- alt_end_state_dict[alt_state],
418
- ).pvalue
419
-
420
- results_dict = dict()
421
- results_dict["Gene"] = token
422
- results_dict["Gene_name"] = name
423
- results_dict["Ensembl_ID"] = ensembl_id
424
- results_dict["Shift_to_goal_end"] = mean_goal_end
425
- results_dict["Goal_end_vs_random_pval"] = pval_goal_end
426
- if alt_end_state_exists is True:
427
- for alt_state in cell_states_to_model["alt_states"]:
428
- results_dict[f"Shift_to_alt_end_{alt_state}"] = alt_end_state_dict[
429
- f"{alt_state}_mean"
430
- ]
431
- results_dict[
432
- f"Alt_end_vs_random_pval_{alt_state}"
433
- ] = alt_end_state_dict[f"{alt_state}_pval"]
434
-
435
- cos_sims_df_i = pd.DataFrame(results_dict, index=[i])
436
- cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
437
-
438
- cos_sims_full_df["Goal_end_FDR"] = get_fdr(
439
- list(cos_sims_full_df["Goal_end_vs_random_pval"])
440
- )
441
- if alt_end_state_exists is True:
442
- for alt_state in cell_states_to_model["alt_states"]:
443
- cos_sims_full_df[f"Alt_end_FDR_{alt_state}"] = get_fdr(
444
- list(cos_sims_full_df[f"Alt_end_vs_random_pval_{alt_state}"])
445
- )
446
-
447
- # quantify number of detections of each gene
448
- cos_sims_full_df["N_Detections"] = [
449
- n_detections_dict[token] for token in cos_sims_full_df["Gene"]
450
- ]
451
-
452
- # sort by shift to desired state
453
- cos_sims_full_df["Sig"] = [
454
- 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Goal_end_FDR"]
455
- ]
456
- cos_sims_full_df = cos_sims_full_df.sort_values(
457
- by=["Sig", "Shift_to_goal_end", "Goal_end_FDR"],
458
- ascending=[False, False, True],
459
- )
460
-
461
- return cos_sims_full_df
462
-
463
-
464
- # stats comparing cos sim shifts of test perturbations vs null distribution
465
- def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
466
- cos_sims_full_df = cos_sims_df.copy()
467
-
468
- cos_sims_full_df["Test_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
469
- cos_sims_full_df["Null_avg_shift"] = np.zeros(cos_sims_df.shape[0], dtype=float)
470
- cos_sims_full_df["Test_vs_null_avg_shift"] = np.zeros(
471
- cos_sims_df.shape[0], dtype=float
472
- )
473
- cos_sims_full_df["Test_vs_null_pval"] = np.zeros(cos_sims_df.shape[0], dtype=float)
474
- cos_sims_full_df["Test_vs_null_FDR"] = np.zeros(cos_sims_df.shape[0], dtype=float)
475
- cos_sims_full_df["N_Detections_test"] = np.zeros(
476
- cos_sims_df.shape[0], dtype="uint32"
477
- )
478
- cos_sims_full_df["N_Detections_null"] = np.zeros(
479
- cos_sims_df.shape[0], dtype="uint32"
480
- )
481
-
482
- for i in trange(cos_sims_df.shape[0]):
483
- token = cos_sims_df["Gene"][i]
484
- test_shifts = []
485
- null_shifts = []
486
-
487
- for dict_i in dict_list:
488
- test_shifts += dict_i.get((token, "cell_emb"), [])
489
-
490
- for dict_i in null_dict_list:
491
- null_shifts += dict_i.get((token, "cell_emb"), [])
492
-
493
- cos_sims_full_df.loc[i, "Test_avg_shift"] = np.mean(test_shifts)
494
- cos_sims_full_df.loc[i, "Null_avg_shift"] = np.mean(null_shifts)
495
- cos_sims_full_df.loc[i, "Test_vs_null_avg_shift"] = np.mean(
496
- test_shifts
497
- ) - np.mean(null_shifts)
498
- cos_sims_full_df.loc[i, "Test_vs_null_pval"] = ranksums(
499
- test_shifts, null_shifts, nan_policy="omit"
500
- ).pvalue
501
- # remove nan values
502
- cos_sims_full_df.Test_vs_null_pval = np.where(
503
- np.isnan(cos_sims_full_df.Test_vs_null_pval),
504
- 1,
505
- cos_sims_full_df.Test_vs_null_pval,
506
- )
507
- cos_sims_full_df.loc[i, "N_Detections_test"] = len(test_shifts)
508
- cos_sims_full_df.loc[i, "N_Detections_null"] = len(null_shifts)
509
-
510
- cos_sims_full_df["Test_vs_null_FDR"] = get_fdr(
511
- cos_sims_full_df["Test_vs_null_pval"]
512
- )
513
-
514
- cos_sims_full_df["Sig"] = [
515
- 1 if fdr < 0.05 else 0 for fdr in cos_sims_full_df["Test_vs_null_FDR"]
516
- ]
517
- cos_sims_full_df = cos_sims_full_df.sort_values(
518
- by=["Sig", "Test_vs_null_avg_shift", "Test_vs_null_FDR"],
519
- ascending=[False, False, True],
520
- )
521
- return cos_sims_full_df
522
-
523
-
524
- # stats for identifying perturbations with largest effect within a given set of cells
525
- # fits a mixture model to 2 components (impact vs. non-impact) and
526
- # reports the most likely component for each test perturbation
527
- # Note: because assumes given perturbation has a consistent effect in the cells tested,
528
- # we recommend only using the mixture model strategy with uniform cell populations
529
- def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
530
- names = ["Gene", "Gene_name", "Ensembl_ID"]
531
-
532
- if combos == 0:
533
- names += ["Test_avg_shift"]
534
- elif combos == 1:
535
- names += [
536
- "Anchor_shift",
537
- "Test_token_shift",
538
- "Sum_of_indiv_shifts",
539
- "Combo_shift",
540
- "Combo_minus_sum_shift",
541
- ]
542
-
543
- names += ["Impact_component", "Impact_component_percent"]
544
-
545
- cos_sims_full_df = pd.DataFrame(columns=names)
546
- avg_values = []
547
- gene_names = []
548
-
549
- for i in trange(cos_sims_df.shape[0]):
550
- token = cos_sims_df["Gene"][i]
551
- name = cos_sims_df["Gene_name"][i]
552
- ensembl_id = cos_sims_df["Ensembl_ID"][i]
553
- cos_shift_data = []
554
-
555
- for dict_i in dict_list:
556
- if (combos == 0) and (anchor_token is not None):
557
- cos_shift_data += dict_i.get((anchor_token, token), [])
558
- else:
559
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
560
-
561
- # Extract values for current gene
562
- if combos == 0:
563
- test_values = cos_shift_data
564
- elif combos == 1:
565
- test_values = []
566
- for tup in cos_shift_data:
567
- test_values.append(tup[2])
568
-
569
- if len(test_values) > 0:
570
- avg_value = np.mean(test_values)
571
- avg_values.append(avg_value)
572
- gene_names.append(name)
573
-
574
- # fit Gaussian mixture model to dataset of mean for each gene
575
- avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
576
- gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
577
-
578
- for i in trange(cos_sims_df.shape[0]):
579
- token = cos_sims_df["Gene"][i]
580
- name = cos_sims_df["Gene_name"][i]
581
- ensembl_id = cos_sims_df["Ensembl_ID"][i]
582
- cos_shift_data = []
583
-
584
- for dict_i in dict_list:
585
- if (combos == 0) and (anchor_token is not None):
586
- cos_shift_data += dict_i.get((anchor_token, token), [])
587
- else:
588
- cos_shift_data += dict_i.get((token, "cell_emb"), [])
589
-
590
- if combos == 0:
591
- mean_test = np.mean(cos_shift_data)
592
- impact_components = [
593
- get_impact_component(value, gm) for value in cos_shift_data
594
- ]
595
- elif combos == 1:
596
- anchor_cos_sim_megalist = [
597
- anchor for anchor, token, combo in cos_shift_data
598
- ]
599
- token_cos_sim_megalist = [token for anchor, token, combo in cos_shift_data]
600
- anchor_plus_token_cos_sim_megalist = [
601
- 1 - ((1 - anchor) + (1 - token))
602
- for anchor, token, combo in cos_shift_data
603
- ]
604
- combo_anchor_token_cos_sim_megalist = [
605
- combo for anchor, token, combo in cos_shift_data
606
- ]
607
- combo_minus_sum_cos_sim_megalist = [
608
- combo - (1 - ((1 - anchor) + (1 - token)))
609
- for anchor, token, combo in cos_shift_data
610
- ]
611
-
612
- mean_anchor = np.mean(anchor_cos_sim_megalist)
613
- mean_token = np.mean(token_cos_sim_megalist)
614
- mean_sum = np.mean(anchor_plus_token_cos_sim_megalist)
615
- mean_test = np.mean(combo_anchor_token_cos_sim_megalist)
616
- mean_combo_minus_sum = np.mean(combo_minus_sum_cos_sim_megalist)
617
-
618
- impact_components = [
619
- get_impact_component(value, gm)
620
- for value in combo_anchor_token_cos_sim_megalist
621
- ]
622
-
623
- impact_component = get_impact_component(mean_test, gm)
624
- impact_component_percent = np.mean(impact_components) * 100
625
-
626
- data_i = [token, name, ensembl_id]
627
- if combos == 0:
628
- data_i += [mean_test]
629
- elif combos == 1:
630
- data_i += [
631
- mean_anchor,
632
- mean_token,
633
- mean_sum,
634
- mean_test,
635
- mean_combo_minus_sum,
636
- ]
637
- data_i += [impact_component, impact_component_percent]
638
-
639
- cos_sims_df_i = pd.DataFrame(dict(zip(names, data_i)), index=[i])
640
- cos_sims_full_df = pd.concat([cos_sims_full_df, cos_sims_df_i])
641
-
642
- # quantify number of detections of each gene
643
- if anchor_token is None:
644
- cos_sims_full_df["N_Detections"] = [
645
- n_detections(i, dict_list, "cell", anchor_token)
646
- for i in cos_sims_full_df["Gene"]
647
- ]
648
- else:
649
- cos_sims_full_df["N_Detections"] = [
650
- n_detections(i, dict_list, "gene", anchor_token)
651
- for i in cos_sims_full_df["Gene"]
652
- ]
653
-
654
- if combos == 0:
655
- cos_sims_full_df = cos_sims_full_df.sort_values(
656
- by=["Impact_component", "Test_avg_shift"], ascending=[False, True]
657
- )
658
- elif combos == 1:
659
- cos_sims_full_df = cos_sims_full_df.sort_values(
660
- by=["Impact_component", "Combo_minus_sum_shift"], ascending=[False, True]
661
- )
662
- return cos_sims_full_df
663
-
664
-
665
- class InSilicoPerturberStats:
666
- valid_option_dict = {
667
- "mode": {
668
- "goal_state_shift",
669
- "vs_null",
670
- "mixture_model",
671
- "aggregate_data",
672
- "aggregate_gene_shifts",
673
- },
674
- "genes_perturbed": {"all", list},
675
- "combos": {0, 1},
676
- "anchor_gene": {None, str},
677
- "cell_states_to_model": {None, dict},
678
- "pickle_suffix": {None, str},
679
- }
680
-
681
- def __init__(
682
- self,
683
- mode="mixture_model",
684
- genes_perturbed="all",
685
- combos=0,
686
- anchor_gene=None,
687
- cell_states_to_model=None,
688
- pickle_suffix="_raw.pickle",
689
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
690
- gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
691
- ):
692
- """
693
- Initialize in silico perturber stats generator.
694
-
695
- **Parameters:**
696
-
697
- mode : {"goal_state_shift", "vs_null", "mixture_model", "aggregate_data", "aggregate_gene_shifts"}
698
- | Type of stats.
699
- | "goal_state_shift": perturbation vs. random for desired cell state shift
700
- | "vs_null": perturbation vs. null from provided null distribution dataset
701
- | "mixture_model": perturbation in impact vs. no impact component of mixture model (no goal direction)
702
- | "aggregate_data": aggregates cosine shifts for single perturbation in multiple cells
703
- | "aggregate_gene_shifts": aggregates cosine shifts of genes in response to perturbation(s)
704
- genes_perturbed : "all", list
705
- | Genes perturbed in isp experiment.
706
- | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell).
707
- | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together.
708
- combos : {0,1,2}
709
- | Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2).
710
- anchor_gene : None, str
711
- | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
712
- | For example, if combos=1 and anchor_gene="ENSG00000136574":
713
- | analyzes data for anchor gene perturbed in combination with each other gene.
714
- | However, if combos=0 and anchor_gene="ENSG00000136574":
715
- | analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
716
- cell_states_to_model: None, dict
717
- | Cell states to model if testing perturbations that achieve goal state change.
718
- | Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
719
- | state_key: key specifying name of column in .dataset that defines the start/goal states
720
- | start_state: value in the state_key column that specifies the start state
721
- | goal_state: value in the state_key column taht specifies the goal end state
722
- | alt_states: list of values in the state_key column that specify the alternate end states
723
- | For example: {"state_key": "disease",
724
- | "start_state": "dcm",
725
- | "goal_state": "nf",
726
- | "alt_states": ["hcm", "other1", "other2"]}
727
- token_dictionary_file : Path
728
- | Path to pickle file containing token dictionary (Ensembl ID:token).
729
- gene_name_id_dictionary_file : Path
730
- | Path to pickle file containing gene name to ID dictionary (gene name:Ensembl ID).
731
- """
732
-
733
- self.mode = mode
734
- self.genes_perturbed = genes_perturbed
735
- self.combos = combos
736
- self.anchor_gene = anchor_gene
737
- self.cell_states_to_model = cell_states_to_model
738
- self.pickle_suffix = pickle_suffix
739
-
740
- self.validate_options()
741
-
742
- # load token dictionary (Ensembl IDs:token)
743
- with open(token_dictionary_file, "rb") as f:
744
- self.gene_token_dict = pickle.load(f)
745
-
746
- # load gene name dictionary (gene name:Ensembl ID)
747
- with open(gene_name_id_dictionary_file, "rb") as f:
748
- self.gene_name_id_dict = pickle.load(f)
749
-
750
- if anchor_gene is None:
751
- self.anchor_token = None
752
- else:
753
- self.anchor_token = self.gene_token_dict[self.anchor_gene]
754
-
755
- def validate_options(self):
756
- for attr_name, valid_options in self.valid_option_dict.items():
757
- attr_value = self.__dict__[attr_name]
758
- if type(attr_value) not in {list, dict}:
759
- if attr_name in {"anchor_gene"}:
760
- continue
761
- elif attr_value in valid_options:
762
- continue
763
- valid_type = False
764
- for option in valid_options:
765
- if (option in [str, int, list, dict]) and isinstance(
766
- attr_value, option
767
- ):
768
- valid_type = True
769
- break
770
- if not valid_type:
771
- logger.error(
772
- f"Invalid option for {attr_name}. "
773
- f"Valid options for {attr_name}: {valid_options}"
774
- )
775
- raise
776
-
777
- if self.cell_states_to_model is not None:
778
- if len(self.cell_states_to_model.items()) == 1:
779
- logger.warning(
780
- "The single value dictionary for cell_states_to_model will be "
781
- "replaced with a dictionary with named keys for start, goal, and alternate states. "
782
- "Please specify state_key, start_state, goal_state, and alt_states "
783
- "in the cell_states_to_model dictionary for future use. "
784
- "For example, cell_states_to_model={"
785
- "'state_key': 'disease', "
786
- "'start_state': 'dcm', "
787
- "'goal_state': 'nf', "
788
- "'alt_states': ['hcm', 'other1', 'other2']}"
789
- )
790
- for key, value in self.cell_states_to_model.items():
791
- if (len(value) == 3) and isinstance(value, tuple):
792
- if (
793
- isinstance(value[0], list)
794
- and isinstance(value[1], list)
795
- and isinstance(value[2], list)
796
- ):
797
- if len(value[0]) == 1 and len(value[1]) == 1:
798
- all_values = value[0] + value[1] + value[2]
799
- if len(all_values) == len(set(all_values)):
800
- continue
801
- # reformat to the new named key format
802
- state_values = flatten_list(list(self.cell_states_to_model.values()))
803
- self.cell_states_to_model = {
804
- "state_key": list(self.cell_states_to_model.keys())[0],
805
- "start_state": state_values[0][0],
806
- "goal_state": state_values[1][0],
807
- "alt_states": state_values[2:][0],
808
- }
809
- elif set(self.cell_states_to_model.keys()) == {
810
- "state_key",
811
- "start_state",
812
- "goal_state",
813
- "alt_states",
814
- }:
815
- if (
816
- (self.cell_states_to_model["state_key"] is None)
817
- or (self.cell_states_to_model["start_state"] is None)
818
- or (self.cell_states_to_model["goal_state"] is None)
819
- ):
820
- logger.error(
821
- "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
822
- )
823
- raise
824
-
825
- if (
826
- self.cell_states_to_model["start_state"]
827
- == self.cell_states_to_model["goal_state"]
828
- ):
829
- logger.error("All states must be unique.")
830
- raise
831
-
832
- if self.cell_states_to_model["alt_states"] is not None:
833
- if not isinstance(self.cell_states_to_model["alt_states"], list):
834
- logger.error(
835
- "self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
836
- )
837
- raise
838
- if len(self.cell_states_to_model["alt_states"]) != len(
839
- set(self.cell_states_to_model["alt_states"])
840
- ):
841
- logger.error("All states must be unique.")
842
- raise
843
-
844
- elif set(self.cell_states_to_model.keys()) == {
845
- "state_key",
846
- "start_state",
847
- "goal_state",
848
- }:
849
- self.cell_states_to_model["alt_states"] = []
850
- else:
851
- logger.error(
852
- "cell_states_to_model must only have the following four keys: "
853
- "'state_key', 'start_state', 'goal_state', 'alt_states'."
854
- "For example, cell_states_to_model={"
855
- "'state_key': 'disease', "
856
- "'start_state': 'dcm', "
857
- "'goal_state': 'nf', "
858
- "'alt_states': ['hcm', 'other1', 'other2']}"
859
- )
860
- raise
861
-
862
- if self.anchor_gene is not None:
863
- self.anchor_gene = None
864
- logger.warning(
865
- "anchor_gene set to None. "
866
- "Currently, anchor gene not available "
867
- "when modeling multiple cell states."
868
- )
869
-
870
- if self.combos > 0:
871
- if self.anchor_gene is None:
872
- logger.error(
873
- "Currently, stats are only supported for combination "
874
- "in silico perturbation run with anchor gene. Please add "
875
- "anchor gene when using with combos > 0. "
876
- )
877
- raise
878
-
879
- if (self.mode == "mixture_model") and (self.genes_perturbed != "all"):
880
- logger.error(
881
- "Mixture model mode requires multiple gene perturbations to fit model "
882
- "so is incompatible with a single grouped perturbation."
883
- )
884
- raise
885
- if (self.mode == "aggregate_data") and (self.genes_perturbed == "all"):
886
- logger.error(
887
- "Simple data aggregation mode is for single perturbation in multiple cells "
888
- "so is incompatible with a genes_perturbed being 'all'."
889
- )
890
- raise
891
-
892
- def get_stats(
893
- self,
894
- input_data_directory,
895
- null_dist_data_directory,
896
- output_directory,
897
- output_prefix,
898
- null_dict_list=None,
899
- ):
900
- """
901
- Get stats for in silico perturbation data and save as results in output_directory.
902
-
903
- **Parameters:**
904
-
905
- input_data_directory : Path
906
- | Path to directory containing cos_sim dictionary inputs
907
- null_dist_data_directory : Path
908
- | Path to directory containing null distribution cos_sim dictionary inputs
909
- output_directory : Path
910
- | Path to directory where perturbation data will be saved as .csv
911
- output_prefix : str
912
- | Prefix for output .csv
913
- null_dict_list: list[dict]
914
- | List of loaded null distribution dictionary if more than one comparison vs. the null is to be performed
915
-
916
- **Outputs:**
917
-
918
- Definition of possible columns in .csv output file.
919
-
920
- | Of note, not all columns will be present in all output files.
921
- | Some columns are specific to particular perturbation modes.
922
-
923
- | "Gene": gene token
924
- | "Gene_name": gene name
925
- | "Ensembl_ID": gene Ensembl ID
926
- | "N_Detections": number of cells in which each gene or gene combination was detected in the input dataset
927
- | "Sig": 1 if FDR<0.05, otherwise 0
928
-
929
- | "Shift_to_goal_end": cosine shift from start state towards goal end state in response to given perturbation
930
- | "Shift_to_alt_end": cosine shift from start state towards alternate end state in response to given perturbation
931
- | "Goal_end_vs_random_pval": pvalue of cosine shift from start state towards goal end state by Wilcoxon
932
- | pvalue compares shift caused by perturbing given gene compared to random genes
933
- | "Alt_end_vs_random_pval": pvalue of cosine shift from start state towards alternate end state by Wilcoxon
934
- | pvalue compares shift caused by perturbing given gene compared to random genes
935
- | "Goal_end_FDR": Benjamini-Hochberg correction of "Goal_end_vs_random_pval"
936
- | "Alt_end_FDR": Benjamini-Hochberg correction of "Alt_end_vs_random_pval"
937
-
938
- | "Test_avg_shift": cosine shift in response to given perturbation in cells from test distribution
939
- | "Null_avg_shift": cosine shift in response to given perturbation in cells from null distribution (e.g. random cells)
940
- | "Test_vs_null_avg_shift": difference in cosine shift in cells from test vs. null distribution
941
- | (i.e. "Test_avg_shift" minus "Null_avg_shift")
942
- | "Test_vs_null_pval": pvalue of cosine shift in test vs. null distribution
943
- | "Test_vs_null_FDR": Benjamini-Hochberg correction of "Test_vs_null_pval"
944
- | "N_Detections_test": "N_Detections" in cells from test distribution
945
- | "N_Detections_null": "N_Detections" in cells from null distribution
946
-
947
- | "Anchor_shift": cosine shift in response to given perturbation of anchor gene
948
- | "Test_token_shift": cosine shift in response to given perturbation of test gene
949
- | "Sum_of_indiv_shifts": sum of cosine shifts in response to individually perturbing test and anchor genes
950
- | "Combo_shift": cosine shift in response to given perturbation of both anchor and test gene(s) in combination
951
- | "Combo_minus_sum_shift": difference of cosine shifts in response combo perturbation vs. sum of individual perturbations
952
- | (i.e. "Combo_shift" minus "Sum_of_indiv_shifts")
953
- | "Impact_component": whether the given perturbation was modeled to be within the impact component by the mixture model
954
- | 1: within impact component; 0: not within impact component
955
- | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
956
-
957
- | In case of aggregating data / gene shifts:
958
- | "Perturbed": ID(s) of gene(s) being perturbed
959
- | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
960
- | "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
961
- | "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
962
- """
963
-
964
- if self.mode not in [
965
- "goal_state_shift",
966
- "vs_null",
967
- "mixture_model",
968
- "aggregate_data",
969
- "aggregate_gene_shifts",
970
- ]:
971
- logger.error(
972
- "Currently, only modes available are stats for goal_state_shift, "
973
- "vs_null (comparing to null distribution), "
974
- "mixture_model (fitting mixture model for perturbations with or without impact), "
975
- "and aggregating data for single perturbations or for gene embedding shifts."
976
- )
977
- raise
978
-
979
- self.gene_token_id_dict = invert_dict(self.gene_token_dict)
980
- self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
981
-
982
- # obtain total gene list
983
- if (self.combos == 0) and (self.anchor_token is not None):
984
- # cos sim data for effect of gene perturbation on the embedding of each other gene
985
- dict_list = read_dictionaries(
986
- input_data_directory,
987
- "gene",
988
- self.anchor_token,
989
- self.cell_states_to_model,
990
- self.pickle_suffix,
991
- )
992
- gene_list = get_gene_list(dict_list, "gene")
993
- elif (
994
- (self.combos == 0)
995
- and (self.anchor_token is None)
996
- and (self.mode == "aggregate_gene_shifts")
997
- ):
998
- dict_list = read_dictionaries(
999
- input_data_directory,
1000
- "gene",
1001
- self.anchor_token,
1002
- self.cell_states_to_model,
1003
- self.pickle_suffix,
1004
- )
1005
- gene_list = get_gene_list(dict_list, "cell")
1006
- else:
1007
- # cos sim data for effect of gene perturbation on the embedding of each cell
1008
- dict_list = read_dictionaries(
1009
- input_data_directory,
1010
- "cell",
1011
- self.anchor_token,
1012
- self.cell_states_to_model,
1013
- self.pickle_suffix,
1014
- )
1015
- gene_list = get_gene_list(dict_list, "cell")
1016
-
1017
- # initiate results dataframe
1018
- cos_sims_df_initial = pd.DataFrame(
1019
- {
1020
- "Gene": gene_list,
1021
- "Gene_name": [self.token_to_gene_name(item) for item in gene_list],
1022
- "Ensembl_ID": [
1023
- token_tuple_to_ensembl_ids(genes, self.gene_token_id_dict)
1024
- if self.genes_perturbed != "all"
1025
- else self.gene_token_id_dict[genes[1]]
1026
- if isinstance(genes, tuple)
1027
- else self.gene_token_id_dict[genes]
1028
- for genes in gene_list
1029
- ],
1030
- },
1031
- index=[i for i in range(len(gene_list))],
1032
- )
1033
-
1034
- if self.mode == "goal_state_shift":
1035
- cos_sims_df = isp_stats_to_goal_state(
1036
- cos_sims_df_initial,
1037
- dict_list,
1038
- self.cell_states_to_model,
1039
- self.genes_perturbed,
1040
- )
1041
-
1042
- elif self.mode == "vs_null":
1043
- if null_dict_list is None:
1044
- null_dict_list = read_dictionaries(
1045
- null_dist_data_directory,
1046
- "cell",
1047
- self.anchor_token,
1048
- self.cell_states_to_model,
1049
- self.pickle_suffix,
1050
- )
1051
- cos_sims_df = isp_stats_vs_null(
1052
- cos_sims_df_initial, dict_list, null_dict_list
1053
- )
1054
-
1055
- elif self.mode == "mixture_model":
1056
- cos_sims_df = isp_stats_mixture_model(
1057
- cos_sims_df_initial, dict_list, self.combos, self.anchor_token
1058
- )
1059
-
1060
- elif self.mode == "aggregate_data":
1061
- cos_sims_df = isp_aggregate_grouped_perturb(
1062
- cos_sims_df_initial, dict_list, self.genes_perturbed
1063
- )
1064
-
1065
- elif self.mode == "aggregate_gene_shifts":
1066
- if (self.genes_perturbed == "all") and (self.combos == 0):
1067
- tuple_types = [
1068
- True if isinstance(genes, tuple) else False for genes in gene_list
1069
- ]
1070
- if all(tuple_types):
1071
- token_dtype = "tuple"
1072
- elif not any(tuple_types):
1073
- token_dtype = "nontuple"
1074
- else:
1075
- token_dtype = "mix"
1076
- else:
1077
- token_dtype = "mix"
1078
-
1079
- cos_sims_df = isp_aggregate_gene_shifts(
1080
- cos_sims_df_initial,
1081
- dict_list,
1082
- self.gene_token_id_dict,
1083
- self.gene_id_name_dict,
1084
- token_dtype,
1085
- )
1086
-
1087
- # save perturbation stats to output_path
1088
- output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
1089
- cos_sims_df.to_csv(output_path)
1090
-
1091
- def token_to_gene_name(self, item):
1092
- if np.issubdtype(type(item), np.integer):
1093
- return self.gene_id_name_dict.get(
1094
- self.gene_token_id_dict.get(item, np.nan), np.nan
1095
- )
1096
- if isinstance(item, tuple):
1097
- return tuple(
1098
- [
1099
- self.gene_id_name_dict.get(
1100
- self.gene_token_id_dict.get(i, np.nan), np.nan
1101
- )
1102
- for i in item
1103
- ]
1104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/__init__.py DELETED
@@ -1 +0,0 @@
1
- # ruff: noqa: F401
 
 
geneformer/mtl/collators.py DELETED
@@ -1,76 +0,0 @@
1
- # imports
2
- import torch
3
- import pickle
4
- from ..collator_for_classification import DataCollatorForGeneClassification
5
- from .. import TOKEN_DICTIONARY_FILE
6
-
7
- """Geneformer collator for multi-task cell classification."""
8
-
9
- class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
10
- class_type = "cell"
11
-
12
- @staticmethod
13
- def load_token_dictionary():
14
- with open(TOKEN_DICTIONARY_FILE, 'rb') as f:
15
- return pickle.load(f)
16
-
17
- def __init__(self, *args, **kwargs) -> None:
18
- # Load the token dictionary
19
- token_dictionary = self.load_token_dictionary()
20
- # Use the loaded token dictionary
21
- super().__init__(token_dictionary=token_dictionary, *args, **kwargs)
22
-
23
- def _prepare_batch(self, features):
24
- # Process inputs as usual
25
- batch = self.tokenizer.pad(
26
- features,
27
- class_type=self.class_type,
28
- padding=self.padding,
29
- max_length=self.max_length,
30
- pad_to_multiple_of=self.pad_to_multiple_of,
31
- return_tensors="pt",
32
- )
33
-
34
- # Check if labels are present
35
- if "label" in features[0]:
36
- # Initialize labels dictionary for all tasks
37
- labels = {task: [] for task in features[0]["label"].keys()}
38
- # Populate labels for each task
39
- for feature in features:
40
- for task, label in feature["label"].items():
41
- labels[task].append(label)
42
-
43
- # Convert label lists to tensors, handling dictionaries appropriately
44
- for task in labels:
45
- if isinstance(labels[task][0], (list, torch.Tensor)):
46
- dtype = torch.long
47
- labels[task] = torch.tensor(labels[task], dtype=dtype)
48
- elif isinstance(labels[task][0], dict):
49
- # Handle dict specifically if needed
50
- pass # Resolve nested data structure
51
-
52
- # Update the batch to include task-specific labels
53
- batch["labels"] = labels
54
- else:
55
- # If no labels are present, create empty labels for all tasks
56
- batch["labels"] = {
57
- task: torch.tensor([], dtype=torch.long)
58
- for task in features[0]["input_ids"].keys()
59
- }
60
-
61
- return batch
62
-
63
- def __call__(self, features):
64
- batch = self._prepare_batch(features)
65
- for k, v in batch.items():
66
- if torch.is_tensor(v):
67
- batch[k] = v.clone().detach()
68
- elif isinstance(v, dict):
69
- # Assuming nested structure needs conversion
70
- batch[k] = {
71
- task: torch.tensor(labels, dtype=torch.int64)
72
- for task, labels in v.items()
73
- }
74
- else:
75
- batch[k] = torch.tensor(v, dtype=torch.int64)
76
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/data.py DELETED
@@ -1,162 +0,0 @@
1
- import os
2
- from .collators import DataCollatorForMultitaskCellClassification
3
- from .imports import *
4
-
5
- def validate_columns(dataset, required_columns, dataset_type):
6
- """Ensures required columns are present in the dataset."""
7
- missing_columns = [col for col in required_columns if col not in dataset.column_names]
8
- if missing_columns:
9
- raise KeyError(
10
- f"Missing columns in {dataset_type} dataset: {missing_columns}. "
11
- f"Available columns: {dataset.column_names}"
12
- )
13
-
14
-
15
- def create_label_mappings(dataset, task_to_column):
16
- """Creates label mappings for the dataset."""
17
- task_label_mappings = {}
18
- num_labels_list = []
19
- for task, column in task_to_column.items():
20
- unique_values = sorted(set(dataset[column]))
21
- mapping = {label: idx for idx, label in enumerate(unique_values)}
22
- task_label_mappings[task] = mapping
23
- num_labels_list.append(len(unique_values))
24
- return task_label_mappings, num_labels_list
25
-
26
-
27
- def save_label_mappings(mappings, path):
28
- """Saves label mappings to a pickle file."""
29
- with open(path, "wb") as f:
30
- pickle.dump(mappings, f)
31
-
32
-
33
- def load_label_mappings(path):
34
- """Loads label mappings from a pickle file."""
35
- with open(path, "rb") as f:
36
- return pickle.load(f)
37
-
38
-
39
- def transform_dataset(dataset, task_to_column, task_label_mappings, config, is_test):
40
- """Transforms the dataset to the required format."""
41
- transformed_dataset = []
42
- cell_id_mapping = {}
43
-
44
- for idx, record in enumerate(dataset):
45
- transformed_record = {
46
- "input_ids": torch.tensor(record["input_ids"], dtype=torch.long),
47
- "cell_id": idx, # Index-based cell ID
48
- }
49
-
50
- if not is_test:
51
- label_dict = {
52
- task: task_label_mappings[task][record[column]]
53
- for task, column in task_to_column.items()
54
- }
55
- else:
56
- label_dict = {task: -1 for task in config["task_names"]}
57
-
58
- transformed_record["label"] = label_dict
59
- transformed_dataset.append(transformed_record)
60
- cell_id_mapping[idx] = record.get("unique_cell_id", idx)
61
-
62
- return transformed_dataset, cell_id_mapping
63
-
64
-
65
- def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""):
66
- """Main function to load and preprocess data."""
67
- try:
68
- dataset = load_from_disk(dataset_path)
69
-
70
- # Setup task and column mappings
71
- task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))]
72
- task_to_column = dict(zip(task_names, config["task_columns"]))
73
- config["task_names"] = task_names
74
-
75
- label_mappings_path = os.path.join(
76
- config["results_dir"],
77
- f"task_label_mappings{'_val' if dataset_type == 'validation' else ''}.pkl"
78
- )
79
-
80
- if not is_test:
81
- validate_columns(dataset, task_to_column.values(), dataset_type)
82
-
83
- # Create and save label mappings
84
- task_label_mappings, num_labels_list = create_label_mappings(dataset, task_to_column)
85
- save_label_mappings(task_label_mappings, label_mappings_path)
86
- else:
87
- # Load existing mappings for test data
88
- task_label_mappings = load_label_mappings(label_mappings_path)
89
- num_labels_list = [len(mapping) for mapping in task_label_mappings.values()]
90
-
91
- # Transform dataset
92
- transformed_dataset, cell_id_mapping = transform_dataset(
93
- dataset, task_to_column, task_label_mappings, config, is_test
94
- )
95
-
96
- return transformed_dataset, cell_id_mapping, num_labels_list
97
-
98
- except KeyError as e:
99
- raise ValueError(f"Configuration error or dataset key missing: {e}")
100
- except Exception as e:
101
- raise RuntimeError(f"Error during data loading or preprocessing: {e}")
102
-
103
-
104
- def preload_and_process_data(config):
105
- """Preloads and preprocesses train and validation datasets."""
106
- # Process train data and save mappings
107
- train_data = load_and_preprocess_data(config["train_path"], config, dataset_type="train")
108
-
109
- # Process validation data and save mappings
110
- val_data = load_and_preprocess_data(config["val_path"], config, dataset_type="validation")
111
-
112
- # Validate that the mappings match
113
- validate_label_mappings(config)
114
-
115
- return (*train_data, *val_data[:2]) # Return train and val data along with mappings
116
-
117
-
118
- def validate_label_mappings(config):
119
- """Ensures train and validation label mappings are consistent."""
120
- train_mappings_path = os.path.join(config["results_dir"], "task_label_mappings.pkl")
121
- val_mappings_path = os.path.join(config["results_dir"], "task_label_mappings_val.pkl")
122
- train_mappings = load_label_mappings(train_mappings_path)
123
- val_mappings = load_label_mappings(val_mappings_path)
124
-
125
- for task_name in config["task_names"]:
126
- if train_mappings[task_name] != val_mappings[task_name]:
127
- raise ValueError(
128
- f"Mismatch in label mappings for task '{task_name}'.\n"
129
- f"Train Mapping: {train_mappings[task_name]}\n"
130
- f"Validation Mapping: {val_mappings[task_name]}"
131
- )
132
-
133
-
134
- def get_data_loader(preprocessed_dataset, batch_size):
135
- """Creates a DataLoader with optimal settings."""
136
- return DataLoader(
137
- preprocessed_dataset,
138
- batch_size=batch_size,
139
- shuffle=True,
140
- collate_fn=DataCollatorForMultitaskCellClassification(),
141
- num_workers=os.cpu_count(),
142
- pin_memory=True,
143
- )
144
-
145
-
146
- def preload_data(config):
147
- """Preprocesses train and validation data for trials."""
148
- train_loader = get_data_loader(*preload_and_process_data(config)[:2], config["batch_size"])
149
- val_loader = get_data_loader(*preload_and_process_data(config)[2:4], config["batch_size"])
150
- return train_loader, val_loader
151
-
152
-
153
- def load_and_preprocess_test_data(config):
154
- """Loads and preprocesses test data."""
155
- return load_and_preprocess_data(config["test_path"], config, is_test=True)
156
-
157
-
158
- def prepare_test_loader(config):
159
- """Prepares DataLoader for test data."""
160
- test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config)
161
- test_loader = get_data_loader(test_dataset, config["batch_size"])
162
- return test_loader, cell_id_mapping, num_labels_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/eval_utils.py DELETED
@@ -1,88 +0,0 @@
1
- import pandas as pd
2
-
3
- from .imports import * # noqa # isort:skip
4
- from .data import prepare_test_loader # noqa # isort:skip
5
- from .model import GeneformerMultiTask
6
-
7
-
8
- def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config):
9
- task_pred_labels = {task_name: [] for task_name in config["task_names"]}
10
- task_pred_probs = {task_name: [] for task_name in config["task_names"]}
11
- cell_ids = []
12
-
13
- # # Load task label mappings from pickle file
14
- # with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f:
15
- # task_label_mappings = pickle.load(f)
16
-
17
- model.eval()
18
- with torch.no_grad():
19
- for batch in test_loader:
20
- input_ids = batch["input_ids"].to(device)
21
- attention_mask = batch["attention_mask"].to(device)
22
- _, logits, _ = model(input_ids, attention_mask)
23
- for sample_idx in range(len(batch["input_ids"])):
24
- cell_id = cell_id_mapping[batch["cell_id"][sample_idx].item()]
25
- cell_ids.append(cell_id)
26
- for i, task_name in enumerate(config["task_names"]):
27
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
28
- pred_prob = (
29
- torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
30
- )
31
- task_pred_labels[task_name].append(pred_label)
32
- task_pred_probs[task_name].append(pred_prob)
33
-
34
- # Save test predictions with cell IDs and probabilities to CSV
35
- test_results_dir = config["results_dir"]
36
- os.makedirs(test_results_dir, exist_ok=True)
37
- test_preds_file = os.path.join(test_results_dir, "test_preds.csv")
38
-
39
- rows = []
40
- for sample_idx in range(len(cell_ids)):
41
- row = {"Cell ID": cell_ids[sample_idx]}
42
- for task_name in config["task_names"]:
43
- row[f"{task_name} Prediction"] = task_pred_labels[task_name][sample_idx]
44
- row[f"{task_name} Probabilities"] = ",".join(
45
- map(str, task_pred_probs[task_name][sample_idx])
46
- )
47
- rows.append(row)
48
-
49
- df = pd.DataFrame(rows)
50
- df.to_csv(test_preds_file, index=False)
51
- print(f"Test predictions saved to {test_preds_file}")
52
-
53
-
54
- def load_and_evaluate_test_model(config):
55
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
- test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config)
57
- model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask")
58
- hyperparams_path = os.path.join(model_directory, "hyperparameters.json")
59
-
60
- # Load the saved best hyperparameters
61
- with open(hyperparams_path, "r") as f:
62
- best_hyperparams = json.load(f)
63
-
64
- # Extract the task weights if present, otherwise set to None
65
- task_weights = best_hyperparams.get("task_weights", None)
66
- normalized_task_weights = task_weights if task_weights else []
67
-
68
- # Print the loaded hyperparameters
69
- print("Loaded hyperparameters:")
70
- for param, value in best_hyperparams.items():
71
- if param == "task_weights":
72
- print(f"normalized_task_weights: {value}")
73
- else:
74
- print(f"{param}: {value}")
75
-
76
- best_model_path = os.path.join(model_directory, "pytorch_model.bin")
77
- best_model = GeneformerMultiTask(
78
- config["pretrained_path"],
79
- num_labels_list,
80
- dropout_rate=best_hyperparams["dropout_rate"],
81
- use_task_weights=config["use_task_weights"],
82
- task_weights=normalized_task_weights,
83
- )
84
- best_model.load_state_dict(torch.load(best_model_path))
85
- best_model.to(device)
86
-
87
- evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config)
88
- print("Evaluation completed.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/imports.py DELETED
@@ -1,43 +0,0 @@
1
- import functools
2
- import gc
3
- import json
4
- import os
5
- import pickle
6
- import sys
7
- import warnings
8
- from enum import Enum
9
- from itertools import chain
10
- from typing import Dict, List, Optional, Union
11
-
12
- import numpy as np
13
- import optuna
14
- import pandas as pd
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- import torch.optim as optim
19
- from datasets import load_from_disk
20
- from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, roc_curve
21
- from sklearn.model_selection import train_test_split
22
- from sklearn.preprocessing import LabelEncoder
23
- from torch.utils.data import DataLoader
24
- from transformers import (
25
- AdamW,
26
- BatchEncoding,
27
- BertConfig,
28
- BertModel,
29
- DataCollatorForTokenClassification,
30
- SpecialTokensMixin,
31
- get_cosine_schedule_with_warmup,
32
- get_linear_schedule_with_warmup,
33
- get_scheduler,
34
- )
35
- from transformers.utils import logging, to_py_obj
36
-
37
- from .collators import DataCollatorForMultitaskCellClassification
38
-
39
- # local modules
40
- from .data import get_data_loader, preload_and_process_data
41
- from .model import GeneformerMultiTask
42
- from .optuna_utils import create_optuna_study
43
- from .utils import save_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/model.py DELETED
@@ -1,121 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import BertConfig, BertModel
4
-
5
-
6
- class AttentionPool(nn.Module):
7
- """Attention-based pooling layer."""
8
-
9
- def __init__(self, hidden_size):
10
- super(AttentionPool, self).__init__()
11
- self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1))
12
- nn.init.xavier_uniform_(
13
- self.attention_weights
14
- ) # https://pytorch.org/docs/stable/nn.init.html
15
-
16
- def forward(self, hidden_states):
17
- attention_scores = torch.matmul(hidden_states, self.attention_weights)
18
- attention_scores = torch.softmax(attention_scores, dim=1)
19
- pooled_output = torch.sum(hidden_states * attention_scores, dim=1)
20
- return pooled_output
21
-
22
-
23
- class GeneformerMultiTask(nn.Module):
24
- def __init__(
25
- self,
26
- pretrained_path,
27
- num_labels_list,
28
- dropout_rate=0.1,
29
- use_task_weights=False,
30
- task_weights=None,
31
- max_layers_to_freeze=0,
32
- use_attention_pooling=False,
33
- ):
34
- super(GeneformerMultiTask, self).__init__()
35
- self.config = BertConfig.from_pretrained(pretrained_path)
36
- self.bert = BertModel(self.config)
37
- self.num_labels_list = num_labels_list
38
- self.use_task_weights = use_task_weights
39
- self.dropout = nn.Dropout(dropout_rate)
40
- self.use_attention_pooling = use_attention_pooling
41
-
42
- if use_task_weights and (
43
- task_weights is None or len(task_weights) != len(num_labels_list)
44
- ):
45
- raise ValueError(
46
- "Task weights must be defined and match the number of tasks when 'use_task_weights' is True."
47
- )
48
- self.task_weights = (
49
- task_weights if use_task_weights else [1.0] * len(num_labels_list)
50
- )
51
-
52
- # Freeze the specified initial layers
53
- for layer in self.bert.encoder.layer[:max_layers_to_freeze]:
54
- for param in layer.parameters():
55
- param.requires_grad = False
56
-
57
- self.attention_pool = (
58
- AttentionPool(self.config.hidden_size) if use_attention_pooling else None
59
- )
60
-
61
- self.classification_heads = nn.ModuleList(
62
- [
63
- nn.Linear(self.config.hidden_size, num_labels)
64
- for num_labels in num_labels_list
65
- ]
66
- )
67
- # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html
68
- for head in self.classification_heads:
69
- nn.init.xavier_uniform_(head.weight)
70
- nn.init.zeros_(head.bias)
71
-
72
- def forward(self, input_ids, attention_mask, labels=None):
73
- try:
74
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
75
- except Exception as e:
76
- raise RuntimeError(f"Error during BERT forward pass: {e}")
77
-
78
- sequence_output = outputs.last_hidden_state
79
-
80
- try:
81
- pooled_output = (
82
- self.attention_pool(sequence_output)
83
- if self.use_attention_pooling
84
- else sequence_output[:, 0, :]
85
- )
86
- pooled_output = self.dropout(pooled_output)
87
- except Exception as e:
88
- raise RuntimeError(f"Error during pooling and dropout: {e}")
89
-
90
- total_loss = 0
91
- logits = []
92
- losses = []
93
-
94
- for task_id, (head, num_labels) in enumerate(
95
- zip(self.classification_heads, self.num_labels_list)
96
- ):
97
- try:
98
- task_logits = head(pooled_output)
99
- except Exception as e:
100
- raise RuntimeError(
101
- f"Error during forward pass of classification head {task_id}: {e}"
102
- )
103
-
104
- logits.append(task_logits)
105
-
106
- if labels is not None:
107
- try:
108
- loss_fct = nn.CrossEntropyLoss()
109
- task_loss = loss_fct(
110
- task_logits.view(-1, num_labels), labels[task_id].view(-1)
111
- )
112
- if self.use_task_weights:
113
- task_loss *= self.task_weights[task_id]
114
- total_loss += task_loss
115
- losses.append(task_loss.item())
116
- except Exception as e:
117
- raise RuntimeError(
118
- f"Error during loss computation for task {task_id}: {e}"
119
- )
120
-
121
- return total_loss, logits, losses if labels is not None else logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/optuna_utils.py DELETED
@@ -1,27 +0,0 @@
1
- import optuna
2
- from optuna.integration import TensorBoardCallback
3
-
4
-
5
- def save_trial_callback(study, trial, trials_result_path):
6
- with open(trials_result_path, "a") as f:
7
- f.write(
8
- f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n"
9
- )
10
-
11
-
12
- def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir):
13
- study = optuna.create_study(direction="maximize")
14
-
15
- # init TensorBoard callback
16
- tensorboard_callback = TensorBoardCallback(
17
- dirname=tensorboard_log_dir, metric_name="F1 Macro"
18
- )
19
-
20
- # callback and TensorBoard callback
21
- callbacks = [
22
- lambda study, trial: save_trial_callback(study, trial, trials_result_path),
23
- tensorboard_callback,
24
- ]
25
-
26
- study.optimize(objective, n_trials=n_trials, callbacks=callbacks)
27
- return study
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/train.py DELETED
@@ -1,380 +0,0 @@
1
- import os
2
- import random
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import torch
7
- from torch.utils.tensorboard import SummaryWriter
8
- from tqdm import tqdm
9
-
10
- from .imports import *
11
- from .model import GeneformerMultiTask
12
- from .utils import calculate_task_specific_metrics, get_layer_freeze_range
13
-
14
-
15
- def set_seed(seed):
16
- random.seed(seed)
17
- np.random.seed(seed)
18
- torch.manual_seed(seed)
19
- torch.cuda.manual_seed_all(seed)
20
- torch.backends.cudnn.deterministic = True
21
- torch.backends.cudnn.benchmark = False
22
-
23
-
24
- def initialize_wandb(config):
25
- if config.get("use_wandb", False):
26
- import wandb
27
-
28
- wandb.init(project=config["wandb_project"], config=config)
29
- print("Weights & Biases (wandb) initialized and will be used for logging.")
30
- else:
31
- print(
32
- "Weights & Biases (wandb) is not enabled. Logging will use other methods."
33
- )
34
-
35
-
36
- def create_model(config, num_labels_list, device):
37
- model = GeneformerMultiTask(
38
- config["pretrained_path"],
39
- num_labels_list,
40
- dropout_rate=config["dropout_rate"],
41
- use_task_weights=config["use_task_weights"],
42
- task_weights=config["task_weights"],
43
- max_layers_to_freeze=config["max_layers_to_freeze"],
44
- use_attention_pooling=config["use_attention_pooling"],
45
- )
46
- if config["use_data_parallel"]:
47
- model = nn.DataParallel(model)
48
- return model.to(device)
49
-
50
-
51
- def setup_optimizer_and_scheduler(model, config, total_steps):
52
- optimizer = AdamW(
53
- model.parameters(),
54
- lr=config["learning_rate"],
55
- weight_decay=config["weight_decay"],
56
- )
57
- warmup_steps = int(config["warmup_ratio"] * total_steps)
58
-
59
- if config["lr_scheduler_type"] == "linear":
60
- scheduler = get_linear_schedule_with_warmup(
61
- optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
62
- )
63
- elif config["lr_scheduler_type"] == "cosine":
64
- scheduler = get_cosine_schedule_with_warmup(
65
- optimizer,
66
- num_warmup_steps=warmup_steps,
67
- num_training_steps=total_steps,
68
- num_cycles=0.5,
69
- )
70
-
71
- return optimizer, scheduler
72
-
73
-
74
- def train_epoch(
75
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
76
- ):
77
- model.train()
78
- progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
79
- for batch_idx, batch in enumerate(progress_bar):
80
- optimizer.zero_grad()
81
- input_ids = batch["input_ids"].to(device)
82
- attention_mask = batch["attention_mask"].to(device)
83
- labels = [
84
- batch["labels"][task_name].to(device) for task_name in config["task_names"]
85
- ]
86
-
87
- loss, _, _ = model(input_ids, attention_mask, labels)
88
- loss.backward()
89
-
90
- if config["gradient_clipping"]:
91
- torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"])
92
-
93
- optimizer.step()
94
- scheduler.step()
95
-
96
- writer.add_scalar(
97
- "Training Loss", loss.item(), epoch * len(train_loader) + batch_idx
98
- )
99
- if config.get("use_wandb", False):
100
- import wandb
101
-
102
- wandb.log({"Training Loss": loss.item()})
103
-
104
- # Update progress bar
105
- progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
106
-
107
- return loss.item() # Return the last batch loss
108
-
109
-
110
- def validate_model(model, val_loader, device, config):
111
- model.eval()
112
- val_loss = 0.0
113
- task_true_labels = {task_name: [] for task_name in config["task_names"]}
114
- task_pred_labels = {task_name: [] for task_name in config["task_names"]}
115
- task_pred_probs = {task_name: [] for task_name in config["task_names"]}
116
-
117
- with torch.no_grad():
118
- for batch in val_loader:
119
- input_ids = batch["input_ids"].to(device)
120
- attention_mask = batch["attention_mask"].to(device)
121
- labels = [
122
- batch["labels"][task_name].to(device)
123
- for task_name in config["task_names"]
124
- ]
125
- loss, logits, _ = model(input_ids, attention_mask, labels)
126
- val_loss += loss.item()
127
-
128
- for sample_idx in range(len(batch["input_ids"])):
129
- for i, task_name in enumerate(config["task_names"]):
130
- true_label = batch["labels"][task_name][sample_idx].item()
131
- pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item()
132
- pred_prob = (
133
- torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy()
134
- )
135
- task_true_labels[task_name].append(true_label)
136
- task_pred_labels[task_name].append(pred_label)
137
- task_pred_probs[task_name].append(pred_prob)
138
-
139
- val_loss /= len(val_loader)
140
- return val_loss, task_true_labels, task_pred_labels, task_pred_probs
141
-
142
-
143
- def log_metrics(task_metrics, val_loss, config, writer, epochs):
144
- for task_name, metrics in task_metrics.items():
145
- print(
146
- f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}"
147
- )
148
- if config.get("use_wandb", False):
149
- import wandb
150
-
151
- wandb.log(
152
- {
153
- f"{task_name} Validation F1 Macro": metrics["f1"],
154
- f"{task_name} Validation Accuracy": metrics["accuracy"],
155
- }
156
- )
157
-
158
- writer.add_scalar("Validation Loss", val_loss, epochs)
159
- for task_name, metrics in task_metrics.items():
160
- writer.add_scalar(f"{task_name} - Validation F1 Macro", metrics["f1"], epochs)
161
- writer.add_scalar(
162
- f"{task_name} - Validation Accuracy", metrics["accuracy"], epochs
163
- )
164
-
165
-
166
- def save_validation_predictions(
167
- val_cell_id_mapping,
168
- task_true_labels,
169
- task_pred_labels,
170
- task_pred_probs,
171
- config,
172
- trial_number=None,
173
- ):
174
- if trial_number is not None:
175
- trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}")
176
- os.makedirs(trial_results_dir, exist_ok=True)
177
- val_preds_file = os.path.join(trial_results_dir, "val_preds.csv")
178
- else:
179
- val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv")
180
-
181
- rows = []
182
- for sample_idx in range(len(val_cell_id_mapping)):
183
- row = {"Cell ID": val_cell_id_mapping[sample_idx]}
184
- for task_name in config["task_names"]:
185
- row[f"{task_name} True"] = task_true_labels[task_name][sample_idx]
186
- row[f"{task_name} Pred"] = task_pred_labels[task_name][sample_idx]
187
- row[f"{task_name} Probabilities"] = ",".join(
188
- map(str, task_pred_probs[task_name][sample_idx])
189
- )
190
- rows.append(row)
191
-
192
- df = pd.DataFrame(rows)
193
- df.to_csv(val_preds_file, index=False)
194
- print(f"Validation predictions saved to {val_preds_file}")
195
-
196
-
197
- def train_model(
198
- config,
199
- device,
200
- train_loader,
201
- val_loader,
202
- train_cell_id_mapping,
203
- val_cell_id_mapping,
204
- num_labels_list,
205
- ):
206
- set_seed(config["seed"])
207
- initialize_wandb(config)
208
-
209
- model = create_model(config, num_labels_list, device)
210
- total_steps = len(train_loader) * config["epochs"]
211
- optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
212
-
213
- log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run")
214
- writer = SummaryWriter(log_dir=log_dir)
215
-
216
- epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress")
217
- for epoch in epoch_progress:
218
- last_loss = train_epoch(
219
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
220
- )
221
- epoch_progress.set_postfix({"last_loss": f"{last_loss:.4f}"})
222
-
223
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
224
- model, val_loader, device, config
225
- )
226
- task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
227
-
228
- log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
229
- writer.close()
230
-
231
- save_validation_predictions(
232
- val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config
233
- )
234
-
235
- if config.get("use_wandb", False):
236
- import wandb
237
-
238
- wandb.finish()
239
-
240
- print(f"\nFinal Validation Loss: {val_loss:.4f}")
241
- return val_loss, model # Return both the validation loss and the trained model
242
-
243
-
244
- def objective(
245
- trial,
246
- train_loader,
247
- val_loader,
248
- train_cell_id_mapping,
249
- val_cell_id_mapping,
250
- num_labels_list,
251
- config,
252
- device,
253
- ):
254
- set_seed(config["seed"]) # Set the seed before each trial
255
- initialize_wandb(config)
256
-
257
- # Hyperparameters
258
- config["learning_rate"] = trial.suggest_float(
259
- "learning_rate",
260
- config["hyperparameters"]["learning_rate"]["low"],
261
- config["hyperparameters"]["learning_rate"]["high"],
262
- log=config["hyperparameters"]["learning_rate"]["log"],
263
- )
264
- config["warmup_ratio"] = trial.suggest_float(
265
- "warmup_ratio",
266
- config["hyperparameters"]["warmup_ratio"]["low"],
267
- config["hyperparameters"]["warmup_ratio"]["high"],
268
- )
269
- config["weight_decay"] = trial.suggest_float(
270
- "weight_decay",
271
- config["hyperparameters"]["weight_decay"]["low"],
272
- config["hyperparameters"]["weight_decay"]["high"],
273
- )
274
- config["dropout_rate"] = trial.suggest_float(
275
- "dropout_rate",
276
- config["hyperparameters"]["dropout_rate"]["low"],
277
- config["hyperparameters"]["dropout_rate"]["high"],
278
- )
279
- config["lr_scheduler_type"] = trial.suggest_categorical(
280
- "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
- )
282
- config["use_attention_pooling"] = trial.suggest_categorical(
283
- "use_attention_pooling", [False]
284
- )
285
-
286
- if config["use_task_weights"]:
287
- config["task_weights"] = [
288
- trial.suggest_float(
289
- f"task_weight_{i}",
290
- config["hyperparameters"]["task_weights"]["low"],
291
- config["hyperparameters"]["task_weights"]["high"],
292
- )
293
- for i in range(len(num_labels_list))
294
- ]
295
- weight_sum = sum(config["task_weights"])
296
- config["task_weights"] = [
297
- weight / weight_sum for weight in config["task_weights"]
298
- ]
299
- else:
300
- config["task_weights"] = None
301
-
302
- # Dynamic range for max_layers_to_freeze
303
- freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
- config["max_layers_to_freeze"] = trial.suggest_int(
305
- "max_layers_to_freeze",
306
- freeze_range["min"],
307
- freeze_range["max"]
308
- )
309
-
310
- model = create_model(config, num_labels_list, device)
311
- total_steps = len(train_loader) * config["epochs"]
312
- optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps)
313
-
314
- log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}")
315
- writer = SummaryWriter(log_dir=log_dir)
316
-
317
- for epoch in range(config["epochs"]):
318
- train_epoch(
319
- model, train_loader, optimizer, scheduler, device, config, writer, epoch
320
- )
321
-
322
- val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(
323
- model, val_loader, device, config
324
- )
325
- task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels)
326
-
327
- log_metrics(task_metrics, val_loss, config, writer, config["epochs"])
328
- writer.close()
329
-
330
- save_validation_predictions(
331
- val_cell_id_mapping,
332
- task_true_labels,
333
- task_pred_labels,
334
- task_pred_probs,
335
- config,
336
- trial.number,
337
- )
338
-
339
- trial.set_user_attr("model_state_dict", model.state_dict())
340
- trial.set_user_attr("task_weights", config["task_weights"])
341
-
342
- trial.report(val_loss, config["epochs"])
343
-
344
- if trial.should_prune():
345
- raise optuna.TrialPruned()
346
-
347
- if config.get("use_wandb", False):
348
- import wandb
349
-
350
- wandb.log(
351
- {
352
- "trial_number": trial.number,
353
- "val_loss": val_loss,
354
- **{
355
- f"{task_name}_f1": metrics["f1"]
356
- for task_name, metrics in task_metrics.items()
357
- },
358
- **{
359
- f"{task_name}_accuracy": metrics["accuracy"]
360
- for task_name, metrics in task_metrics.items()
361
- },
362
- **{
363
- k: v
364
- for k, v in config.items()
365
- if k
366
- in [
367
- "learning_rate",
368
- "warmup_ratio",
369
- "weight_decay",
370
- "dropout_rate",
371
- "lr_scheduler_type",
372
- "use_attention_pooling",
373
- "max_layers_to_freeze",
374
- ]
375
- },
376
- }
377
- )
378
- wandb.finish()
379
-
380
- return val_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/train_utils.py DELETED
@@ -1,161 +0,0 @@
1
- import random
2
-
3
- from .data import get_data_loader, preload_and_process_data
4
- from .imports import *
5
- from .model import GeneformerMultiTask
6
- from .train import objective, train_model
7
- from .utils import save_model
8
-
9
-
10
- def set_seed(seed):
11
- random.seed(seed)
12
- np.random.seed(seed)
13
- torch.manual_seed(seed)
14
- torch.cuda.manual_seed_all(seed)
15
- torch.backends.cudnn.deterministic = True
16
- torch.backends.cudnn.benchmark = False
17
-
18
-
19
- def run_manual_tuning(config):
20
- # Set seed for reproducibility
21
- set_seed(config["seed"])
22
-
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- (
25
- train_dataset,
26
- train_cell_id_mapping,
27
- val_dataset,
28
- val_cell_id_mapping,
29
- num_labels_list,
30
- ) = preload_and_process_data(config)
31
- train_loader = get_data_loader(train_dataset, config["batch_size"])
32
- val_loader = get_data_loader(val_dataset, config["batch_size"])
33
-
34
- # Print the manual hyperparameters being used
35
- print("\nManual hyperparameters being used:")
36
- for key, value in config["manual_hyperparameters"].items():
37
- print(f"{key}: {value}")
38
- print() # Add an empty line for better readability
39
-
40
- # Use the manual hyperparameters
41
- for key, value in config["manual_hyperparameters"].items():
42
- config[key] = value
43
-
44
- # Train the model
45
- val_loss, trained_model = train_model(
46
- config,
47
- device,
48
- train_loader,
49
- val_loader,
50
- train_cell_id_mapping,
51
- val_cell_id_mapping,
52
- num_labels_list,
53
- )
54
-
55
- print(f"\nValidation loss with manual hyperparameters: {val_loss}")
56
-
57
- # Save the trained model
58
- model_save_directory = os.path.join(
59
- config["model_save_path"], "GeneformerMultiTask"
60
- )
61
- save_model(trained_model, model_save_directory)
62
-
63
- # Save the hyperparameters
64
- hyperparams_to_save = {
65
- **config["manual_hyperparameters"],
66
- "dropout_rate": config["dropout_rate"],
67
- "use_task_weights": config["use_task_weights"],
68
- "task_weights": config["task_weights"],
69
- "max_layers_to_freeze": config["max_layers_to_freeze"],
70
- "use_attention_pooling": config["use_attention_pooling"],
71
- }
72
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
73
- with open(hyperparams_path, "w") as f:
74
- json.dump(hyperparams_to_save, f)
75
- print(f"Manual hyperparameters saved to {hyperparams_path}")
76
-
77
- return val_loss
78
-
79
-
80
- def run_optuna_study(config):
81
- # Set seed for reproducibility
82
- set_seed(config["seed"])
83
-
84
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
- (
86
- train_dataset,
87
- train_cell_id_mapping,
88
- val_dataset,
89
- val_cell_id_mapping,
90
- num_labels_list,
91
- ) = preload_and_process_data(config)
92
- train_loader = get_data_loader(train_dataset, config["batch_size"])
93
- val_loader = get_data_loader(val_dataset, config["batch_size"])
94
-
95
- if config["use_manual_hyperparameters"]:
96
- train_model(
97
- config,
98
- device,
99
- train_loader,
100
- val_loader,
101
- train_cell_id_mapping,
102
- val_cell_id_mapping,
103
- num_labels_list,
104
- )
105
- else:
106
- objective_with_config_and_data = functools.partial(
107
- objective,
108
- train_loader=train_loader,
109
- val_loader=val_loader,
110
- train_cell_id_mapping=train_cell_id_mapping,
111
- val_cell_id_mapping=val_cell_id_mapping,
112
- num_labels_list=num_labels_list,
113
- config=config,
114
- device=device,
115
- )
116
-
117
- study = optuna.create_study(
118
- direction="minimize", # Minimize validation loss
119
- study_name=config["study_name"],
120
- # storage=config["storage"],
121
- load_if_exists=True,
122
- )
123
-
124
- study.optimize(objective_with_config_and_data, n_trials=config["n_trials"])
125
-
126
- # After finding the best trial
127
- best_params = study.best_trial.params
128
- best_task_weights = study.best_trial.user_attrs["task_weights"]
129
- print("Saving the best model and its hyperparameters...")
130
-
131
- # Saving model as before
132
- best_model = GeneformerMultiTask(
133
- config["pretrained_path"],
134
- num_labels_list,
135
- dropout_rate=best_params["dropout_rate"],
136
- use_task_weights=config["use_task_weights"],
137
- task_weights=best_task_weights,
138
- )
139
-
140
- # Get the best model state dictionary
141
- best_model_state_dict = study.best_trial.user_attrs["model_state_dict"]
142
-
143
- # Remove the "module." prefix from the state dictionary keys if present
144
- best_model_state_dict = {
145
- k.replace("module.", ""): v for k, v in best_model_state_dict.items()
146
- }
147
-
148
- # Load the modified state dictionary into the model, skipping unexpected keys
149
- best_model.load_state_dict(best_model_state_dict, strict=False)
150
-
151
- model_save_directory = os.path.join(
152
- config["model_save_path"], "GeneformerMultiTask"
153
- )
154
- save_model(best_model, model_save_directory)
155
-
156
- # Additionally, save the best hyperparameters and task weights
157
- hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json")
158
-
159
- with open(hyperparams_path, "w") as f:
160
- json.dump({**best_params, "task_weights": best_task_weights}, f)
161
- print(f"Best hyperparameters and task weights saved to {hyperparams_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl/utils.py DELETED
@@ -1,129 +0,0 @@
1
- import os
2
- import shutil
3
-
4
- from sklearn.metrics import accuracy_score, f1_score
5
- from sklearn.preprocessing import LabelEncoder
6
- from transformers import AutoConfig, BertConfig, BertModel
7
-
8
- from .imports import *
9
-
10
-
11
- def save_model(model, model_save_directory):
12
- if not os.path.exists(model_save_directory):
13
- os.makedirs(model_save_directory)
14
-
15
- # Get the state dict
16
- if isinstance(model, nn.DataParallel):
17
- model_state_dict = (
18
- model.module.state_dict()
19
- ) # Use model.module to access the underlying model
20
- else:
21
- model_state_dict = model.state_dict()
22
-
23
- # Remove the "module." prefix from the keys if present
24
- model_state_dict = {
25
- k.replace("module.", ""): v for k, v in model_state_dict.items()
26
- }
27
-
28
- model_save_path = os.path.join(model_save_directory, "pytorch_model.bin")
29
- torch.save(model_state_dict, model_save_path)
30
-
31
- # Save the model configuration
32
- if isinstance(model, nn.DataParallel):
33
- model.module.config.to_json_file(
34
- os.path.join(model_save_directory, "config.json")
35
- )
36
- else:
37
- model.config.to_json_file(os.path.join(model_save_directory, "config.json"))
38
-
39
- print(f"Model and configuration saved to {model_save_directory}")
40
-
41
-
42
- def calculate_task_specific_metrics(task_true_labels, task_pred_labels):
43
- task_metrics = {}
44
- for task_name in task_true_labels.keys():
45
- true_labels = task_true_labels[task_name]
46
- pred_labels = task_pred_labels[task_name]
47
- f1 = f1_score(true_labels, pred_labels, average="macro")
48
- accuracy = accuracy_score(true_labels, pred_labels)
49
- task_metrics[task_name] = {"f1": f1, "accuracy": accuracy}
50
- return task_metrics
51
-
52
-
53
- def calculate_combined_f1(combined_labels, combined_preds):
54
- # Initialize the LabelEncoder
55
- le = LabelEncoder()
56
-
57
- # Fit and transform combined labels and predictions to numerical values
58
- le.fit(combined_labels + combined_preds)
59
- encoded_true_labels = le.transform(combined_labels)
60
- encoded_pred_labels = le.transform(combined_preds)
61
-
62
- # Print out the mapping for sanity check
63
- print("\nLabel Encoder Mapping:")
64
- for index, class_label in enumerate(le.classes_):
65
- print(f"'{class_label}': {index}")
66
-
67
- # Calculate accuracy
68
- accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels)
69
-
70
- # Calculate F1 Macro score
71
- f1 = f1_score(encoded_true_labels, encoded_pred_labels, average="macro")
72
-
73
- return f1, accuracy
74
-
75
-
76
- # def save_model_without_heads(original_model_save_directory):
77
- # # Create a new directory for the model without heads
78
- # new_model_save_directory = original_model_save_directory + "_No_Heads"
79
- # if not os.path.exists(new_model_save_directory):
80
- # os.makedirs(new_model_save_directory)
81
-
82
- # # Load the model state dictionary
83
- # model_state_dict = torch.load(
84
- # os.path.join(original_model_save_directory, "pytorch_model.bin")
85
- # )
86
-
87
- # # Initialize a new BERT model without the classification heads
88
- # config = BertConfig.from_pretrained(
89
- # os.path.join(original_model_save_directory, "config.json")
90
- # )
91
- # model_without_heads = BertModel(config)
92
-
93
- # # Filter the state dict to exclude classification heads
94
- # model_without_heads_state_dict = {
95
- # k: v
96
- # for k, v in model_state_dict.items()
97
- # if not k.startswith("classification_heads")
98
- # }
99
-
100
- # # Load the filtered state dict into the model
101
- # model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
-
103
- # # Save the model without heads
104
- # model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
- # torch.save(model_without_heads.state_dict(), model_save_path)
106
-
107
- # # Copy the configuration file
108
- # shutil.copy(
109
- # os.path.join(original_model_save_directory, "config.json"),
110
- # new_model_save_directory,
111
- # )
112
-
113
- # print(f"Model without classification heads saved to {new_model_save_directory}")
114
-
115
-
116
- def get_layer_freeze_range(pretrained_path):
117
- """
118
- Dynamically determines the number of layers to freeze based on the model depth from its configuration.
119
- Args:
120
- pretrained_path (str): Path to the pretrained model directory or model identifier.
121
- Returns:
122
- dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze.
123
- """
124
- if pretrained_path:
125
- config = AutoConfig.from_pretrained(pretrained_path)
126
- total_layers = config.num_hidden_layers
127
- return {"min": 0, "max": total_layers - 1}
128
- else:
129
- return {"min": 0, "max": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/mtl_classifier.py DELETED
@@ -1,363 +0,0 @@
1
- """
2
- Geneformer multi-task cell classifier.
3
-
4
- **Input data:**
5
-
6
- | Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain "unique_cell_id" column for logging.
7
-
8
- **Usage:**
9
-
10
- .. code-block :: python
11
-
12
- >>> from geneformer import MTLClassifier
13
- >>> mc = MTLClassifier(task_columns = ["task1", "task2"],
14
- ... study_name = "mtl",
15
- ... pretrained_path = "/path/pretrained/model",
16
- ... train_path = "/path/train/set",
17
- ... val_path = "/path/eval/set",
18
- ... test_path = "/path/test/set",
19
- ... model_save_path = "/results/directory/save_path",
20
- ... trials_result_path = "/results/directory/results.txt",
21
- ... results_dir = "/results/directory",
22
- ... tensorboard_log_dir = "/results/tblogdir",
23
- ... hyperparameters = hyperparameters)
24
- >>> mc.run_optuna_study()
25
- >>> mc.load_and_evaluate_test_model()
26
- >>> mc.save_model_without_heads()
27
- """
28
-
29
- import logging
30
- import os
31
-
32
- from .mtl import eval_utils, train_utils, utils
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- class MTLClassifier:
38
- valid_option_dict = {
39
- "task_columns": {list},
40
- "train_path": {None, str},
41
- "val_path": {None, str},
42
- "test_path": {None, str},
43
- "pretrained_path": {None, str},
44
- "model_save_path": {None, str},
45
- "results_dir": {None, str},
46
- "batch_size": {None, int},
47
- "n_trials": {None, int},
48
- "study_name": {None, str},
49
- "max_layers_to_freeze": {None, dict},
50
- "epochs": {None, int},
51
- "tensorboard_log_dir": {None, str},
52
- "use_data_parallel": {None, bool},
53
- "use_attention_pooling": {None, bool},
54
- "use_task_weights": {None, bool},
55
- "hyperparameters": {None, dict},
56
- "manual_hyperparameters": {None, dict},
57
- "use_manual_hyperparameters": {None, bool},
58
- "use_wandb": {None, bool},
59
- "wandb_project": {None, str},
60
- "gradient_clipping": {None, bool},
61
- "max_grad_norm": {None, int, float},
62
- "seed": {None, int},
63
- "trials_result_path": {None, str},
64
- }
65
-
66
- def __init__(
67
- self,
68
- task_columns=None,
69
- train_path=None,
70
- val_path=None,
71
- test_path=None,
72
- pretrained_path=None,
73
- model_save_path=None,
74
- results_dir=None,
75
- trials_result_path=None,
76
- batch_size=4,
77
- n_trials=15,
78
- study_name="mtl",
79
- max_layers_to_freeze=None,
80
- epochs=1,
81
- tensorboard_log_dir="/results/tblogdir",
82
- use_data_parallel=False,
83
- use_attention_pooling=True,
84
- use_task_weights=True,
85
- hyperparameters=None, # Default is None
86
- manual_hyperparameters=None, # Default is None
87
- use_manual_hyperparameters=False, # Default is False
88
- use_wandb=False,
89
- wandb_project=None,
90
- gradient_clipping=False,
91
- max_grad_norm=None,
92
- seed=42, # Default seed value
93
- ):
94
- """
95
- Initialize Geneformer multi-task classifier.
96
-
97
- **Parameters:**
98
-
99
- task_columns : list
100
- | List of tasks for cell state classification
101
- | Input data columns are labeled with corresponding task names
102
- study_name : None, str
103
- | Study name for labeling output files
104
- pretrained_path : None, str
105
- | Path to pretrained model
106
- train_path : None, str
107
- | Path to training dataset with task columns and "unique_cell_id" column
108
- val_path : None, str
109
- | Path to validation dataset with task columns and "unique_cell_id" column
110
- test_path : None, str
111
- | Path to test dataset with task columns and "unique_cell_id" column
112
- model_save_path : None, str
113
- | Path to directory to save output model (either full model or model without heads)
114
- trials_result_path : None, str
115
- | Path to directory to save hyperparameter tuning trial results
116
- results_dir : None, str
117
- | Path to directory to save results
118
- tensorboard_log_dir : None, str
119
- | Path to directory for Tensorboard logging results
120
- use_data_parallel : None, bool
121
- | Whether to use data parallelization
122
- use_attention_pooling : None, bool
123
- | Whether to use attention pooling
124
- use_task_weights : None, bool
125
- | Whether to use task weights
126
- batch_size : None, int
127
- | Batch size to use
128
- n_trials : None, int
129
- | Number of trials for hyperparameter tuning
130
- epochs : None, int
131
- | Number of epochs for training
132
- max_layers_to_freeze : None, dict
133
- | Dictionary with keys "min" and "max" indicating the min and max layers to freeze from fine-tuning (int)
134
- | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
135
- hyperparameters : None, dict
136
- | Dictionary of categorical max and min for each hyperparameter for tuning
137
- | For example:
138
- | {"learning_rate": {"type":"float", "low":"1e-5", "high":"1e-3", "log":True}, "task_weights": {...}, ...}
139
- manual_hyperparameters : None, dict
140
- | Dictionary of manually set value for each hyperparameter
141
- | For example:
142
- | {"learning_rate": 0.001, "task_weights": [1, 1], ...}
143
- use_manual_hyperparameters : None, bool
144
- | Whether to use manually set hyperparameters
145
- use_wandb : None, bool
146
- | Whether to use Weights & Biases for logging
147
- wandb_project : None, str
148
- | Weights & Biases project name
149
- gradient_clipping : None, bool
150
- | Whether to use gradient clipping
151
- max_grad_norm : None, int, float
152
- | Maximum norm for gradient clipping
153
- seed : None, int
154
- | Random seed
155
- """
156
-
157
- self.task_columns = task_columns
158
- self.train_path = train_path
159
- self.val_path = val_path
160
- self.test_path = test_path
161
- self.pretrained_path = pretrained_path
162
- self.model_save_path = model_save_path
163
- self.results_dir = results_dir
164
- self.trials_result_path = trials_result_path
165
- self.batch_size = batch_size
166
- self.n_trials = n_trials
167
- self.study_name = study_name
168
-
169
- if max_layers_to_freeze is None:
170
- # Dynamically determine the range of layers to freeze
171
- layer_freeze_range = utils.get_layer_freeze_range(pretrained_path)
172
- self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range["max"]}
173
- else:
174
- self.max_layers_to_freeze = max_layers_to_freeze
175
-
176
- self.epochs = epochs
177
- self.tensorboard_log_dir = tensorboard_log_dir
178
- self.use_data_parallel = use_data_parallel
179
- self.use_attention_pooling = use_attention_pooling
180
- self.use_task_weights = use_task_weights
181
- self.hyperparameters = (
182
- hyperparameters
183
- if hyperparameters is not None
184
- else {
185
- "learning_rate": {
186
- "type": "float",
187
- "low": 1e-5,
188
- "high": 1e-3,
189
- "log": True,
190
- },
191
- "warmup_ratio": {"type": "float", "low": 0.005, "high": 0.01},
192
- "weight_decay": {"type": "float", "low": 0.01, "high": 0.1},
193
- "dropout_rate": {"type": "float", "low": 0.0, "high": 0.7},
194
- "lr_scheduler_type": {"type": "categorical", "choices": ["cosine"]},
195
- "task_weights": {"type": "float", "low": 0.1, "high": 2.0},
196
- }
197
- )
198
- self.manual_hyperparameters = (
199
- manual_hyperparameters
200
- if manual_hyperparameters is not None
201
- else {
202
- "learning_rate": 0.001,
203
- "warmup_ratio": 0.01,
204
- "weight_decay": 0.1,
205
- "dropout_rate": 0.1,
206
- "lr_scheduler_type": "cosine",
207
- "use_attention_pooling": False,
208
- "task_weights": [1, 1],
209
- "max_layers_to_freeze": 2,
210
- }
211
- )
212
- self.use_manual_hyperparameters = use_manual_hyperparameters
213
- self.use_wandb = use_wandb
214
- self.wandb_project = wandb_project
215
- self.gradient_clipping = gradient_clipping
216
- self.max_grad_norm = max_grad_norm
217
- self.seed = seed
218
-
219
- if self.use_manual_hyperparameters:
220
- logger.warning(
221
- "Hyperparameter tuning is highly recommended for optimal results."
222
- )
223
-
224
- self.validate_options()
225
-
226
- # set up output directories
227
- if self.results_dir is not None:
228
- self.trials_results_path = f"{self.results_dir}/results.txt".replace(
229
- "//", "/"
230
- )
231
-
232
- for output_dir in [self.model_save_path, self.results_dir]:
233
- if not os.path.exists(output_dir):
234
- os.makedirs(output_dir)
235
-
236
- self.config = {
237
- key: value
238
- for key, value in self.__dict__.items()
239
- if key in self.valid_option_dict
240
- }
241
-
242
- def validate_options(self):
243
- # confirm arguments are within valid options and compatible with each other
244
- for attr_name, valid_options in self.valid_option_dict.items():
245
- attr_value = self.__dict__[attr_name]
246
- if not isinstance(attr_value, (list, dict)):
247
- if attr_value in valid_options:
248
- continue
249
- valid_type = False
250
- for option in valid_options:
251
- if (option in [int, float, list, dict, bool, str]) and isinstance(
252
- attr_value, option
253
- ):
254
- valid_type = True
255
- break
256
- if valid_type:
257
- continue
258
- logger.error(
259
- f"Invalid option for {attr_name}. "
260
- f"Valid options for {attr_name}: {valid_options}"
261
- )
262
- raise ValueError(
263
- f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}"
264
- )
265
-
266
- def run_manual_tuning(self):
267
- """
268
- Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.
269
- """
270
- required_variable_names = [
271
- "train_path",
272
- "val_path",
273
- "pretrained_path",
274
- "model_save_path",
275
- "results_dir",
276
- ]
277
- required_variables = [
278
- self.train_path,
279
- self.val_path,
280
- self.pretrained_path,
281
- self.model_save_path,
282
- self.results_dir,
283
- ]
284
- req_var_dict = dict(zip(required_variable_names, required_variables))
285
- self.validate_additional_options(req_var_dict)
286
-
287
- if not self.use_manual_hyperparameters:
288
- raise ValueError(
289
- "Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True."
290
- )
291
-
292
- # Ensure manual_hyperparameters are set in the config
293
- self.config["manual_hyperparameters"] = self.manual_hyperparameters
294
- self.config["use_manual_hyperparameters"] = True
295
-
296
- train_utils.run_manual_tuning(self.config)
297
-
298
- def validate_additional_options(self, req_var_dict):
299
- missing_variable = False
300
- for variable_name, variable in req_var_dict.items():
301
- if variable is None:
302
- logger.warning(
303
- f"Please provide value to MTLClassifier for required variable {variable_name}"
304
- )
305
- missing_variable = True
306
- if missing_variable is True:
307
- raise ValueError("Missing required variables for MTLClassifier")
308
-
309
- def run_optuna_study(
310
- self,
311
- ):
312
- """
313
- Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.
314
- """
315
-
316
- required_variable_names = [
317
- "train_path",
318
- "val_path",
319
- "pretrained_path",
320
- "model_save_path",
321
- "results_dir",
322
- ]
323
- required_variables = [
324
- self.train_path,
325
- self.val_path,
326
- self.pretrained_path,
327
- self.model_save_path,
328
- self.results_dir,
329
- ]
330
- req_var_dict = dict(zip(required_variable_names, required_variables))
331
- self.validate_additional_options(req_var_dict)
332
-
333
- train_utils.run_optuna_study(self.config)
334
-
335
- def load_and_evaluate_test_model(
336
- self,
337
- ):
338
- """
339
- Loads previously fine-tuned multi-task model and evaluates on test data.
340
- """
341
-
342
- required_variable_names = ["test_path", "model_save_path", "results_dir"]
343
- required_variables = [self.test_path, self.model_save_path, self.results_dir]
344
- req_var_dict = dict(zip(required_variable_names, required_variables))
345
- self.validate_additional_options(req_var_dict)
346
-
347
- eval_utils.load_and_evaluate_test_model(self.config)
348
-
349
- # def save_model_without_heads(
350
- # self,
351
- # ):
352
- # """
353
- # Save previously fine-tuned multi-task model without classification heads.
354
- # """
355
-
356
- # required_variable_names = ["model_save_path"]
357
- # required_variables = [self.model_save_path]
358
- # req_var_dict = dict(zip(required_variable_names, required_variables))
359
- # self.validate_additional_options(req_var_dict)
360
-
361
- # utils.save_model_without_heads(
362
- # os.path.join(self.model_save_path, "GeneformerMultiTask")
363
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/perturber_utils.py DELETED
@@ -1,919 +0,0 @@
1
- import itertools as it
2
- import logging
3
- import pickle
4
- from collections import defaultdict
5
- from pathlib import Path
6
- from typing import List
7
-
8
- import numpy as np
9
- import pandas as pd
10
- import torch
11
- from datasets import Dataset, load_from_disk
12
- from peft import LoraConfig, get_peft_model
13
- from transformers import (
14
- BertForMaskedLM,
15
- BertForSequenceClassification,
16
- BertForTokenClassification,
17
- BitsAndBytesConfig,
18
- )
19
-
20
- GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
21
- TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
22
- ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
23
-
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- # load data and filter by defined criteria
29
- def load_and_filter(filter_data, nproc, input_data_file):
30
- data = load_from_disk(input_data_file)
31
- if filter_data is not None:
32
- data = filter_by_dict(data, filter_data, nproc)
33
- return data
34
-
35
-
36
- def filter_by_dict(data, filter_data, nproc):
37
- for key, value in filter_data.items():
38
-
39
- def filter_data_by_criteria(example):
40
- return example[key] in value
41
-
42
- data = data.filter(filter_data_by_criteria, num_proc=nproc)
43
- if len(data) == 0:
44
- logger.error("No cells remain after filtering. Check filtering criteria.")
45
- raise
46
- return data
47
-
48
-
49
- def filter_data_by_tokens(filtered_input_data, tokens, nproc):
50
- def if_has_tokens(example):
51
- return len(set(example["input_ids"]).intersection(tokens)) == len(tokens)
52
-
53
- filtered_input_data = filtered_input_data.filter(if_has_tokens, num_proc=nproc)
54
- return filtered_input_data
55
-
56
-
57
- def logging_filtered_data_len(filtered_input_data, filtered_tokens_categ):
58
- if len(filtered_input_data) == 0:
59
- logger.error(f"No cells in dataset contain {filtered_tokens_categ}.")
60
- raise
61
- else:
62
- logger.info(f"# cells with {filtered_tokens_categ}: {len(filtered_input_data)}")
63
-
64
-
65
- def filter_data_by_tokens_and_log(
66
- filtered_input_data, tokens, nproc, filtered_tokens_categ
67
- ):
68
- # filter for cells with anchor gene
69
- filtered_input_data = filter_data_by_tokens(filtered_input_data, tokens, nproc)
70
- # logging length of filtered data
71
- logging_filtered_data_len(filtered_input_data, filtered_tokens_categ)
72
-
73
- return filtered_input_data
74
-
75
-
76
- def filter_data_by_start_state(filtered_input_data, cell_states_to_model, nproc):
77
- # confirm that start state is valid to prevent futile filtering
78
- state_key = cell_states_to_model["state_key"]
79
- state_values = filtered_input_data[state_key]
80
- start_state = cell_states_to_model["start_state"]
81
- if start_state not in state_values:
82
- logger.error(
83
- f"Start state {start_state} is not present "
84
- f"in the dataset's {state_key} attribute."
85
- )
86
- raise
87
-
88
- # filter for start state cells
89
- def filter_for_origin(example):
90
- return example[state_key] in [start_state]
91
-
92
- filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=nproc)
93
- return filtered_input_data
94
-
95
-
96
- def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
97
- if cell_inds_to_perturb["start"] >= len(filtered_input_data):
98
- logger.error(
99
- "cell_inds_to_perturb['start'] is larger than the filtered dataset."
100
- )
101
- raise
102
- if cell_inds_to_perturb["end"] > len(filtered_input_data):
103
- logger.warning(
104
- "cell_inds_to_perturb['end'] is larger than the filtered dataset. \
105
- Setting to the end of the filtered dataset."
106
- )
107
- cell_inds_to_perturb["end"] = len(filtered_input_data)
108
- filtered_input_data = filtered_input_data.select(
109
- [i for i in range(cell_inds_to_perturb["start"], cell_inds_to_perturb["end"])]
110
- )
111
- return filtered_input_data
112
-
113
-
114
- # load model to GPU
115
- def load_model(model_type, num_classes, model_directory, mode, quantize=False):
116
- if model_type == "MTLCellClassifier-Quantized":
117
- model_type = "MTLCellClassifier"
118
- quantize = True
119
-
120
- output_hidden_states = (mode == "eval")
121
-
122
- # Quantization logic
123
- if quantize:
124
- if model_type == "MTLCellClassifier":
125
- quantize_config = BitsAndBytesConfig(load_in_8bit=True)
126
- peft_config = None
127
- else:
128
- quantize_config = BitsAndBytesConfig(
129
- load_in_4bit=True,
130
- bnb_4bit_use_double_quant=True,
131
- bnb_4bit_quant_type="nf4",
132
- bnb_4bit_compute_dtype=torch.bfloat16,
133
- )
134
- peft_config = LoraConfig(
135
- lora_alpha=128,
136
- lora_dropout=0.1,
137
- r=64,
138
- bias="none",
139
- task_type="TokenClassification",
140
- )
141
- else:
142
- quantize_config = None
143
- peft_config = None
144
-
145
- # Model class selection
146
- model_classes = {
147
- "Pretrained": BertForMaskedLM,
148
- "GeneClassifier": BertForTokenClassification,
149
- "CellClassifier": BertForSequenceClassification,
150
- "MTLCellClassifier": BertForMaskedLM
151
- }
152
-
153
- model_class = model_classes.get(model_type)
154
- if not model_class:
155
- raise ValueError(f"Unknown model type: {model_type}")
156
-
157
- # Model loading
158
- model_args = {
159
- "pretrained_model_name_or_path": model_directory,
160
- "output_hidden_states": output_hidden_states,
161
- "output_attentions": False,
162
- }
163
-
164
- if model_type != "Pretrained":
165
- model_args["num_labels"] = num_classes
166
-
167
- if quantize_config:
168
- model_args["quantization_config"] = quantize_config
169
-
170
- # Load the model
171
- model = model_class.from_pretrained(**model_args)
172
-
173
- if mode == "eval":
174
- model.eval()
175
-
176
- # Handle device placement and PEFT
177
- if not quantize:
178
- # Only move non-quantized models
179
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
- model = model.to(device)
181
- elif peft_config:
182
- # Apply PEFT for quantized models (except MTLCellClassifier)
183
- model.enable_input_require_grads()
184
- model = get_peft_model(model, peft_config)
185
-
186
- return model
187
-
188
- def quant_layers(model):
189
- layer_nums = []
190
- for name, parameter in model.named_parameters():
191
- if "layer" in name:
192
- layer_nums += [int(name.split("layer.")[1].split(".")[0])]
193
- return int(max(layer_nums)) + 1
194
-
195
-
196
- def get_model_emb_dims(model):
197
- return model.config.hidden_size
198
-
199
-
200
- def get_model_input_size(model):
201
- return model.config.max_position_embeddings
202
-
203
-
204
- def flatten_list(megalist):
205
- return [item for sublist in megalist for item in sublist]
206
-
207
-
208
- def measure_length(example):
209
- example["length"] = len(example["input_ids"])
210
- return example
211
-
212
-
213
- def downsample_and_sort(data, max_ncells):
214
- num_cells = len(data)
215
- # if max number of cells is defined, then shuffle and subsample to this max number
216
- if max_ncells is not None:
217
- if num_cells > max_ncells:
218
- data = data.shuffle(seed=42)
219
- num_cells = max_ncells
220
- data_subset = data.select([i for i in range(num_cells)])
221
- # sort dataset with largest cell first to encounter any memory errors earlier
222
- data_sorted = data_subset.sort("length", reverse=True)
223
- return data_sorted
224
-
225
-
226
- def get_possible_states(cell_states_to_model):
227
- possible_states = []
228
- for key in ["start_state", "goal_state"]:
229
- possible_states += [cell_states_to_model[key]]
230
- possible_states += cell_states_to_model.get("alt_states", [])
231
- return possible_states
232
-
233
-
234
- def forward_pass_single_cell(model, example_cell, layer_to_quant):
235
- example_cell.set_format(type="torch")
236
- input_data = example_cell["input_ids"]
237
- with torch.no_grad():
238
- outputs = model(input_ids=input_data.to("cuda"))
239
- emb = torch.squeeze(outputs.hidden_states[layer_to_quant])
240
- del outputs
241
- return emb
242
-
243
-
244
- def perturb_emb_by_index(emb, indices):
245
- mask = torch.ones(emb.numel(), dtype=torch.bool)
246
- mask[indices] = False
247
- return emb[mask]
248
-
249
-
250
- def delete_indices(example):
251
- indices = example["perturb_index"]
252
- if any(isinstance(el, list) for el in indices):
253
- indices = flatten_list(indices)
254
- for index in sorted(indices, reverse=True):
255
- del example["input_ids"][index]
256
-
257
- example["length"] = len(example["input_ids"])
258
- return example
259
-
260
-
261
- # for genes_to_perturb = "all" where only genes within cell are overexpressed
262
- def overexpress_indices(example):
263
- indices = example["perturb_index"]
264
- if any(isinstance(el, list) for el in indices):
265
- indices = flatten_list(indices)
266
- insert_pos = 0
267
- for index in sorted(indices, reverse=False):
268
- example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
269
- insert_pos += 1
270
- example["length"] = len(example["input_ids"])
271
- return example
272
-
273
-
274
- # if CLS token present, move to 1st rather than 0th position
275
- def overexpress_indices_special(example):
276
- indices = example["perturb_index"]
277
- if any(isinstance(el, list) for el in indices):
278
- indices = flatten_list(indices)
279
- insert_pos = 1 # Insert starting after CLS token
280
- for index in sorted(indices, reverse=False):
281
- example["input_ids"].insert(insert_pos, example["input_ids"].pop(index))
282
- insert_pos += 1
283
- example["length"] = len(example["input_ids"])
284
- return example
285
-
286
-
287
- # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
288
- def overexpress_tokens(example, max_len, special_token):
289
- # -100 indicates tokens to overexpress are not present in rank value encoding
290
- if example["perturb_index"] != [-100]:
291
- example = delete_indices(example)
292
- if special_token:
293
- [
294
- example["input_ids"].insert(1, token)
295
- for token in example["tokens_to_perturb"][::-1]
296
- ]
297
- else:
298
- [
299
- example["input_ids"].insert(0, token)
300
- for token in example["tokens_to_perturb"][::-1]
301
- ]
302
-
303
- # truncate to max input size, must also truncate original emb to be comparable
304
- if len(example["input_ids"]) > max_len:
305
- if special_token:
306
- example["input_ids"] = example["input_ids"][0 : max_len - 1] + [
307
- example["input_ids"][-1]
308
- ]
309
- else:
310
- example["input_ids"] = example["input_ids"][0:max_len]
311
- example["length"] = len(example["input_ids"])
312
- return example
313
-
314
-
315
- def calc_n_overflow(max_len, example_len, tokens_to_perturb, indices_to_perturb):
316
- n_to_add = len(tokens_to_perturb) - len(indices_to_perturb)
317
- n_overflow = example_len + n_to_add - max_len
318
- return n_overflow
319
-
320
-
321
- def truncate_by_n_overflow(example):
322
- new_max_len = example["length"] - example["n_overflow"]
323
- example["input_ids"] = example["input_ids"][0:new_max_len]
324
- example["length"] = len(example["input_ids"])
325
- return example
326
-
327
-
328
- def truncate_by_n_overflow_special(example):
329
- if example["n_overflow"] > 0:
330
- new_max_len = example["length"] - example["n_overflow"]
331
- example["input_ids"] = example["input_ids"][0 : new_max_len - 1] + [
332
- example["input_ids"][-1]
333
- ]
334
- example["length"] = len(example["input_ids"])
335
- return example
336
-
337
-
338
- def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
339
- # indices_to_remove is list of indices to remove
340
- indices_to_keep = [
341
- i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove
342
- ]
343
- num_dims = emb.dim()
344
- emb_slice = [
345
- slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)
346
- ]
347
- sliced_emb = emb[emb_slice]
348
- return sliced_emb
349
-
350
-
351
- def remove_indices_from_emb_batch(emb_batch, list_of_indices_to_remove, gene_dim):
352
- output_batch_list = [
353
- remove_indices_from_emb(emb_batch[i, :, :], idxes, gene_dim - 1)
354
- for i, idxes in enumerate(list_of_indices_to_remove)
355
- ]
356
- # add padding given genes are sometimes added that are or are not in original cell
357
- batch_max = max([emb.size()[gene_dim - 1] for emb in output_batch_list])
358
- output_batch_list_padded = [
359
- pad_xd_tensor(emb, 0.000, batch_max, gene_dim - 1) for emb in output_batch_list
360
- ]
361
- return torch.stack(output_batch_list_padded)
362
-
363
-
364
- # removes perturbed indices
365
- # need to handle the various cases where a set of genes is overexpressed
366
- def remove_perturbed_indices_set(
367
- emb,
368
- perturb_type: str,
369
- indices_to_perturb: List[List],
370
- tokens_to_perturb: List[List],
371
- original_lengths: List[int],
372
- input_ids=None,
373
- ):
374
- if perturb_type == "overexpress":
375
- num_perturbed = len(tokens_to_perturb)
376
- if num_perturbed == 1:
377
- indices_to_perturb_orig = [
378
- idx if idx != [-100] else [None] for idx in indices_to_perturb
379
- ]
380
- if all(v is [None] for v in indices_to_perturb_orig):
381
- return emb
382
- else:
383
- indices_to_perturb_orig = []
384
-
385
- for idx_list in indices_to_perturb:
386
- indices_to_perturb_orig.append(
387
- [idx if idx != [-100] else [None] for idx in idx_list]
388
- )
389
-
390
- else:
391
- indices_to_perturb_orig = indices_to_perturb
392
-
393
- emb = remove_indices_from_emb_batch(emb, indices_to_perturb_orig, gene_dim=1)
394
-
395
- return emb
396
-
397
-
398
- def make_perturbation_batch(
399
- example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
400
- ) -> tuple[Dataset, List[int]]:
401
- if combo_lvl == 0 and tokens_to_perturb == "all":
402
- if perturb_type in ["overexpress", "activate"]:
403
- range_start = 1
404
- elif perturb_type in ["delete", "inhibit"]:
405
- range_start = 0
406
- indices_to_perturb = [
407
- [i] for i in range(range_start, example_cell["length"][0])
408
- ]
409
- # elif combo_lvl > 0 and anchor_token is None:
410
- ## to implement
411
- elif combo_lvl > 0 and (anchor_token is not None):
412
- example_input_ids = example_cell["input_ids"][0]
413
- anchor_index = example_input_ids.index(anchor_token[0])
414
- indices_to_perturb = [
415
- sorted([anchor_index, i]) if i != anchor_index else None
416
- for i in range(example_cell["length"][0])
417
- ]
418
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
419
- else:
420
- example_input_ids = example_cell["input_ids"][0]
421
- indices_to_perturb = [
422
- [example_input_ids.index(token)] if token in example_input_ids else None
423
- for token in tokens_to_perturb
424
- ]
425
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
426
-
427
- # create all permutations of combo_lvl of modifiers from tokens_to_perturb
428
- if combo_lvl > 0 and (anchor_token is None):
429
- if tokens_to_perturb != "all":
430
- if len(tokens_to_perturb) == combo_lvl + 1:
431
- indices_to_perturb = [
432
- list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
433
- ]
434
- else:
435
- all_indices = [[i] for i in range(example_cell["length"][0])]
436
- all_indices = [
437
- index for index in all_indices if index not in indices_to_perturb
438
- ]
439
- indices_to_perturb = [
440
- [[j for i in indices_to_perturb for j in i], x] for x in all_indices
441
- ]
442
-
443
- length = len(indices_to_perturb)
444
- perturbation_dataset = Dataset.from_dict(
445
- {
446
- "input_ids": example_cell["input_ids"] * length,
447
- "perturb_index": indices_to_perturb,
448
- }
449
- )
450
-
451
- if length < 400:
452
- num_proc_i = 1
453
- else:
454
- num_proc_i = num_proc
455
-
456
- if perturb_type == "delete":
457
- perturbation_dataset = perturbation_dataset.map(
458
- delete_indices, num_proc=num_proc_i
459
- )
460
- elif perturb_type == "overexpress":
461
- perturbation_dataset = perturbation_dataset.map(
462
- overexpress_indices, num_proc=num_proc_i
463
- )
464
-
465
- perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
466
-
467
- return perturbation_dataset, indices_to_perturb
468
-
469
-
470
- def make_perturbation_batch_special(
471
- example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
472
- ) -> tuple[Dataset, List[int]]:
473
- if combo_lvl == 0 and tokens_to_perturb == "all":
474
- if perturb_type in ["overexpress", "activate"]:
475
- range_start = 1
476
- elif perturb_type in ["delete", "inhibit"]:
477
- range_start = 0
478
- range_start += 1 # Starting after the CLS token
479
- indices_to_perturb = [
480
- [i]
481
- for i in range(
482
- range_start, example_cell["length"][0] - 1
483
- ) # And excluding the EOS token
484
- ]
485
-
486
- # elif combo_lvl > 0 and anchor_token is None:
487
- ## to implement
488
- elif combo_lvl > 0 and (anchor_token is not None):
489
- example_input_ids = example_cell["input_ids"][0]
490
- anchor_index = example_input_ids.index(anchor_token[0])
491
- indices_to_perturb = [
492
- sorted([anchor_index, i]) if i != anchor_index else None
493
- for i in range(
494
- 1, example_cell["length"][0] - 1
495
- ) # Exclude CLS and EOS tokens
496
- ]
497
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
498
- else:
499
- example_input_ids = example_cell["input_ids"][0]
500
- indices_to_perturb = [
501
- [example_input_ids.index(token)] if token in example_input_ids else None
502
- for token in tokens_to_perturb
503
- ]
504
- indices_to_perturb = [item for item in indices_to_perturb if item is not None]
505
-
506
- # create all permutations of combo_lvl of modifiers from tokens_to_perturb
507
- if combo_lvl > 0 and (anchor_token is None):
508
- if tokens_to_perturb != "all":
509
- if len(tokens_to_perturb) == combo_lvl + 1:
510
- indices_to_perturb = [
511
- list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1)
512
- ]
513
- else:
514
- all_indices = [
515
- [i] for i in range(1, example_cell["length"][0] - 1)
516
- ] # Exclude CLS and EOS tokens
517
- all_indices = [
518
- index for index in all_indices if index not in indices_to_perturb
519
- ]
520
- indices_to_perturb = [
521
- [[j for i in indices_to_perturb for j in i], x] for x in all_indices
522
- ]
523
-
524
- length = len(indices_to_perturb)
525
- perturbation_dataset = Dataset.from_dict(
526
- {
527
- "input_ids": example_cell["input_ids"] * length,
528
- "perturb_index": indices_to_perturb,
529
- }
530
- )
531
-
532
- if length < 400:
533
- num_proc_i = 1
534
- else:
535
- num_proc_i = num_proc
536
-
537
- if perturb_type == "delete":
538
- perturbation_dataset = perturbation_dataset.map(
539
- delete_indices, num_proc=num_proc_i
540
- )
541
- elif perturb_type == "overexpress":
542
- perturbation_dataset = perturbation_dataset.map(
543
- overexpress_indices_special, num_proc=num_proc_i
544
- )
545
-
546
- perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
547
-
548
- return perturbation_dataset, indices_to_perturb
549
-
550
-
551
- # original cell emb removing the activated/overexpressed/inhibited gene emb
552
- # so that only non-perturbed gene embeddings are compared to each other
553
- # in original or perturbed context
554
- def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group):
555
- all_embs_list = []
556
-
557
- # if making comparison batch for multiple perturbations in single cell
558
- if perturb_group is False:
559
- # squeeze if single cell
560
- if original_emb_batch.ndim == 3 and original_emb_batch.size()[0] == 1:
561
- original_emb_batch = torch.squeeze(original_emb_batch)
562
- original_emb_list = [original_emb_batch] * len(indices_to_perturb)
563
- # if making comparison batch for single perturbation in multiple cells
564
- elif perturb_group is True:
565
- original_emb_list = original_emb_batch
566
-
567
- for original_emb, indices in zip(original_emb_list, indices_to_perturb):
568
- if indices == [-100]:
569
- all_embs_list += [original_emb[:]]
570
- continue
571
-
572
- emb_list = []
573
- start = 0
574
- if any(isinstance(el, list) for el in indices):
575
- indices = flatten_list(indices)
576
-
577
- # removes indices that were perturbed from the original embedding
578
- for i in sorted(indices):
579
- emb_list += [original_emb[start:i]]
580
- start = i + 1
581
-
582
- emb_list += [original_emb[start:]]
583
- all_embs_list += [torch.cat(emb_list)]
584
-
585
- len_set = set([emb.size()[0] for emb in all_embs_list])
586
- if len(len_set) > 1:
587
- max_len = max(len_set)
588
- all_embs_list = [pad_2d_tensor(emb, None, max_len, 0) for emb in all_embs_list]
589
- return torch.stack(all_embs_list)
590
-
591
-
592
- def pad_list(input_ids, pad_token_id, max_len):
593
- input_ids = np.pad(
594
- input_ids,
595
- (0, max_len - len(input_ids)),
596
- mode="constant",
597
- constant_values=pad_token_id,
598
- )
599
- return input_ids
600
-
601
-
602
- def pad_xd_tensor(tensor, pad_token_id, max_len, dim):
603
- padding_length = max_len - tensor.size()[dim]
604
- # Construct a padding configuration where all padding values are 0, except for the padding dimension
605
- # 2 * number of dimensions (padding before and after for every dimension)
606
- pad_config = [0] * 2 * tensor.dim()
607
- # Set the padding after the desired dimension to the calculated padding length
608
- pad_config[-2 * dim - 1] = padding_length
609
- return torch.nn.functional.pad(
610
- tensor, pad=pad_config, mode="constant", value=pad_token_id
611
- )
612
-
613
-
614
- def pad_tensor(tensor, pad_token_id, max_len):
615
- tensor = torch.nn.functional.pad(
616
- tensor, pad=(0, max_len - tensor.numel()), mode="constant", value=pad_token_id
617
- )
618
-
619
- return tensor
620
-
621
-
622
- def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
623
- if dim == 0:
624
- pad = (0, 0, 0, max_len - tensor.size()[dim])
625
- elif dim == 1:
626
- pad = (0, max_len - tensor.size()[dim], 0, 0)
627
- tensor = torch.nn.functional.pad(
628
- tensor, pad=pad, mode="constant", value=pad_token_id
629
- )
630
- return tensor
631
-
632
-
633
- def pad_3d_tensor(tensor, pad_token_id, max_len, dim):
634
- if dim == 0:
635
- raise Exception("dim 0 usually does not need to be padded.")
636
- if dim == 1:
637
- pad = (0, 0, 0, max_len - tensor.size()[dim])
638
- elif dim == 2:
639
- pad = (0, max_len - tensor.size()[dim], 0, 0)
640
- tensor = torch.nn.functional.pad(
641
- tensor, pad=pad, mode="constant", value=pad_token_id
642
- )
643
- return tensor
644
-
645
-
646
- def pad_or_truncate_encoding(encoding, pad_token_id, max_len):
647
- if isinstance(encoding, torch.Tensor):
648
- encoding_len = encoding.size()[0]
649
- elif isinstance(encoding, list):
650
- encoding_len = len(encoding)
651
- if encoding_len > max_len:
652
- encoding = encoding[0:max_len]
653
- elif encoding_len < max_len:
654
- if isinstance(encoding, torch.Tensor):
655
- encoding = pad_tensor(encoding, pad_token_id, max_len)
656
- elif isinstance(encoding, list):
657
- encoding = pad_list(encoding, pad_token_id, max_len)
658
- return encoding
659
-
660
-
661
- # pad list of tensors and convert to tensor
662
- def pad_tensor_list(
663
- tensor_list,
664
- dynamic_or_constant,
665
- pad_token_id,
666
- model_input_size,
667
- dim=None,
668
- padding_func=None,
669
- ):
670
- # determine maximum tensor length
671
- if dynamic_or_constant == "dynamic":
672
- max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
673
- elif isinstance(dynamic_or_constant, int):
674
- max_len = dynamic_or_constant
675
- else:
676
- max_len = model_input_size
677
- logger.warning(
678
- "If padding style is constant, must provide integer value. "
679
- f"Setting padding to max input size {model_input_size}."
680
- )
681
-
682
- # pad all tensors to maximum length
683
- if dim is None:
684
- tensor_list = [
685
- pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list
686
- ]
687
- else:
688
- tensor_list = [
689
- padding_func(tensor, pad_token_id, max_len, dim) for tensor in tensor_list
690
- ]
691
- # return stacked tensors
692
- if padding_func != pad_3d_tensor:
693
- return torch.stack(tensor_list)
694
- else:
695
- return torch.cat(tensor_list, 0)
696
-
697
-
698
- def gen_attention_mask(minibatch_encoding, max_len=None):
699
- if max_len is None:
700
- max_len = max(minibatch_encoding["length"])
701
- original_lens = minibatch_encoding["length"]
702
- attention_mask = [
703
- [1] * original_len + [0] * (max_len - original_len)
704
- if original_len <= max_len
705
- else [1] * max_len
706
- for original_len in original_lens
707
- ]
708
- return torch.tensor(attention_mask, device="cuda")
709
-
710
-
711
- # get cell embeddings excluding padding
712
- def mean_nonpadding_embs(embs, original_lens, dim=1):
713
- # create a mask tensor based on padding lengths
714
- mask = torch.arange(embs.size(dim), device=embs.device) < original_lens.unsqueeze(1)
715
- if embs.dim() == 3:
716
- # fill the masked positions in embs with zeros
717
- masked_embs = embs.masked_fill(~mask.unsqueeze(2), 0.0)
718
-
719
- # compute the mean across the non-padding dimensions
720
- mean_embs = masked_embs.sum(dim) / original_lens.view(-1, 1).float()
721
-
722
- elif embs.dim() == 2:
723
- masked_embs = embs.masked_fill(~mask, 0.0)
724
- mean_embs = masked_embs.sum(dim) / original_lens.float()
725
- return mean_embs
726
-
727
-
728
- # get cell embeddings when there is no padding
729
- def compute_nonpadded_cell_embedding(embs, cell_emb_style):
730
- if cell_emb_style == "mean_pool":
731
- return torch.mean(embs, dim=embs.ndim - 2)
732
-
733
-
734
- # quantify shifts for a set of genes
735
- def quant_cos_sims(
736
- perturbation_emb,
737
- original_emb,
738
- cell_states_to_model,
739
- state_embs_dict,
740
- emb_mode="gene",
741
- ):
742
- if emb_mode == "gene":
743
- cos = torch.nn.CosineSimilarity(dim=2)
744
- elif emb_mode == "cell":
745
- cos = torch.nn.CosineSimilarity(dim=1)
746
-
747
- # if emb_mode == "gene", can only calculate gene cos sims
748
- # against original cell
749
- if cell_states_to_model is None or emb_mode == "gene":
750
- cos_sims = cos(perturbation_emb, original_emb).to("cuda")
751
-
752
- elif cell_states_to_model is not None and emb_mode == "cell":
753
- possible_states = get_possible_states(cell_states_to_model)
754
- cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))]))
755
- for state in possible_states:
756
- cos_sims[state] = cos_sim_shift(
757
- original_emb,
758
- perturbation_emb,
759
- state_embs_dict[state].to("cuda"), # required to move to cuda here
760
- cos,
761
- )
762
-
763
- return cos_sims
764
-
765
-
766
- # calculate cos sim shift of perturbation with respect to origin and alternative cell
767
- def cos_sim_shift(original_emb, perturbed_emb, end_emb, cos):
768
- origin_v_end = cos(original_emb, end_emb)
769
- perturb_v_end = cos(perturbed_emb, end_emb)
770
-
771
- return perturb_v_end - origin_v_end
772
-
773
-
774
- def concatenate_cos_sims(cos_sims):
775
- if isinstance(cos_sims, list):
776
- return torch.cat(cos_sims)
777
- else:
778
- for state in cos_sims.keys():
779
- cos_sims[state] = torch.cat(cos_sims[state])
780
- return cos_sims
781
-
782
-
783
- def write_perturbation_dictionary(cos_sims_dict: defaultdict, output_path_prefix: str):
784
- with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
785
- pickle.dump(cos_sims_dict, fp)
786
-
787
-
788
- def tensor_list_to_pd(tensor_list):
789
- tensor = torch.cat(tensor_list).cpu().numpy()
790
- df = pd.DataFrame(tensor)
791
- return df
792
-
793
-
794
- def validate_cell_states_to_model(cell_states_to_model):
795
- if cell_states_to_model is not None:
796
- if len(cell_states_to_model.items()) == 1:
797
- logger.warning(
798
- "The single value dictionary for cell_states_to_model will be "
799
- "replaced with a dictionary with named keys for start, goal, and alternate states. "
800
- "Please specify state_key, start_state, goal_state, and alt_states "
801
- "in the cell_states_to_model dictionary for future use. "
802
- "For example, cell_states_to_model={"
803
- "'state_key': 'disease', "
804
- "'start_state': 'dcm', "
805
- "'goal_state': 'nf', "
806
- "'alt_states': ['hcm', 'other1', 'other2']}"
807
- )
808
- for key, value in cell_states_to_model.items():
809
- if (len(value) == 3) and isinstance(value, tuple):
810
- if (
811
- isinstance(value[0], list)
812
- and isinstance(value[1], list)
813
- and isinstance(value[2], list)
814
- ):
815
- if len(value[0]) == 1 and len(value[1]) == 1:
816
- all_values = value[0] + value[1] + value[2]
817
- if len(all_values) == len(set(all_values)):
818
- continue
819
- # reformat to the new named key format
820
- state_values = flatten_list(list(cell_states_to_model.values()))
821
-
822
- cell_states_to_model = {
823
- "state_key": list(cell_states_to_model.keys())[0],
824
- "start_state": state_values[0][0],
825
- "goal_state": state_values[1][0],
826
- "alt_states": state_values[2:][0],
827
- }
828
- elif set(cell_states_to_model.keys()).issuperset(
829
- {"state_key", "start_state", "goal_state"}
830
- ):
831
- if (
832
- (cell_states_to_model["state_key"] is None)
833
- or (cell_states_to_model["start_state"] is None)
834
- or (cell_states_to_model["goal_state"] is None)
835
- ):
836
- logger.error(
837
- "Please specify 'state_key', 'start_state', and 'goal_state' in cell_states_to_model."
838
- )
839
- raise
840
-
841
- if (
842
- cell_states_to_model["start_state"]
843
- == cell_states_to_model["goal_state"]
844
- ):
845
- logger.error("All states must be unique.")
846
- raise
847
-
848
- if "alt_states" in set(cell_states_to_model.keys()):
849
- if cell_states_to_model["alt_states"] is not None:
850
- if not isinstance(cell_states_to_model["alt_states"], list):
851
- logger.error(
852
- "cell_states_to_model['alt_states'] must be a list (even if it is one element)."
853
- )
854
- raise
855
- if len(cell_states_to_model["alt_states"]) != len(
856
- set(cell_states_to_model["alt_states"])
857
- ):
858
- logger.error("All states must be unique.")
859
- raise
860
- else:
861
- cell_states_to_model["alt_states"] = []
862
-
863
- else:
864
- logger.error(
865
- "cell_states_to_model must only have the following four keys: "
866
- "'state_key', 'start_state', 'goal_state', 'alt_states'."
867
- "For example, cell_states_to_model={"
868
- "'state_key': 'disease', "
869
- "'start_state': 'dcm', "
870
- "'goal_state': 'nf', "
871
- "'alt_states': ['hcm', 'other1', 'other2']}"
872
- )
873
- raise
874
-
875
-
876
- class GeneIdHandler:
877
- def __init__(self, raise_errors=False):
878
- def invert_dict(dict_obj):
879
- return {v: k for k, v in dict_obj.items()}
880
-
881
- self.raise_errors = raise_errors
882
-
883
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
884
- self.gene_token_dict = pickle.load(f)
885
- self.token_gene_dict = invert_dict(self.gene_token_dict)
886
-
887
- with open(ENSEMBL_DICTIONARY_FILE, "rb") as f:
888
- self.id_gene_dict = pickle.load(f)
889
- self.gene_id_dict = invert_dict(self.id_gene_dict)
890
-
891
- def ens_to_token(self, ens_id):
892
- if not self.raise_errors:
893
- return self.gene_token_dict.get(ens_id, ens_id)
894
- else:
895
- return self.gene_token_dict[ens_id]
896
-
897
- def token_to_ens(self, token):
898
- if not self.raise_errors:
899
- return self.token_gene_dict.get(token, token)
900
- else:
901
- return self.token_gene_dict[token]
902
-
903
- def ens_to_symbol(self, ens_id):
904
- if not self.raise_errors:
905
- return self.gene_id_dict.get(ens_id, ens_id)
906
- else:
907
- return self.gene_id_dict[ens_id]
908
-
909
- def symbol_to_ens(self, symbol):
910
- if not self.raise_errors:
911
- return self.id_gene_dict.get(symbol, symbol)
912
- else:
913
- return self.id_gene_dict[symbol]
914
-
915
- def token_to_symbol(self, token):
916
- return self.ens_to_symbol(self.token_to_ens(token))
917
-
918
- def symbol_to_token(self, symbol):
919
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/pretrainer.py DELETED
@@ -1,640 +0,0 @@
1
- """
2
- Geneformer precollator and pretrainer.
3
-
4
- Huggingface data collator and trainer modified to accommodate single-cell transcriptomics data.
5
- """
6
- import collections
7
- import math
8
- import pickle
9
- import warnings
10
- from enum import Enum
11
- from typing import Dict, List, Optional, Union
12
-
13
- import numpy as np
14
- import torch
15
- from datasets import Dataset
16
- from packaging import version
17
- from torch.utils.data.sampler import RandomSampler
18
- from transformers import (
19
- BatchEncoding,
20
- DataCollatorForLanguageModeling,
21
- SpecialTokensMixin,
22
- Trainer,
23
- )
24
- from transformers.file_utils import is_datasets_available, is_sagemaker_dp_enabled
25
- from transformers.trainer_pt_utils import (
26
- LengthGroupedSampler,
27
- )
28
- from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
29
- from transformers.utils.generic import _is_tensorflow, _is_torch
30
-
31
- logger = logging.get_logger(__name__)
32
- EncodedInput = List[int]
33
- VERY_LARGE_INTEGER = int(
34
- 1e30
35
- ) # This is used to set the max input length for a model with infinite size input
36
- LARGE_INTEGER = int(
37
- 1e20
38
- ) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
39
-
40
- if is_sagemaker_dp_enabled():
41
- import smdistributed.dataparallel.torch.distributed as dist
42
- else:
43
- import torch.distributed as dist
44
-
45
- _is_torch_generator_available = False
46
- if version.parse(torch.__version__) >= version.parse("1.6"):
47
- _is_torch_generator_available = True
48
-
49
-
50
- class ExplicitEnum(Enum):
51
- """
52
- Enum with more explicit error message for missing values.
53
- """
54
-
55
- @classmethod
56
- def _missing_(cls, value):
57
- raise ValueError(
58
- "%r is not a valid %s, please select one of %s"
59
- % (value, cls.__name__, str(list(cls._value2member_map_.keys())))
60
- )
61
-
62
-
63
- class TruncationStrategy(ExplicitEnum):
64
- """
65
- Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
66
- tab-completion in an IDE.
67
- """
68
-
69
- ONLY_FIRST = "only_first"
70
- ONLY_SECOND = "only_second"
71
- LONGEST_FIRST = "longest_first"
72
- DO_NOT_TRUNCATE = "do_not_truncate"
73
-
74
-
75
- class PaddingStrategy(ExplicitEnum):
76
- """
77
- Possible values for the ``padding`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for tab-completion
78
- in an IDE.
79
- """
80
-
81
- LONGEST = "longest"
82
- MAX_LENGTH = "max_length"
83
- DO_NOT_PAD = "do_not_pad"
84
-
85
-
86
- class TensorType(ExplicitEnum):
87
- """
88
- Possible values for the ``return_tensors`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
89
- tab-completion in an IDE.
90
- """
91
-
92
- PYTORCH = "pt"
93
- TENSORFLOW = "tf"
94
- NUMPY = "np"
95
- JAX = "jax"
96
-
97
-
98
- class GeneformerPreCollator(SpecialTokensMixin):
99
- def __init__(self, *args, **kwargs) -> None:
100
- super().__init__(mask_token="<mask>", pad_token="<pad>")
101
-
102
- self.token_dictionary = kwargs.get("token_dictionary")
103
- self.padding_side = "right"
104
- self.model_input_names = ["input_ids"]
105
-
106
- def convert_ids_to_tokens(self, value):
107
- return self.token_dictionary.get(value)
108
-
109
- def _get_padding_truncation_strategies(
110
- self,
111
- padding=False,
112
- truncation=False,
113
- max_length=None,
114
- pad_to_multiple_of=None,
115
- verbose=True,
116
- **kwargs,
117
- ):
118
- """
119
- Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
120
- and pad_to_max_length) and behaviors.
121
- """
122
- old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
123
- old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
124
-
125
- # Backward compatibility for previous behavior, maybe we should deprecate it:
126
- # If you only set max_length, it activates truncation for max_length
127
- if max_length is not None and padding is False and truncation is False:
128
- if verbose:
129
- if not self.deprecation_warnings.get(
130
- "Truncation-not-explicitly-activated", False
131
- ):
132
- logger.warning(
133
- "Truncation was not explicitly activated but `max_length` is provided a specific value, "
134
- "please use `truncation=True` to explicitly truncate examples to max length. "
135
- "Defaulting to 'longest_first' truncation strategy. "
136
- "If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
137
- "more precisely by providing a specific strategy to `truncation`."
138
- )
139
- self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
140
- truncation = "longest_first"
141
-
142
- # Get padding strategy
143
- if padding is False and old_pad_to_max_length:
144
- if verbose:
145
- warnings.warn(
146
- "The `pad_to_max_length` argument is deprecated and will be removed in a future version, "
147
- "use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or "
148
- "use `padding='max_length'` to pad to a max length. In this case, you can give a specific "
149
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the "
150
- "maximal input size of the model (e.g. 512 for Bert).",
151
- FutureWarning,
152
- )
153
- if max_length is None:
154
- padding_strategy = PaddingStrategy.LONGEST
155
- else:
156
- padding_strategy = PaddingStrategy.MAX_LENGTH
157
- elif padding is not False:
158
- if padding is True:
159
- padding_strategy = (
160
- PaddingStrategy.LONGEST
161
- ) # Default to pad to the longest sequence in the batch
162
- elif not isinstance(padding, PaddingStrategy):
163
- padding_strategy = PaddingStrategy(padding)
164
- elif isinstance(padding, PaddingStrategy):
165
- padding_strategy = padding
166
- else:
167
- padding_strategy = PaddingStrategy.DO_NOT_PAD
168
-
169
- # Get truncation strategy
170
- if truncation is False and old_truncation_strategy != "do_not_truncate":
171
- if verbose:
172
- warnings.warn(
173
- "The `truncation_strategy` argument is deprecated and will be removed in a future version, "
174
- "use `truncation=True` to truncate examples to a max length. You can give a specific "
175
- "length with `max_length` (e.g. `max_length=45`) or leave max_length to None to truncate to the "
176
- "maximal input size of the model (e.g. 512 for Bert). "
177
- " If you have pairs of inputs, you can give a specific truncation strategy selected among "
178
- "`truncation='only_first'` (will only truncate the first sentence in the pairs) "
179
- "`truncation='only_second'` (will only truncate the second sentence in the pairs) "
180
- "or `truncation='longest_first'` (will iteratively remove tokens from the longest sentence in the pairs).",
181
- FutureWarning,
182
- )
183
- truncation_strategy = TruncationStrategy(old_truncation_strategy)
184
- elif truncation is not False:
185
- if truncation is True:
186
- truncation_strategy = (
187
- TruncationStrategy.LONGEST_FIRST
188
- ) # Default to truncate the longest sequences in pairs of inputs
189
- elif not isinstance(truncation, TruncationStrategy):
190
- truncation_strategy = TruncationStrategy(truncation)
191
- elif isinstance(truncation, TruncationStrategy):
192
- truncation_strategy = truncation
193
- else:
194
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
195
-
196
- # Set max length if needed
197
- if max_length is None:
198
- if padding_strategy == PaddingStrategy.MAX_LENGTH:
199
- if self.model_max_length > LARGE_INTEGER:
200
- if verbose:
201
- if not self.deprecation_warnings.get(
202
- "Asking-to-pad-to-max_length", False
203
- ):
204
- logger.warning(
205
- "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. "
206
- "Default to no padding."
207
- )
208
- self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
209
- padding_strategy = PaddingStrategy.DO_NOT_PAD
210
- else:
211
- max_length = self.model_max_length
212
-
213
- if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
214
- if self.model_max_length > LARGE_INTEGER:
215
- if verbose:
216
- if not self.deprecation_warnings.get(
217
- "Asking-to-truncate-to-max_length", False
218
- ):
219
- logger.warning(
220
- "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. "
221
- "Default to no truncation."
222
- )
223
- self.deprecation_warnings[
224
- "Asking-to-truncate-to-max_length"
225
- ] = True
226
- truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
227
- else:
228
- max_length = self.model_max_length
229
-
230
- # Test if we have a padding token
231
- if padding_strategy != PaddingStrategy.DO_NOT_PAD and (
232
- not self.pad_token or self.pad_token_id < 0
233
- ):
234
- raise ValueError(
235
- "Asking to pad but the tokenizer does not have a padding token. "
236
- "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
237
- "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
238
- )
239
-
240
- # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
241
- if (
242
- truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
243
- and padding_strategy != PaddingStrategy.DO_NOT_PAD
244
- and pad_to_multiple_of is not None
245
- and max_length is not None
246
- and (max_length % pad_to_multiple_of != 0)
247
- ):
248
- raise ValueError(
249
- f"Truncation and padding are both activated but "
250
- f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
251
- )
252
-
253
- return padding_strategy, truncation_strategy, max_length, kwargs
254
-
255
- def pad(
256
- self,
257
- encoded_inputs: Union[
258
- BatchEncoding,
259
- List[BatchEncoding],
260
- Dict[str, EncodedInput],
261
- Dict[str, List[EncodedInput]],
262
- List[Dict[str, EncodedInput]],
263
- ],
264
- padding: Union[bool, str, PaddingStrategy] = True,
265
- max_length: Optional[int] = None,
266
- pad_to_multiple_of: Optional[int] = None,
267
- return_attention_mask: Optional[bool] = True,
268
- return_tensors: Optional[Union[str, TensorType]] = None,
269
- verbose: bool = True,
270
- ) -> BatchEncoding:
271
- """
272
- Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
273
- in the batch.
274
-
275
- Padding side (left/right) padding token ids are defined at the tokenizer level (with ``self.padding_side``,
276
- ``self.pad_token_id`` and ``self.pad_token_type_id``)
277
-
278
- .. note::
279
-
280
- If the ``encoded_inputs`` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
281
- result will use the same type unless you provide a different tensor type with ``return_tensors``. In the
282
- case of PyTorch tensors, you will lose the specific device of your tensors however.
283
-
284
- Args:
285
- encoded_inputs (:class:`~transformers.BatchEncoding`, list of :class:`~transformers.BatchEncoding`, :obj:`Dict[str, List[int]]`, :obj:`Dict[str, List[List[int]]` or :obj:`List[Dict[str, List[int]]]`):
286
- Tokenized inputs. Can represent one input (:class:`~transformers.BatchEncoding` or :obj:`Dict[str,
287
- List[int]]`) or a batch of tokenized inputs (list of :class:`~transformers.BatchEncoding`, `Dict[str,
288
- List[List[int]]]` or `List[Dict[str, List[int]]]`) so you can use this method during preprocessing as
289
- well as in a PyTorch Dataloader collate function.
290
-
291
- Instead of :obj:`List[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
292
- see the note above for the return type.
293
- padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
294
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
295
- index) among:
296
-
297
- * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
298
- single sequence if provided).
299
- * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
300
- maximum acceptable input length for the model if that argument is not provided.
301
- * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
302
- different lengths).
303
- max_length (:obj:`int`, `optional`):
304
- Maximum length of the returned list and optionally padding length (see above).
305
- pad_to_multiple_of (:obj:`int`, `optional`):
306
- If set will pad the sequence to a multiple of the provided value.
307
-
308
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
309
- >= 7.5 (Volta).
310
- return_attention_mask (:obj:`bool`, `optional`):
311
- Whether to return the attention mask. If left to the default, will return the attention mask according
312
- to the specific tokenizer's default, defined by the :obj:`return_outputs` attribute.
313
-
314
- `What are attention masks? <../glossary.html#attention-mask>`__
315
- return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
316
- If set, will return tensors instead of list of python integers. Acceptable values are:
317
-
318
- * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
319
- * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
320
- * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
321
- verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
322
- Whether or not to print more information and warnings.
323
- """
324
- # If we have a list of dicts, let's convert it in a dict of lists
325
- # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
326
- if isinstance(encoded_inputs, (list, tuple)) and isinstance(
327
- encoded_inputs[0], (dict, BatchEncoding)
328
- ):
329
- encoded_inputs = {
330
- key: [example[key] for example in encoded_inputs]
331
- for key in encoded_inputs[0].keys()
332
- }
333
-
334
- # The model's main input name, usually `input_ids`, has be passed for padding
335
- if self.model_input_names[0] not in encoded_inputs:
336
- raise ValueError(
337
- "You should supply an encoding or a list of encodings to this method"
338
- f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
339
- )
340
-
341
- required_input = encoded_inputs[self.model_input_names[0]]
342
-
343
- if not required_input:
344
- if return_attention_mask:
345
- encoded_inputs["attention_mask"] = []
346
- return encoded_inputs
347
-
348
- # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
349
- # and rebuild them afterwards if no return_tensors is specified
350
- # Note that we lose the specific device the tensor may be on for PyTorch
351
-
352
- first_element = required_input[0]
353
- if isinstance(first_element, (list, tuple)):
354
- # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
355
- index = 0
356
- while len(required_input[index]) == 0:
357
- index += 1
358
- if index < len(required_input):
359
- first_element = required_input[index][0]
360
- # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
361
- if not isinstance(first_element, (int, list, tuple)):
362
- if is_tf_available() and _is_tensorflow(first_element):
363
- return_tensors = "tf" if return_tensors is None else return_tensors
364
- elif is_torch_available() and _is_torch(first_element):
365
- return_tensors = "pt" if return_tensors is None else return_tensors
366
- elif isinstance(first_element, np.ndarray):
367
- return_tensors = "np" if return_tensors is None else return_tensors
368
- else:
369
- raise ValueError(
370
- f"type of {first_element} unknown: {type(first_element)}. "
371
- f"Should be one of a python, numpy, pytorch or tensorflow object."
372
- )
373
-
374
- for key, value in encoded_inputs.items():
375
- encoded_inputs[key] = to_py_obj(value)
376
-
377
- # Convert padding_strategy in PaddingStrategy
378
- padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
379
- padding=padding, max_length=max_length, verbose=verbose
380
- )
381
-
382
- required_input = encoded_inputs[self.model_input_names[0]]
383
- if required_input and not isinstance(required_input[0], (list, tuple)):
384
- encoded_inputs = self._pad(
385
- encoded_inputs,
386
- max_length=max_length,
387
- padding_strategy=padding_strategy,
388
- pad_to_multiple_of=pad_to_multiple_of,
389
- return_attention_mask=return_attention_mask,
390
- )
391
- return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
392
-
393
- batch_size = len(required_input)
394
- assert all(
395
- len(v) == batch_size for v in encoded_inputs.values()
396
- ), "Some items in the output dictionary have a different batch size than others."
397
-
398
- if padding_strategy == PaddingStrategy.LONGEST:
399
- max_length = max(len(inputs) for inputs in required_input)
400
- padding_strategy = PaddingStrategy.MAX_LENGTH
401
-
402
- batch_outputs = {}
403
- for i in range(batch_size):
404
- inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
405
- outputs = self._pad(
406
- inputs,
407
- max_length=max_length,
408
- padding_strategy=padding_strategy,
409
- pad_to_multiple_of=pad_to_multiple_of,
410
- return_attention_mask=return_attention_mask,
411
- )
412
-
413
- for key, value in outputs.items():
414
- if key not in batch_outputs:
415
- batch_outputs[key] = []
416
- batch_outputs[key].append(value)
417
-
418
- return BatchEncoding(batch_outputs, tensor_type=return_tensors)
419
-
420
- def _pad(
421
- self,
422
- encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
423
- max_length: Optional[int] = None,
424
- padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
425
- pad_to_multiple_of: Optional[int] = None,
426
- return_attention_mask: Optional[bool] = None,
427
- ) -> dict:
428
- """
429
- Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
430
-
431
- Args:
432
- encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
433
- max_length: maximum length of the returned list and optionally padding length (see below).
434
- Will truncate by taking into account the special tokens.
435
- padding_strategy: PaddingStrategy to use for padding.
436
-
437
- - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
438
- - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
439
- - PaddingStrategy.DO_NOT_PAD: Do not pad
440
- The tokenizer padding sides are defined in self.padding_side:
441
-
442
- - 'left': pads on the left of the sequences
443
- - 'right': pads on the right of the sequences
444
- pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
445
- This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
446
- >= 7.5 (Volta).
447
- return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics)
448
- """
449
- # Load from model defaults
450
- if return_attention_mask is None:
451
- return_attention_mask = "attention_mask" in self.model_input_names
452
-
453
- required_input = encoded_inputs[self.model_input_names[0]]
454
-
455
- if padding_strategy == PaddingStrategy.LONGEST:
456
- max_length = len(required_input)
457
-
458
- if (
459
- max_length is not None
460
- and pad_to_multiple_of is not None
461
- and (max_length % pad_to_multiple_of != 0)
462
- ):
463
- max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
464
-
465
- needs_to_be_padded = (
466
- padding_strategy != PaddingStrategy.DO_NOT_PAD
467
- and len(required_input) != max_length
468
- )
469
-
470
- if needs_to_be_padded:
471
- difference = max_length - len(required_input)
472
- if self.padding_side == "right":
473
- if return_attention_mask:
474
- encoded_inputs["attention_mask"] = [1] * len(required_input) + [
475
- 0
476
- ] * difference
477
- if "token_type_ids" in encoded_inputs:
478
- encoded_inputs["token_type_ids"] = (
479
- encoded_inputs["token_type_ids"]
480
- + [self.pad_token_type_id] * difference
481
- )
482
- if "special_tokens_mask" in encoded_inputs:
483
- encoded_inputs["special_tokens_mask"] = (
484
- encoded_inputs["special_tokens_mask"] + [1] * difference
485
- )
486
- encoded_inputs[self.model_input_names[0]] = (
487
- required_input + [self.pad_token_id] * difference
488
- )
489
- elif self.padding_side == "left":
490
- if return_attention_mask:
491
- encoded_inputs["attention_mask"] = [0] * difference + [1] * len(
492
- required_input
493
- )
494
- if "token_type_ids" in encoded_inputs:
495
- encoded_inputs["token_type_ids"] = [
496
- self.pad_token_type_id
497
- ] * difference + encoded_inputs["token_type_ids"]
498
- if "special_tokens_mask" in encoded_inputs:
499
- encoded_inputs["special_tokens_mask"] = [
500
- 1
501
- ] * difference + encoded_inputs["special_tokens_mask"]
502
- encoded_inputs[self.model_input_names[0]] = [
503
- self.pad_token_id
504
- ] * difference + required_input
505
- else:
506
- raise ValueError("Invalid padding strategy:" + str(self.padding_side))
507
- elif return_attention_mask and "attention_mask" not in encoded_inputs:
508
- encoded_inputs["attention_mask"] = [1] * len(required_input)
509
-
510
- return encoded_inputs
511
-
512
- def get_special_tokens_mask(
513
- self,
514
- token_ids_0: List[int],
515
- token_ids_1: Optional[List[int]] = None,
516
- already_has_special_tokens: bool = False,
517
- ) -> List[int]:
518
- """
519
- Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
520
- special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
521
- Args:
522
- token_ids_0 (:obj:`List[int]`):
523
- List of ids of the first sequence.
524
- token_ids_1 (:obj:`List[int]`, `optional`):
525
- List of ids of the second sequence.
526
- already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
527
- Whether or not the token list is already formatted with special tokens for the model.
528
- Returns:
529
- A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
530
- """
531
- assert already_has_special_tokens and token_ids_1 is None, (
532
- "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
533
- "Please use a slow (full python) tokenizer to activate this argument."
534
- "Or set `return_special_tokens_mask=True` when calling the encoding method "
535
- "to get the special tokens mask in any tokenizer. "
536
- )
537
-
538
- all_special_ids = self.all_special_ids # cache the property
539
-
540
- special_tokens_mask = [
541
- 1 if token in all_special_ids else 0 for token in token_ids_0
542
- ]
543
-
544
- return special_tokens_mask
545
-
546
- def convert_tokens_to_ids(
547
- self, tokens: Union[str, List[str]]
548
- ) -> Union[int, List[int]]:
549
- """
550
- Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
551
- vocabulary.
552
- Args:
553
- tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
554
- Returns:
555
- :obj:`int` or :obj:`List[int]`: The token id or list of token ids.
556
- """
557
- if tokens is None:
558
- return None
559
-
560
- if isinstance(tokens, str):
561
- return self._convert_token_to_id_with_added_voc(tokens)
562
-
563
- ids = []
564
- for token in tokens:
565
- ids.append(self._convert_token_to_id_with_added_voc(token))
566
- return ids
567
-
568
- def _convert_token_to_id_with_added_voc(self, token):
569
- if token is None:
570
- return None
571
-
572
- return self.token_dictionary.get(token)
573
-
574
- def __len__(self):
575
- return len(self.token_dictionary)
576
-
577
-
578
- class GeneformerPretrainer(Trainer):
579
- def __init__(self, *args, **kwargs):
580
- data_collator = kwargs.get("data_collator", None)
581
- token_dictionary = kwargs.pop("token_dictionary")
582
- mlm = kwargs.pop("mlm", True)
583
- mlm_probability = kwargs.pop("mlm_probability", 0.15)
584
-
585
- if data_collator is None:
586
- precollator = GeneformerPreCollator(token_dictionary=token_dictionary)
587
-
588
- # # Data Collator Functions
589
- data_collator = DataCollatorForLanguageModeling(
590
- tokenizer=precollator, mlm=mlm, mlm_probability=mlm_probability
591
- )
592
- kwargs["data_collator"] = data_collator
593
-
594
- # load previously saved length vector for dataset to speed up LengthGroupedSampler
595
- # pre-obtained with [dataset[i]["length"] for i in range(len(dataset))]
596
- example_lengths_file = kwargs.pop("example_lengths_file")
597
- if example_lengths_file:
598
- with open(example_lengths_file, "rb") as f:
599
- self.example_lengths = pickle.load(f)
600
- else:
601
- raise Exception(
602
- "example_lengths_file is required; e.g. https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/genecorpus_30M_2048_sorted_lengths.pkl"
603
- )
604
- super().__init__(*args, **kwargs)
605
-
606
- # updated to not use distributed sampler since Trainer now distributes with accelerate
607
- def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
608
- if not isinstance(self.train_dataset, collections.abc.Sized):
609
- return None
610
-
611
- generator = None
612
- if self.args.world_size <= 1 and _is_torch_generator_available:
613
- generator = torch.Generator()
614
- generator.manual_seed(
615
- int(torch.empty((), dtype=torch.int64).random_().item())
616
- )
617
-
618
- # Build the sampler.
619
- if self.args.group_by_length:
620
- if is_datasets_available() and isinstance(self.train_dataset, Dataset):
621
- lengths = self.example_lengths
622
- else:
623
- lengths = None
624
- model_input_name = (
625
- self.tokenizer.model_input_names[0]
626
- if self.tokenizer is not None
627
- else None
628
- )
629
- return LengthGroupedSampler(
630
- dataset=self.train_dataset,
631
- batch_size=self.args.train_batch_size,
632
- lengths=lengths,
633
- model_input_name=model_input_name,
634
- generator=generator,
635
- )
636
-
637
- else:
638
- if _is_torch_generator_available:
639
- return RandomSampler(self.train_dataset, generator=generator)
640
- return RandomSampler(self.train_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/tokenizer.py DELETED
@@ -1,685 +0,0 @@
1
- """
2
- Geneformer tokenizer.
3
-
4
- **Input data:**
5
-
6
- | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
- | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
- | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
-
10
- | *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
11
- | *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
12
-
13
- **Usage:**
14
-
15
- .. code-block :: python
16
-
17
- >>> from geneformer import TranscriptomeTokenizer
18
- >>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
19
- >>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
20
-
21
- **Description:**
22
-
23
- | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
-
25
- | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
-
27
- | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
-
29
- | No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
30
-
31
- | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
-
33
- | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
-
35
- | OF NOTE: Take care that the correct token dictionary and gene median file is used for the correct model.
36
-
37
- | OF NOTE: For 95M model series, special_token should be True and model_input_size should be 4096. For 30M model series, special_token should be False and model_input_size should be 2048.
38
-
39
- """
40
-
41
- from __future__ import annotations
42
-
43
- import logging
44
- import os
45
- import pickle
46
- import warnings
47
- from collections import Counter
48
- from pathlib import Path
49
- from typing import Literal
50
-
51
- import loompy as lp
52
- import numpy as np
53
- import pandas as pd
54
- import scanpy as sc
55
- import scipy.sparse as sp
56
- from datasets import Dataset
57
- from tqdm import tqdm
58
-
59
- warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa
60
- import loompy as lp # noqa
61
-
62
- logger = logging.getLogger(__name__)
63
-
64
- from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
-
66
- def rank_genes(gene_vector, gene_tokens):
67
- """
68
- Rank gene expression vector.
69
- """
70
- # sort by median-scaled gene values
71
- sorted_indices = np.argsort(-gene_vector)
72
- return gene_tokens[sorted_indices]
73
-
74
-
75
- def tokenize_cell(gene_vector, gene_tokens):
76
- """
77
- Convert normalized gene expression vector to tokenized rank value encoding.
78
- """
79
- # create array of gene vector with token indices
80
- # mask undetected genes
81
- nonzero_mask = np.nonzero(gene_vector)[0]
82
- # rank by median-scaled gene values
83
- return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
84
-
85
-
86
- def sum_ensembl_ids(
87
- data_directory,
88
- collapse_gene_ids,
89
- gene_mapping_dict,
90
- gene_token_dict,
91
- custom_attr_name_dict,
92
- file_format="loom",
93
- chunk_size=512,
94
- ):
95
- if file_format == "loom":
96
- """
97
- Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
98
- """
99
- with lp.connect(data_directory) as data:
100
- assert (
101
- "ensembl_id" in data.ra.keys()
102
- ), "'ensembl_id' column missing from data.ra.keys()"
103
-
104
- assert (
105
- "ensembl_id_collapsed" not in data.ra.keys()
106
- ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
107
-
108
- assert (
109
- "n_counts" in data.ca.keys()
110
- ), "'n_counts' column missing from data.ca.keys()"
111
-
112
- if custom_attr_name_dict is not None:
113
- for label in custom_attr_name_dict:
114
- assert label in data.ca.keys(), f"Attribute `{label}` not present in dataset features"
115
-
116
- # Get the ensembl ids that exist in data
117
- ensembl_ids = data.ra.ensembl_id
118
- # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
119
- # Comparing to gene_token_dict here, would not perform any mapping steps
120
- if not collapse_gene_ids:
121
- ensembl_id_check = [
122
- gene for gene in ensembl_ids if gene in gene_token_dict.keys()
123
- ]
124
- if len(ensembl_id_check) == len(set(ensembl_id_check)):
125
- return data_directory
126
- else:
127
- raise ValueError("Error: data Ensembl IDs non-unique.")
128
-
129
- # Get the genes that exist in the mapping dictionary and the value of those genes
130
- genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
131
- vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
132
-
133
- # if the genes in the mapping dict and the value of those genes are of the same length,
134
- # simply return the mapped values
135
- if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
136
- mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
137
- data.ra["ensembl_id_collapsed"] = mapped_vals
138
- return data_directory
139
- # Genes need to be collapsed
140
- else:
141
- dedup_filename = data_directory.with_name(
142
- data_directory.stem + "__dedup.loom"
143
- )
144
- mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
145
- data.ra["ensembl_id_collapsed"] = mapped_vals
146
- dup_genes = [
147
- idx
148
- for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
149
- if count > 1
150
- ]
151
- num_chunks = int(np.ceil(data.shape[1] / chunk_size))
152
- first_chunk = True
153
- for _, _, view in tqdm(
154
- data.scan(axis=1, batch_size=chunk_size), total=num_chunks
155
- ):
156
-
157
- def process_chunk(view, duplic_genes):
158
- data_count_view = pd.DataFrame(
159
- view, index=data.ra["ensembl_id_collapsed"]
160
- )
161
- unique_data_df = data_count_view.loc[
162
- ~data_count_view.index.isin(duplic_genes)
163
- ]
164
- dup_data_df = data_count_view.loc[
165
- data_count_view.index.isin(
166
- [i for i in duplic_genes if "None" not in i]
167
- )
168
- ]
169
- summed_data = dup_data_df.groupby(dup_data_df.index).sum()
170
- if not summed_data.index.is_unique:
171
- raise ValueError(
172
- "Error: Ensembl IDs in summed data frame non-unique."
173
- )
174
- data_count_view = pd.concat(
175
- [unique_data_df, summed_data], axis=0
176
- )
177
- if not data_count_view.index.is_unique:
178
- raise ValueError(
179
- "Error: Ensembl IDs in final data frame non-unique."
180
- )
181
- return data_count_view
182
-
183
- processed_chunk = process_chunk(view[:, :], dup_genes)
184
- processed_array = processed_chunk.to_numpy()
185
- new_row_attrs = {"ensembl_id_collapsed": processed_chunk.index.to_numpy()}
186
-
187
- if "n_counts" not in view.ca.keys():
188
- total_count_view = np.sum(view[:, :], axis=0).astype(int)
189
- view.ca["n_counts"] = total_count_view
190
-
191
- if first_chunk: # Create the Loom file with the first chunk
192
- lp.create(
193
- f"{dedup_filename}",
194
- processed_array,
195
- row_attrs=new_row_attrs,
196
- col_attrs=view.ca,
197
- )
198
- first_chunk = False
199
- else: # Append subsequent chunks
200
- with lp.connect(dedup_filename, mode="r+") as dsout:
201
- dsout.add_columns(processed_array, col_attrs=view.ca)
202
- return dedup_filename
203
-
204
- elif file_format == "h5ad":
205
- """
206
- Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
207
- Returns adata object with deduplicated Ensembl IDs.
208
- """
209
-
210
- data = sc.read_h5ad(str(data_directory))
211
-
212
- assert (
213
- "ensembl_id" in data.var.columns
214
- ), "'ensembl_id' column missing from data.var"
215
-
216
- assert (
217
- "ensembl_id_collapsed" not in data.var.columns
218
- ), "'ensembl_id_collapsed' column already exists in data.var"
219
- assert (
220
- "n_counts" in data.obs.columns
221
- ), "'n_counts' column missing from data.obs"
222
-
223
- if custom_attr_name_dict is not None:
224
- for label in custom_attr_name_dict:
225
- assert label in data.obs.columns, f"Attribute `{label}` not present in data.obs"
226
-
227
-
228
- # Get the ensembl ids that exist in data
229
- ensembl_ids = data.var.ensembl_id
230
- # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
231
- # Comparing to gene_token_dict here, would not perform any mapping steps
232
- if not collapse_gene_ids:
233
- ensembl_id_check = [
234
- gene for gene in ensembl_ids if gene in gene_token_dict.keys()
235
- ]
236
- if len(ensembl_id_check) == len(set(ensembl_id_check)):
237
- return data_directory
238
- else:
239
- raise ValueError("Error: data Ensembl IDs non-unique.")
240
-
241
- # Get the genes that exist in the mapping dictionary and the value of those genes
242
- genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
243
- vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
244
-
245
- # if the genes in the mapping dict and the value of those genes are of the same length,
246
- # simply return the mapped values
247
- if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
248
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
249
- return data
250
- # Genes need to be collapsed
251
- else:
252
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
253
- data.var_names = data.var["ensembl_id_collapsed"]
254
- data = data[:, ~data.var.index.isna()]
255
- dup_genes = [
256
- idx for idx, count in Counter(data.var_names).items() if count > 1
257
- ]
258
-
259
- num_chunks = int(np.ceil(data.shape[0] / chunk_size))
260
-
261
- processed_genes = []
262
- for i in tqdm(range(num_chunks)):
263
- start_idx = i * chunk_size
264
- end_idx = min((i + 1) * chunk_size, data.shape[0])
265
- data_chunk = data[start_idx:end_idx, :]
266
-
267
- processed_chunks = []
268
- for dup_gene in dup_genes:
269
- data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
270
- df = pd.DataFrame.sparse.from_spmatrix(
271
- data_dup_gene.X,
272
- index=data_dup_gene.obs_names,
273
- columns=data_dup_gene.var_names,
274
- )
275
- df_sum = pd.DataFrame(df.sum(axis=1))
276
- df_sum.columns = [dup_gene]
277
- df_sum.index = data_dup_gene.obs.index
278
- processed_chunks.append(df_sum)
279
-
280
- processed_chunks = pd.concat(processed_chunks, axis=1)
281
- processed_genes.append(processed_chunks)
282
- processed_genes = pd.concat(processed_genes, axis=0)
283
- var_df = pd.DataFrame({"ensembl_id_collapsed": processed_genes.columns})
284
- var_df.index = processed_genes.columns
285
- processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
286
-
287
- data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
288
- data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
289
- data_dedup.obs = data.obs
290
- return data_dedup
291
-
292
-
293
- class TranscriptomeTokenizer:
294
- def __init__(
295
- self,
296
- custom_attr_name_dict=None,
297
- nproc=1,
298
- chunk_size=512,
299
- model_input_size=4096,
300
- special_token=True,
301
- collapse_gene_ids=True,
302
- gene_median_file=GENE_MEDIAN_FILE,
303
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
304
- gene_mapping_file=ENSEMBL_MAPPING_FILE,
305
- ):
306
- """
307
- Initialize tokenizer.
308
-
309
- **Parameters:**
310
-
311
- custom_attr_name_dict : None, dict
312
- | Dictionary of custom attributes to be added to the dataset.
313
- | Keys are the names of the attributes in the loom file.
314
- | Values are the names of the attributes in the dataset.
315
- nproc : int
316
- | Number of processes to use for dataset mapping.
317
- chunk_size : int = 512
318
- | Chunk size for anndata tokenizer.
319
- model_input_size : int = 4096
320
- | Max input size of model to truncate input to.
321
- | For the 30M model series, should be 2048. For the 95M model series, should be 4096.
322
- special_token : bool = True
323
- | Adds CLS token before and EOS token after rank value encoding.
324
- | For the 30M model series, should be False. For the 95M model series, should be True.
325
- collapse_gene_ids : bool = True
326
- | Whether to collapse gene IDs based on gene mapping dictionary.
327
- gene_median_file : Path
328
- | Path to pickle file containing dictionary of non-zero median
329
- | gene expression values across Genecorpus-30M.
330
- token_dictionary_file : Path
331
- | Path to pickle file containing token dictionary (Ensembl IDs:token).
332
- gene_mapping_file : None, Path
333
- | Path to pickle file containing dictionary for collapsing gene IDs.
334
-
335
- """
336
- # dictionary of custom attributes {output dataset column name: input .loom column name}
337
- self.custom_attr_name_dict = custom_attr_name_dict
338
-
339
- # number of processes for dataset mapping
340
- self.nproc = nproc
341
-
342
- # chunk size for anndata tokenizer
343
- self.chunk_size = chunk_size
344
-
345
- # input size for tokenization
346
- self.model_input_size = model_input_size
347
-
348
- # add CLS and EOS tokens
349
- self.special_token = special_token
350
-
351
- # load dictionary of gene normalization factors
352
- # (non-zero median value of expression across Genecorpus-30M)
353
- with open(gene_median_file, "rb") as f:
354
- self.gene_median_dict = pickle.load(f)
355
-
356
- # load token dictionary (Ensembl IDs:token)
357
- with open(token_dictionary_file, "rb") as f:
358
- self.gene_token_dict = pickle.load(f)
359
-
360
- # check for special token in gene_token_dict
361
- if self.special_token:
362
- if ("<cls>" not in self.gene_token_dict.keys()) and (
363
- "<eos>" not in self.gene_token_dict.keys()
364
- ):
365
- logger.error(
366
- "<cls> and <eos> required in gene_token_dict when special_token = True."
367
- )
368
- raise
369
-
370
- if not self.special_token:
371
- if ("<cls>" in self.gene_token_dict.keys()) and (
372
- "<eos>" in self.gene_token_dict.keys()
373
- ):
374
- logger.warning(
375
- "<cls> and <eos> are in gene_token_dict but special_token = False. Please note that for 95M model series, special_token should be True."
376
- )
377
-
378
- # if collapsing duplicate gene IDs
379
- self.collapse_gene_ids = collapse_gene_ids
380
-
381
- # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
382
- if gene_mapping_file is not None:
383
- with open(gene_mapping_file, "rb") as f:
384
- self.gene_mapping_dict = pickle.load(f)
385
- else:
386
- self.gene_mapping_dict = {k: k for k, _ in self.gene_token_dict.items()}
387
-
388
- # gene keys for full vocabulary
389
- self.gene_keys = list(self.gene_token_dict.keys())
390
-
391
- # Filter gene mapping dict for items that exist in gene_token_dict
392
- gene_keys_set = set(self.gene_token_dict.keys())
393
- self.gene_mapping_dict = {
394
- k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set
395
- }
396
-
397
- # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
398
- self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
399
-
400
- def tokenize_data(
401
- self,
402
- data_directory: Path | str,
403
- output_directory: Path | str,
404
- output_prefix: str,
405
- file_format: Literal["loom", "h5ad"] = "loom",
406
- use_generator: bool = False,
407
- ):
408
- """
409
- Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
410
-
411
- **Parameters:**
412
-
413
- data_directory : Path
414
- | Path to directory containing loom files or anndata files
415
- output_directory : Path
416
- | Path to directory where tokenized data will be saved as .dataset
417
- output_prefix : str
418
- | Prefix for output .dataset
419
- file_format : str
420
- | Format of input files. Can be "loom" or "h5ad".
421
- use_generator : bool
422
- | Whether to use generator or dict for tokenization.
423
-
424
- """
425
- tokenized_cells, cell_metadata = self.tokenize_files(
426
- Path(data_directory), file_format
427
- )
428
- tokenized_dataset = self.create_dataset(
429
- tokenized_cells,
430
- cell_metadata,
431
- use_generator=use_generator,
432
- )
433
-
434
- output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
435
- tokenized_dataset.save_to_disk(str(output_path))
436
-
437
- def tokenize_files(
438
- self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
439
- ):
440
- tokenized_cells = []
441
- if self.custom_attr_name_dict is not None:
442
- cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
443
- cell_metadata = {
444
- attr_key: [] for attr_key in self.custom_attr_name_dict.values()
445
- }
446
-
447
- # loops through directories to tokenize .loom files
448
- file_found = 0
449
- # loops through directories to tokenize .loom or .h5ad files
450
- tokenize_file_fn = (
451
- self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
452
- )
453
- for file_path in data_directory.glob(f"*.{file_format}"):
454
- file_found = 1
455
- print(f"Tokenizing {file_path}")
456
- file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
457
- tokenized_cells += file_tokenized_cells
458
- if self.custom_attr_name_dict is not None:
459
- for k in cell_attr:
460
- cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
461
- k
462
- ]
463
- else:
464
- cell_metadata = None
465
-
466
- if file_found == 0:
467
- logger.error(
468
- f"No .{file_format} files found in directory {data_directory}."
469
- )
470
- raise
471
- return tokenized_cells, cell_metadata
472
-
473
- def tokenize_anndata(self, adata_file_path, target_sum=10_000):
474
- adata = sum_ensembl_ids(
475
- adata_file_path,
476
- self.collapse_gene_ids,
477
- self.gene_mapping_dict,
478
- self.gene_token_dict,
479
- self.custom_attr_name_dict,
480
- file_format="h5ad",
481
- chunk_size=self.chunk_size,
482
- )
483
-
484
- if self.custom_attr_name_dict is not None:
485
- file_cell_metadata = {
486
- attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
487
- }
488
-
489
- coding_miRNA_loc = np.where(
490
- [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id_collapsed"]]
491
- )[0]
492
- norm_factor_vector = np.array(
493
- [
494
- self.gene_median_dict[i]
495
- for i in adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
496
- ]
497
- )
498
- coding_miRNA_ids = adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
499
- coding_miRNA_tokens = np.array(
500
- [self.gene_token_dict[i] for i in coding_miRNA_ids]
501
- )
502
-
503
- try:
504
- _ = adata.obs["filter_pass"]
505
- except KeyError:
506
- var_exists = False
507
- else:
508
- var_exists = True
509
-
510
- if var_exists:
511
- filter_pass_loc = np.where([i == 1 for i in adata.obs["filter_pass"]])[0]
512
- elif not var_exists:
513
- print(
514
- f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
515
- )
516
- filter_pass_loc = np.array([i for i in range(adata.shape[0])])
517
-
518
- tokenized_cells = []
519
-
520
- for i in range(0, len(filter_pass_loc), self.chunk_size):
521
- idx = filter_pass_loc[i : i + self.chunk_size]
522
-
523
- n_counts = adata[idx].obs["n_counts"].values[:, None]
524
- X_view0 = adata[idx, :].X
525
- X_view = X_view0[:, coding_miRNA_loc]
526
- X_norm = X_view / n_counts * target_sum / norm_factor_vector
527
- X_norm = sp.csr_matrix(X_norm)
528
-
529
- tokenized_cells += [
530
- rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
531
- for i in range(X_norm.shape[0])
532
- ]
533
-
534
- # add custom attributes for subview to dict
535
- if self.custom_attr_name_dict is not None:
536
- for k in file_cell_metadata.keys():
537
- file_cell_metadata[k] += adata[idx].obs[k].tolist()
538
- else:
539
- file_cell_metadata = None
540
-
541
- return tokenized_cells, file_cell_metadata
542
-
543
- def tokenize_loom(self, loom_file_path, target_sum=10_000):
544
- if self.custom_attr_name_dict is not None:
545
- file_cell_metadata = {
546
- attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
547
- }
548
- loom_file_path_original = loom_file_path
549
-
550
- dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
551
- loom_file_path = sum_ensembl_ids(
552
- loom_file_path,
553
- self.collapse_gene_ids,
554
- self.gene_mapping_dict,
555
- self.gene_token_dict,
556
- self.custom_attr_name_dict,
557
- file_format="loom",
558
- chunk_size=self.chunk_size,
559
- )
560
-
561
- with lp.connect(str(loom_file_path)) as data:
562
- # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
563
- coding_miRNA_loc = np.where(
564
- [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id_collapsed"]]
565
- )[0]
566
- norm_factor_vector = np.array(
567
- [
568
- self.gene_median_dict[i]
569
- for i in data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
570
- ]
571
- )
572
- coding_miRNA_ids = data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
573
- coding_miRNA_tokens = np.array(
574
- [self.gene_token_dict[i] for i in coding_miRNA_ids]
575
- )
576
-
577
- # define coordinates of cells passing filters for inclusion (e.g. QC)
578
- try:
579
- data.ca["filter_pass"]
580
- except AttributeError:
581
- var_exists = False
582
- else:
583
- var_exists = True
584
-
585
- if var_exists:
586
- filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
587
- elif not var_exists:
588
- print(
589
- f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
590
- )
591
- filter_pass_loc = np.array([i for i in range(data.shape[1])])
592
-
593
- # scan through .loom files and tokenize cells
594
- tokenized_cells = []
595
- for _ix, _selection, view in data.scan(
596
- items=filter_pass_loc, axis=1, batch_size=self.chunk_size
597
- ):
598
- # select subview with protein-coding and miRNA genes
599
- subview = view.view[coding_miRNA_loc, :]
600
-
601
- # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
602
- # and normalize by gene normalization factors
603
- subview_norm_array = (
604
- subview[:, :]
605
- / subview.ca.n_counts
606
- * target_sum
607
- / norm_factor_vector[:, None]
608
- )
609
- # tokenize subview gene vectors
610
- tokenized_cells += [
611
- tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
612
- for i in range(subview_norm_array.shape[1])
613
- ]
614
-
615
- # add custom attributes for subview to dict
616
- if self.custom_attr_name_dict is not None:
617
- for k in file_cell_metadata.keys():
618
- file_cell_metadata[k] += subview.ca[k].tolist()
619
- else:
620
- file_cell_metadata = None
621
-
622
- if str(dedup_filename) == str(loom_file_path):
623
- os.remove(str(dedup_filename))
624
-
625
- with lp.connect(str(loom_file_path_original)) as data:
626
- if "ensembl_id_collapsed" in data.ra.keys():
627
- del data.ra["ensembl_id_collapsed"]
628
-
629
-
630
- return tokenized_cells, file_cell_metadata
631
-
632
- def create_dataset(
633
- self,
634
- tokenized_cells,
635
- cell_metadata,
636
- use_generator=False,
637
- keep_uncropped_input_ids=False,
638
- ):
639
- print("Creating dataset.")
640
- # create dict for dataset creation
641
- dataset_dict = {"input_ids": tokenized_cells}
642
- if self.custom_attr_name_dict is not None:
643
- dataset_dict.update(cell_metadata)
644
-
645
- # create dataset
646
- if use_generator:
647
-
648
- def dict_generator():
649
- for i in range(len(tokenized_cells)):
650
- yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
651
-
652
- output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
653
- else:
654
- output_dataset = Dataset.from_dict(dataset_dict)
655
-
656
- def format_cell_features(example):
657
- # Store original uncropped input_ids in separate feature
658
- if keep_uncropped_input_ids:
659
- example["input_ids_uncropped"] = example["input_ids"]
660
- example["length_uncropped"] = len(example["input_ids"])
661
-
662
- # Truncate/Crop input_ids to input size
663
- if self.special_token:
664
- example["input_ids"] = example["input_ids"][
665
- 0 : self.model_input_size - 2
666
- ] # truncate to leave space for CLS and EOS token
667
- example["input_ids"] = np.insert(
668
- example["input_ids"], 0, self.gene_token_dict.get("<cls>")
669
- )
670
- example["input_ids"] = np.insert(
671
- example["input_ids"],
672
- len(example["input_ids"]),
673
- self.gene_token_dict.get("<eos>"),
674
- )
675
- else:
676
- # Truncate/Crop input_ids to input size
677
- example["input_ids"] = example["input_ids"][0 : self.model_input_size]
678
- example["length"] = len(example["input_ids"])
679
-
680
- return example
681
-
682
- output_dataset_truncated = output_dataset.map(
683
- format_cell_features, num_proc=self.nproc
684
- )
685
- return output_dataset_truncated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-12L-30M-i2048/config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "gradient_checkpointing": false,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 2048,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0",
20
- "type_vocab_size": 2,
21
- "use_cache": true,
22
- "vocab_size": 25426
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-12L-30M-i2048/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:812f8d85e5ecf9d64c268f052f6ece2c1906bc4f1aecf70d5144b2598386b615
3
- size 158467410
 
 
 
 
gf-12L-30M-i2048/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:259cf6067211e24e198690d00f0a222ee5550ad57e23d04ced0d0ca2e1b3738e
3
- size 2607
 
 
 
 
gf-12L-95M-i4096/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 512,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1024,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 8,
16
- "num_hidden_layers": 12,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-12L-95M-i4096/generation_config.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
- }
 
 
 
 
 
 
gf-12L-95M-i4096/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c
3
- size 152012980
 
 
 
 
gf-12L-95M-i4096/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d
3
- size 4920
 
 
 
 
gf-20L-95M-i4096/config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 896,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1792,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 14,
16
- "num_hidden_layers": 20,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-20L-95M-i4096/generation_config.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
- }
 
 
 
 
 
 
gf-20L-95M-i4096/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf
3
- size 605292732
 
 
 
 
gf-20L-95M-i4096/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc
3
- size 5048
 
 
 
 
gf-20L-95M-i4096_config.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "classifier_dropout": null,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 896,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 1792,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 4096,
14
- "model_type": "bert",
15
- "num_attention_heads": 14,
16
- "num_hidden_layers": 20,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "torch_dtype": "float32",
20
- "transformers_version": "4.37.1",
21
- "type_vocab_size": 2,
22
- "use_cache": true,
23
- "vocab_size": 20275
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-20L-95M-i4096_generation_config.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "_from_model_config": true,
3
- "pad_token_id": 0,
4
- "transformers_version": "4.37.1"
5
- }
 
 
 
 
 
 
gf-6L-30M-i2048/config.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "architectures": [
3
- "BertForMaskedLM"
4
- ],
5
- "attention_probs_dropout_prob": 0.02,
6
- "gradient_checkpointing": false,
7
- "hidden_act": "relu",
8
- "hidden_dropout_prob": 0.02,
9
- "hidden_size": 256,
10
- "initializer_range": 0.02,
11
- "intermediate_size": 512,
12
- "layer_norm_eps": 1e-12,
13
- "max_position_embeddings": 2048,
14
- "model_type": "bert",
15
- "num_attention_heads": 4,
16
- "num_hidden_layers": 6,
17
- "pad_token_id": 0,
18
- "position_embedding_type": "absolute",
19
- "transformers_version": "4.6.0",
20
- "type_vocab_size": 2,
21
- "use_cache": true,
22
- "vocab_size": 25426
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gf-6L-30M-i2048/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14
3
- size 41183536
 
 
 
 
gf-6L-30M-i2048/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8d860e2125884475dd42bc2cd9a0e60c60808a7351241e08f2154931ffc142da
3
- size 41216562
 
 
 
 
gf-6L-30M-i2048/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f0ec3459454205174c9d2e4d6c6930f6b0fbf3364fc03a6f4d99c4d3add2012b
3
- size 2607