File size: 70,445 Bytes
f7c265f
 
 
 
 
 
 
 
 
 
 
 
 
2e60d79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7c265f
 
e6f9278
f7c265f
e6f9278
f7c265f
 
 
 
 
 
 
 
 
 
2e60d79
f7c265f
 
2e60d79
f7c265f
 
2e60d79
f7c265f
 
2e60d79
f7c265f
 
 
e6f9278
f7c265f
2e60d79
f7c265f
 
 
 
 
 
 
 
 
 
e6f9278
2e60d79
e6f9278
f7c265f
 
 
 
 
 
 
 
 
 
 
 
2e60d79
 
 
 
f7c265f
 
 
 
 
 
2e60d79
f7c265f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e60d79
e6f9278
2e60d79
f7c265f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e60d79
 
 
f7c265f
 
 
2e60d79
 
f7c265f
 
2e60d79
 
f7c265f
2e60d79
f7c265f
 
 
 
 
 
 
 
 
 
 
e6f9278
f7c265f
 
2e60d79
 
 
 
f7c265f
e6f9278
f7c265f
e6f9278
f7c265f
e6f9278
 
 
f7c265f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
---
library_name: setfit
tags:
- setfit
- sentence-transformers
- text-classification
- generated_from_setfit_trainer
metrics:
- accuracy
- precision
- recall
- f1
widget:
- text: "<p><a href=\"https://kwotsin.github.io/tech/2017/02/11/transfer-learning.html\"\
    \ rel=\"nofollow noreferrer\">https://kwotsin.github.io/tech/2017/02/11/transfer-learning.html</a>\n\
    I followed the above link to make a image classifier</p>\n\n<p>Training code:</p>\n\
    \n<pre><code>slim = tf.contrib.slim\n\ndataset_dir = './data'\nlog_dir = './log'\n\
    checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'\nimage_size = 299\n\
    num_classes = 21\nvlabels_file = './labels.txt'\nlabels = open(labels_file, 'r')\n\
    labels_to_name = {}\nfor line in labels:\n    label, string_name = line.split(':')\n\
    \    string_name = string_name[:-1]\n    labels_to_name[int(label)] = string_name\n\
    \nfile_pattern = 'test_%s_*.tfrecord'\n\nitems_to_descriptions = {\n    'image':\
    \ 'A 3-channel RGB coloured product image',\n    'label': 'A label that from 20\
    \ labels'\n}\n\nnum_epochs = 10\nbatch_size = 16\ninitial_learning_rate = 0.001\n\
    learning_rate_decay_factor = 0.7\nnum_epochs_before_decay = 4\n\ndef get_split(split_name,\
    \ dataset_dir, file_pattern=file_pattern, file_pattern_for_counting='products'):\n\
    \    if split_name not in ['train', 'validation']:\n        raise ValueError(\n\
    \            'The split_name %s is not recognized. Please input either train or\
    \ validation as the split_name' % (\n            split_name))\n\n    file_pattern_path\
    \ = os.path.join(dataset_dir, file_pattern % (split_name))\n\n    num_samples\
    \ = 0\n    file_pattern_for_counting = file_pattern_for_counting + '_' + split_name\n\
    \    tfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir)\
    \ if\n                          file.startswith(file_pattern_for_counting)]\n\
    \    for tfrecord_file in tfrecords_to_count:\n        for record in tf.python_io.tf_record_iterator(tfrecord_file):\n\
    \            num_samples += 1\n\n    test = num_samples\n\n    reader = tf.TFRecordReader\n\
    \n    keys_to_features = {\n        'image/encoded': tf.FixedLenFeature((), tf.string,\
    \ default_value=''),\n        'image/format': tf.FixedLenFeature((), tf.string,\
    \ default_value='jpg'),\n        'image/class/label': tf.FixedLenFeature(\n  \
    \          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),\n    }\n\
    \n    items_to_handlers = {\n        'image': slim.tfexample_decoder.Image(),\n\
    \        'label': slim.tfexample_decoder.Tensor('image/class/label'),\n    }\n\
    \n    decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)\n\
    \n    labels_to_name_dict = labels_to_name\n\n    dataset = slim.dataset.Dataset(\n\
    \        data_sources=file_pattern_path,\n        decoder=decoder,\n        reader=reader,\n\
    \        num_readers=4,\n        num_samples=num_samples,\n        num_classes=num_classes,\n\
    \        labels_to_name=labels_to_name_dict,\n        items_to_descriptions=items_to_descriptions)\n\
    \n    return dataset\n\ndef load_batch(dataset, batch_size, height=image_size,\
    \ width=image_size, is_training=True):\n    '''\n    Loads a batch for training.\n\
    \n    INPUTS:\n    - dataset(Dataset): a Dataset class object that is created\
    \ from the get_split function\n    - batch_size(int): determines how big of a\
    \ batch to train\n    - height(int): the height of the image to resize to during\
    \ preprocessing\n    - width(int): the width of the image to resize to during\
    \ preprocessing\n    - is_training(bool): to determine whether to perform a training\
    \ or evaluation preprocessing\n\n    OUTPUTS:\n    - images(Tensor): a Tensor\
    \ of the shape (batch_size, height, width, channels) that contain one batch of\
    \ images\n    - labels(Tensor): the batch's labels with the shape (batch_size,)\
    \ (requires one_hot_encoding).\n\n    '''\n    # First create the data_provider\
    \ object\n    data_provider = slim.dataset_data_provider.DatasetDataProvider(\n\
    \        dataset,\n        common_queue_capacity=24 + 3 * batch_size,\n      \
    \  common_queue_min=24)\n\n    # Obtain the raw image using the get method\n \
    \   raw_image, label = data_provider.get(['image', 'label'])\n\n    # Perform\
    \ the correct preprocessing for this image depending if it is training or evaluating\n\
    \    image = inception_preprocessing.preprocess_image(raw_image, height, width,\
    \ is_training)\n\n    # As for the raw images, we just do a simple reshape to\
    \ batch it up\n    raw_image = tf.expand_dims(raw_image, 0)\n    raw_image = tf.image.resize_nearest_neighbor(raw_image,\
    \ [height, width])\n    raw_image = tf.squeeze(raw_image)\n\n    # Batch up the\
    \ image by enqueing the tensors internally in a FIFO queue and dequeueing many\
    \ elements with tf.train.batch.\n    images, raw_images, labels = tf.train.batch(\n\
    \        [image, raw_image, label],\n        batch_size=batch_size,\n        num_threads=4,\n\
    \        capacity=4 * batch_size,\n        allow_smaller_final_batch=True)\n\n\
    \    return images, raw_images, labels\n\n\ndef run():\n    # Create the log directory\
    \ here. Must be done here otherwise import will activate this unneededly.\n  \
    \  if not os.path.exists(log_dir):\n        os.mkdir(log_dir)\n\n    # =======================\
    \ TRAINING PROCESS =========================\n    # Now we start to construct\
    \ the graph and build our model\n    with tf.Graph().as_default() as graph:\n\
    \        tf.logging.set_verbosity(tf.logging.INFO)  # Set the verbosity to INFO\
    \ level\n\n        # First create the dataset and load one batch\n        dataset\
    \ = get_split('train', dataset_dir, file_pattern=file_pattern)\n        images,\
    \ _, labels = load_batch(dataset, batch_size=batch_size)\n\n        # Know the\
    \ number steps to take before decaying the learning rate and batches per epoch\n\
    \        num_batches_per_epoch = int(dataset.num_samples / batch_size)\n     \
    \   num_steps_per_epoch = num_batches_per_epoch  # Because one step is one batch\
    \ processed\n        decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)\n\
    \n        # Create the model inference\n        with slim.arg_scope(inception_resnet_v2_arg_scope()):\n\
    \            logits, end_points = inception_resnet_v2(images, num_classes=dataset.num_classes,\
    \ is_training=True)\n\n        # Define the scopes that you want to exclude for\
    \ restoration\n        exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']\n\
    \        variables_to_restore = slim.get_variables_to_restore(exclude=exclude)\n\
    \n        # Perform one-hot-encoding of the labels (Try one-hot-encoding within\
    \ the load_batch function!)\n        one_hot_labels = slim.one_hot_encoding(labels,\
    \ dataset.num_classes)\n\n        # Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits\
    \ but enhanced with checks\n        loss = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels,\
    \ logits=logits)\n        total_loss = tf.losses.get_total_loss()  # obtain the\
    \ regularization losses as well\n\n        # Create the global step for monitoring\
    \ the learning_rate and training.\n        global_step = get_or_create_global_step()\n\
    \n        # Define your exponentially decaying learning rate\n        lr = tf.train.exponential_decay(\n\
    \            learning_rate=initial_learning_rate,\n            global_step=global_step,\n\
    \            decay_steps=decay_steps,\n            decay_rate=learning_rate_decay_factor,\n\
    \            staircase=True)\n\n        # Now we can define the optimizer that\
    \ takes on the learning rate\n        optimizer = tf.train.AdamOptimizer(learning_rate=lr)\n\
    \n        # Create the train_op.\n        train_op = slim.learning.create_train_op(total_loss,\
    \ optimizer)\n\n        # State the metrics that you want to predict. We get a\
    \ predictions that is not one_hot_encoded.\n        predictions = tf.argmax(end_points['Predictions'],\
    \ 1)\n        probabilities = end_points['Predictions']\n        accuracy, accuracy_update\
    \ = tf.contrib.metrics.streaming_accuracy(predictions, labels)\n        metrics_op\
    \ = tf.group(accuracy_update, probabilities)\n\n        # Now finally create all\
    \ the summaries you need to monitor and group them into one summary op.\n    \
    \    tf.summary.scalar('losses/Total_Loss', total_loss)\n        tf.summary.scalar('accuracy',\
    \ accuracy)\n        tf.summary.scalar('learning_rate', lr)\n        my_summary_op\
    \ = tf.summary.merge_all()\n\n        # Now we need to create a training step\
    \ function that runs both the train_op, metrics_op and updates the global_step\
    \ concurrently.\n        def train_step(sess, train_op, global_step):\n      \
    \      '''\n            Simply runs a session for the three arguments provided\
    \ and gives a logging on the time elapsed for each global step\n            '''\n\
    \            # Check the time for each sess run\n            start_time = time.time()\n\
    \            total_loss, global_step_count, _ = sess.run([train_op, global_step,\
    \ metrics_op])\n            time_elapsed = time.time() - start_time\n\n      \
    \      # Run the logging to print some results\n            logging.info('global\
    \ step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)\n\
    \n            return total_loss, global_step_count\n\n        # Now we create\
    \ a saver function that actually restores the variables from a checkpoint file\
    \ in a sess\n        saver = tf.train.Saver(variables_to_restore)\n\n        def\
    \ restore_fn(sess):\n            return saver.restore(sess, checkpoint_file)\n\
    \n        # Define your supervisor for running a managed session. Do not run the\
    \ summary_op automatically or else it will consume too much memory\n        sv\
    \ = tf.train.Supervisor(logdir=log_dir, summary_op=None, init_fn=restore_fn)\n\
    \n        # Run the managed session\n        with sv.managed_session() as sess:\n\
    \            for step in xrange(num_steps_per_epoch * num_epochs):\n         \
    \       # At the start of every epoch, show the vital information:\n         \
    \       if step % num_batches_per_epoch == 0:\n                    logging.info('Epoch\
    \ %s/%s', step / num_batches_per_epoch + 1, num_epochs)\n                    learning_rate_value,\
    \ accuracy_value = sess.run([lr, accuracy])\n                    logging.info('Current\
    \ Learning Rate: %s', learning_rate_value)\n                    logging.info('Current\
    \ Streaming Accuracy: %s', accuracy_value)\n\n                    # optionally,\
    \ print your logits and predictions for a sanity check that things are going fine.\n\
    \                    logits_value, probabilities_value, predictions_value, labels_value\
    \ = sess.run(\n                        [logits, probabilities, predictions, labels])\n\
    \                    print 'logits: \\n', logits_value\n                    print\
    \ 'Probabilities: \\n', probabilities_value\n                    print 'predictions:\
    \ \\n', predictions_value\n                    print 'Labels:\\n:', labels_value\n\
    \n                # Log the summaries every 10 step.\n                if step\
    \ % 10 == 0:\n                    loss, _ = train_step(sess, train_op, sv.global_step)\n\
    \                    summaries = sess.run(my_summary_op)\n                   \
    \ sv.summary_computed(sess, summaries)\n\n                # If not, simply run\
    \ the training step\n                else:\n                    loss, _ = train_step(sess,\
    \ train_op, sv.global_step)\n\n            # We log the final training loss and\
    \ accuracy\n            logging.info('Final Loss: %s', loss)\n            logging.info('Final\
    \ Accuracy: %s', sess.run(accuracy))\n\n            # Once all the training has\
    \ been done, save the log files and checkpoint model\n            logging.info('Finished\
    \ training! Saving model to disk now.')\n            sv.saver.save(sess, sv.save_path,\
    \ global_step=sv.global_step)\n</code></pre>\n\n<p>This code seems to work an\
    \ I have ran training on some sample data and Im getting 94% accuracy</p>\n\n\
    <p>Evaluation code:</p>\n\n<pre><code>log_dir = './log'\nlog_eval = './log_eval_test'\n\
    dataset_dir = './data'\nbatch_size = 10\nnum_epochs = 1\n\ncheckpoint_file = tf.train.latest_checkpoint('./')\n\
    \n\ndef run():\n    if not os.path.exists(log_eval):\n        os.mkdir(log_eval)\n\
    \    with tf.Graph().as_default() as graph:\n        tf.logging.set_verbosity(tf.logging.INFO)\n\
    \        dataset = get_split('train', dataset_dir)\n        images, raw_images,\
    \ labels = load_batch(dataset, batch_size=batch_size, is_training=False)\n\n \
    \       num_batches_per_epoch = dataset.num_samples / batch_size\n        num_steps_per_epoch\
    \ = num_batches_per_epoch\n\n        with slim.arg_scope(inception_resnet_v2_arg_scope()):\n\
    \            logits, end_points = inception_resnet_v2(images, num_classes=dataset.num_classes,\
    \ is_training=False)\n\n        variables_to_restore = slim.get_variables_to_restore()\n\
    \        saver = tf.train.Saver(variables_to_restore)\n\n        def restore_fn(sess):\n\
    \            return saver.restore(sess, checkpoint_file)\n\n        predictions\
    \ = tf.argmax(end_points['Predictions'], 1)\n        accuracy, accuracy_update\
    \ = tf.contrib.metrics.streaming_accuracy(predictions, labels)\n        metrics_op\
    \ = tf.group(accuracy_update)\n\n        global_step = get_or_create_global_step()\n\
    \        global_step_op = tf.assign(global_step, global_step + 1)\n\n        def\
    \ eval_step(sess, metrics_op, global_step):\n            '''\n            Simply\
    \ takes in a session, runs the metrics op and some logging information.\n    \
    \        '''\n            start_time = time.time()\n            _, global_step_count,\
    \ accuracy_value = sess.run([metrics_op, global_step_op, accuracy])\n        \
    \    time_elapsed = time.time() - start_time\n\n            logging.info('Global\
    \ Step %s: Streaming Accuracy: %.4f (%.2f sec/step)', global_step_count, accuracy_value,\n\
    \                         time_elapsed)\n\n            return accuracy_value\n\
    \n        tf.summary.scalar('Validation_Accuracy', accuracy)\n        my_summary_op\
    \ = tf.summary.merge_all()\n\n        sv = tf.train.Supervisor(logdir=log_eval,\
    \ summary_op=None, saver=None, init_fn=restore_fn)\n\n        with sv.managed_session()\
    \ as sess:\n            for step in xrange(num_steps_per_epoch * num_epochs):\n\
    \                sess.run(sv.global_step)\n                if step % num_batches_per_epoch\
    \ == 0:\n                    logging.info('Epoch: %s/%s', step / num_batches_per_epoch\
    \ + 1, num_epochs)\n                    logging.info('Current Streaming Accuracy:\
    \ %.4f', sess.run(accuracy))\n\n                if step % 10 == 0:\n         \
    \           eval_step(sess, metrics_op=metrics_op, global_step=sv.global_step)\n\
    \                    summaries = sess.run(my_summary_op)\n                   \
    \ sv.summary_computed(sess, summaries)\n\n\n                else:\n          \
    \          eval_step(sess, metrics_op=metrics_op, global_step=sv.global_step)\n\
    \n            logging.info('Final Streaming Accuracy: %.4f', sess.run(accuracy))\n\
    \n            raw_images, labels, predictions = sess.run([raw_images, labels,\
    \ predictions])\n            for i in range(10):\n                image, label,\
    \ prediction = raw_images[i], labels[i], predictions[i]\n                prediction_name,\
    \ label_name = dataset.labels_to_name[prediction], dataset.labels_to_name[label]\n\
    \                text = 'Prediction: %s \\n Ground Truth: %s' % (prediction_name,\
    \ label_name)\n                img_plot = plt.imshow(image)\n\n              \
    \  plt.title(text)\n                img_plot.axes.get_yaxis().set_ticks([])\n\
    \                img_plot.axes.get_xaxis().set_ticks([])\n                plt.show()\n\
    \n            logging.info(\n                'Model evaluation has completed!\
    \ Visit TensorBoard for more information regarding your evaluation.')\n</code></pre>\n\
    \n<p>So after training the model and getting 94% accuracy i tried to evaluate\
    \ the model. On evaluation I get 0-1% accuracy the whole time. I investigated\
    \ this only to find that it is predicting the same class every time</p>\n\n<pre><code>labels:\
    \ [7, 11, 5, 1, 20, 0, 18, 1, 0, 7]\npredictions: [10, 10, 10, 10, 10, 10, 10,\
    \ 10, 10, 10]\n</code></pre>\n\n<p>Can anyone help in where i may be going wrong?</p>\n\
    \n<p>EDIT:</p>\n\n<p>TensorBoard accuracy and loss form training</p>\n\n<p><a\
    \ href=\"https://i.stack.imgur.com/NLiwC.png\" rel=\"nofollow noreferrer\"><img\
    \ src=\"https://i.stack.imgur.com/NLiwC.png\" alt=\"enter image description here\"\
    ></a>\n<a href=\"https://i.stack.imgur.com/QdX6d.png\" rel=\"nofollow noreferrer\"\
    ><img src=\"https://i.stack.imgur.com/QdX6d.png\" alt=\"enter image description\
    \ here\"></a></p>\n\n<p>TensorBoard accuracy from evaluation</p>\n\n<p><a href=\"\
    https://i.stack.imgur.com/TNE5B.png\" rel=\"nofollow noreferrer\"><img src=\"\
    https://i.stack.imgur.com/TNE5B.png\" alt=\"enter image description here\"></a></p>\n\
    \n<p>EDIT:</p>\n\n<p>Ive still not been able to solve this issues. I thought there\
    \ might be a problem with how I am restoring the graph in the eval script so I\
    \ tried using this to restore the model instead</p>\n\n<pre><code>saver = tf.train.import_meta_graph('/log/model.ckpt.meta')\n\
    \ndef restore_fn(sess):\n    return saver.restore(sess, checkpoint_file)\n</code></pre>\n\
    \n<p>instead of</p>\n\n<pre><code>variables_to_restore = slim.get_variables_to_restore()\n\
    \    saver = tf.train.Saver(variables_to_restore)\n\ndef restore_fn(sess):\n \
    \   return saver.restore(sess, checkpoint_file)\n</code></pre>\n\n<p>and just\
    \ just takes a very long time to start and finally errors. I then tried using\
    \ V1 of the writer in the saver (<code>saver = tf.train.Saver(variables_to_restore,\
    \ write_version=saver_pb2.SaveDef.V1)</code>) and retrained and was unable to\
    \ load this checkpoint at all as it said variables was missing.</p>\n\n<p>I also\
    \ attempted to run my eval script with the same data it trained on just to see\
    \ if this may give different results yet I get the same. </p>\n\n<p>Finally I\
    \ re-cloned the repo from the url and ran a train using the same dataset in the\
    \ tutorial and I get 0-3% accuracy when I evaluate even after getting it to 84%\
    \ whilst training. Also my checkpoints must have the correct information as when\
    \ I restart training the accuracy continues from where it left of. It feels like\
    \ i'm not doing something correctly when I restore the model. Would really appreciate\
    \ any suggestions on this as im at a dead end currently :( </p>\n"
- text: '<p>I''ve just started using tensorflow for a project I''m working on. The
    program aims to be a binary classifier with input being 12 features. The output
    is either normal patient or patient with a disease. The prevalence of the disease
    is quite low and so my dataset is very imbalanced, with 502 examples of normal
    controls and only 38 diseased patients. For this reason, I''m trying to use <code>tf.nn.weighted_cross_entropy_with_logits</code>
    as my cost function.</p>


    <p>The code is based on the iris custom estimator from the official tensorflow
    documentation, and works with <code>tf.losses.sparse_softmax_cross_entropy</code>
    as the cost function. However, when I change to <code>weighted_cross_entropy_with_logits</code>,
    I get a shape error and I''m not sure how to fix this.</p>


    <pre><code>ValueError: logits and targets must have the same shape ((?, 2) vs
    (?,))

    </code></pre>


    <p>I have searched and similar problems have been solved by just reshaping the
    labels - I have tried to do this unsuccessfully (and don''t understand why <code>tf.losses.sparse_softmax_cross_entropy</code>
    works fine and the weighted version does not). </p>


    <p>My full code is here

    <a href="https://gist.github.com/revacious/83142573700c17b8d26a4a1b84b0dff7" rel="nofollow
    noreferrer">https://gist.github.com/revacious/83142573700c17b8d26a4a1b84b0dff7</a></p>


    <p>Thanks!</p>

    '
- text: '<p>In the documentation it seems they focus on how to save and restore tf.keras.models,
    but i was wondering how do you save and restore models trained customly through
    some basic iteration loop?</p>


    <p>Now that there isnt a graph or a session, how do we save structure defined
    in a tf function that is customly built without using layer abstractions?</p>

    '
- text: "<p>I simply have <code>train = optimizer.minimize(loss = tf.constant(4,dtype=\"\
    float32\"))</code> Line of code that i change before everything is working. <br/></p>\n\
    \n<p>Why it is giving error ? Because documentation say it can be tensor <a href=\"\
    https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam#minimize\"\
    \ rel=\"nofollow noreferrer\">Here is Docs</a> </p>\n\n<pre><code>W = tf.Variable([0.5],tf.float32)\n\
    b = tf.Variable([0.1],tf.float32)\nx = tf.placeholder(tf.float32)\ny= tf.placeholder(tf.float32)\n\
    discounted_reward = tf.placeholder(tf.float32,shape=[4,], name=\"discounted_reward\"\
    )\nlinear_model = W*x + b\n\nsquared_delta = tf.square(linear_model - y)\nprint(squared_delta)\n\
    loss = tf.reduce_sum(squared_delta*discounted_reward)\nprint(loss)\noptimizer\
    \ = tf.train.GradientDescentOptimizer(0.01)\ntrain = optimizer.minimize(loss =\
    \ tf.constant(4,dtype=\"float32\"))\ninit = tf.global_variables_initializer()\n\
    sess = tf.Session()\n\nsess.run(init)\n\nfor i in range(3):\n    sess.run(train,{x:[1,2,3,4],y:[0,-1,-2,-3],discounted_reward:[1,2,3,4]})\n\
    \nprint(sess.run([W,b]))\n</code></pre>\n\n<hr>\n\n<p>I really need this thing\
    \ to work. In this particular example we can have other ways to solve it but i\
    \ need it to work as my actual code can do this only </p>\n\n<p><hr/> Error is</p>\n\
    \n<pre><code>&gt; ValueError: No gradients provided for any variable, check your\
    \ graph\n&gt; for ops that do not support gradients, between variables\n&gt; [\"\
    &lt;tf.Variable 'Variable:0' shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable\
    \ 'Variable_1:0' shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_2:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_3:0' shape=(1,)\
    \ dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_4:0' shape=(1,) dtype=float32_ref&gt;\"\
    ,\n&gt; \"&lt;tf.Variable 'Variable_5:0' shape=(1,) dtype=float32_ref&gt;\",\n\
    &gt; \"&lt;tf.Variable 'Variable_6:0' shape=(1,) dtype=float32_ref&gt;\",\n&gt;\
    \ \"&lt;tf.Variable 'Variable_7:0' shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"\
    &lt;tf.Variable 'Variable_8:0' shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable\
    \ 'Variable_9:0' shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_10:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_11:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_12:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_13:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_14:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_15:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_16:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_17:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_18:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_19:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_20:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_21:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_22:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_23:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_24:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_25:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_26:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_27:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_28:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_29:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_30:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_31:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_32:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_33:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_34:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_35:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_36:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_37:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_38:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_39:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_40:0'\
    \ shape=(1,) dtype=float32_ref&gt;\",\n&gt; \"&lt;tf.Variable 'Variable_41:0'\
    \ shape=(1,) dtype=float32_ref&gt;\"] and loss\n&gt; Tensor(\"Const_4:0\", shape=(),\
    \ dtype=float32).\n</code></pre>\n"
- text: "<p>I found in the <a href=\"https://www.tensorflow.org/tutorials/recurrent\"\
    \ rel=\"nofollow noreferrer\">tensorflow doc</a>:</p>\n\n<p><code>\nstacked_lstm\
    \ = tf.contrib.rnn.MultiRNNCell([lstm] * number_of_layers,\n                ...\n\
    </code></p>\n\n<p>I need to use MultiRNNCell</p>\n\n<p>but, I write those lines</p>\n\
    \n<p><code>\na = [tf.nn.rnn_cell.BasicLSTMCell(10)]*3\nprint id(a[0]), id(a[1])\n\
    </code></p>\n\n<p>Its output is <code>[4648063696 4648063696]</code>.</p>\n\n\
    <p>Can <code>MultiRNNCell</code> use the same object <code>BasicLSTMCell</code>\
    \ as a list for parameter?</p>\n"
pipeline_tag: text-classification
inference: true
base_model: sentence-transformers/all-mpnet-base-v2
model-index:
- name: SetFit with sentence-transformers/all-mpnet-base-v2
  results:
  - task:
      type: text-classification
      name: Text Classification
    dataset:
      name: Unknown
      type: unknown
      split: test
    metrics:
    - type: accuracy
      value: 0.85
      name: Accuracy
    - type: precision
      value: 0.8535353535353536
      name: Precision
    - type: recall
      value: 0.85
      name: Recall
    - type: f1
      value: 0.8496240601503761
      name: F1
---

# SetFit with sentence-transformers/all-mpnet-base-v2

This is a [SetFit](https://github.com/huggingface/setfit) model that can be used for Text Classification. This SetFit model uses [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) as the Sentence Transformer embedding model. A [LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) instance is used for classification.

The model has been trained using an efficient few-shot learning technique that involves:

1. Fine-tuning a [Sentence Transformer](https://www.sbert.net) with contrastive learning.
2. Training a classification head with features from the fine-tuned Sentence Transformer.

## Model Details

### Model Description
- **Model Type:** SetFit
- **Sentence Transformer body:** [sentence-transformers/all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
- **Classification head:** a [LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) instance
- **Maximum Sequence Length:** 384 tokens
- **Number of Classes:** 2 classes
<!-- - **Training Dataset:** [Unknown](https://huggingface.co/datasets/unknown) -->
<!-- - **Language:** Unknown -->
<!-- - **License:** Unknown -->

### Model Sources

- **Repository:** [SetFit on GitHub](https://github.com/huggingface/setfit)
- **Paper:** [Efficient Few-Shot Learning Without Prompts](https://arxiv.org/abs/2209.11055)
- **Blogpost:** [SetFit: Efficient Few-Shot Learning Without Prompts](https://huggingface.co/blog/setfit)

### Model Labels
| Label | Examples                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       |
|:------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 1     | <ul><li>'<p>I\'m looking to use Tensorflow to train a neural network model for classification, and I want to read data from a CSV file, such as the Iris data set.</p>\n\n<p>The <a href="https://www.tensorflow.org/versions/r0.10/tutorials/tflearn/index.html#tf-contrib-learn-quickstart" rel="nofollow noreferrer">Tensorflow documentation</a> shows an example of loading the Iris data and building a prediction model, but the example uses the high-level <code>tf.contrib.learn</code> API. I want to use the low-level Tensorflow API and run gradient descent myself. How would I do that?</p>\n'</li><li>'<p>In the following code, I want dense matrix <code>B</code> to left multiply a sparse matrix <code>A</code>, but I got errors.</p>\n\n<pre><code>import tensorflow as tf\nimport numpy as np\n\nA = tf.sparse_placeholder(tf.float32)\nB = tf.placeholder(tf.float32, shape=(5,5))\nC = tf.matmul(B,A,a_is_sparse=False,b_is_sparse=True)\nsess = tf.InteractiveSession()\nindices = np.array([[3, 2], [1, 2]], dtype=np.int64)\nvalues = np.array([1.0, 2.0], dtype=np.float32)\nshape = np.array([5,5], dtype=np.int64)\nSparse_A = tf.SparseTensorValue(indices, values, shape)\nRandB = np.ones((5, 5))\nprint sess.run(C, feed_dict={A: Sparse_A, B: RandB})\n</code></pre>\n\n<p>The error message is as follows:</p>\n\n<pre><code>TypeError: Failed to convert object of type &lt;class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'&gt; \nto Tensor. Contents: SparseTensor(indices=Tensor("Placeholder_4:0", shape=(?, ?), dtype=int64), values=Tensor("Placeholder_3:0", shape=(?,), dtype=float32), dense_shape=Tensor("Placeholder_2:0", shape=(?,), dtype=int64)). \nConsider casting elements to a supported type.\n</code></pre>\n\n<p>What\'s wrong with my code?</p>\n\n<p>I\'m doing this following the <a href="https://www.tensorflow.org/api_docs/python/tf/matmul" rel="nofollow noreferrer">documentation</a> and it says we should use <code>a_is_sparse</code> to denote whether the first matrix is sparse, and similarly with <code>b_is_sparse</code>. Why is my code wrong?</p>\n\n<p>As is suggested by vijay, I should use <code>C = tf.matmul(B,tf.sparse_tensor_to_dense(A),a_is_sparse=False,b_is_sparse=True)</code></p>\n\n<p>I tried this but I met with another error saying:</p>\n\n<pre><code>Caused by op u\'SparseToDense\', defined at:\n  File "a.py", line 19, in &lt;module&gt;\n    C = tf.matmul(B,tf.sparse_tensor_to_dense(A),a_is_sparse=False,b_is_sparse=True)\n  File "/home/fengchao.pfc/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/sparse_ops.py", line 845, in sparse_tensor_to_dense\n    name=name)\n  File "/home/mypath/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/sparse_ops.py", line 710, in sparse_to_dense\n    name=name)\n  File "/home/mypath/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_sparse_ops.py", line 1094, in _sparse_to_dense\n    validate_indices=validate_indices, name=name)\n  File "/home/mypath/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op\n    op_def=op_def)\n  File "/home/mypath/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2506, in create_op\n    original_op=self._default_original_op, op_def=op_def)\n  File "/home/mypath/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1269, in __init__\n    self._traceback = _extract_stack()\n\nInvalidArgumentError (see above for traceback): indices[1] = [1,2] is out of order\n[[Node: SparseToDense = SparseToDense[T=DT_FLOAT, Tindices=DT_INT64, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_Placeholder_4_0_2, _arg_Placeholder_2_0_0, _arg_Placeholder_3_0_1, SparseToDense/default_value)]]\n</code></pre>\n\n<p>Thank you all for helping me!</p>\n'</li><li>"<p>I am using <code>tf.estimator.train_and_evaluate</code> and <code>tf.data.Dataset</code> to feed data to the estimator:</p>\n\n<p>Input Data function:</p>\n\n<pre><code>    def data_fn(data_dict, batch_size, mode, num_epochs=10):\n        dataset = {}\n        if mode == tf.estimator.ModeKeys.TRAIN:\n            dataset = tf.data.Dataset.from_tensor_slices(data_dict['train_data'].astype(np.float32))\n            dataset = dataset.cache()\n            dataset = dataset.shuffle(buffer_size= batch_size * 10).repeat(num_epochs).batch(batch_size)\n        else:\n            dataset = tf.data.Dataset.from_tensor_slices(data_dict['valid_data'].astype(np.float32))\n            dataset = dataset.cache()\n            dataset = dataset.batch(batch_size)\n\n        iterator = dataset.make_one_shot_iterator()\n        next_element = iterator.get_next()\n\n    return next_element\n</code></pre>\n\n<p>Train Function:</p>\n\n<pre><code>def train_model(data):\n    tf.logging.set_verbosity(tf.logging.INFO)\n    config = tf.ConfigProto(allow_soft_placement=True,\n                            log_device_placement=False)\n    config.gpu_options.allow_growth = True\n    run_config = tf.contrib.learn.RunConfig(\n        save_checkpoints_steps=10,\n        keep_checkpoint_max=10,\n        session_config=config\n    )\n\n    train_input = lambda: data_fn(data, 100, tf.estimator.ModeKeys.TRAIN, num_epochs=1)\n    eval_input = lambda: data_fn(data, 1000, tf.estimator.ModeKeys.EVAL)\n    estimator = tf.estimator.Estimator(model_fn=model_fn, params=hps, config=run_config)\n    train_spec = tf.estimator.TrainSpec(train_input, max_steps=100)\n    eval_spec = tf.estimator.EvalSpec(eval_input,\n                                      steps=None,\n                                      throttle_secs = 30)\n\n    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)\n</code></pre>\n\n<p>The training goes fine, but when it comes to evaluation I get this error:</p>\n\n<pre><code>OutOfRangeError (see above for traceback): End of sequence \n</code></pre>\n\n<p>If I don't use <code>Dataset.batch</code> on evaluation dataset (by omitting the line <code>dataset[name] = dataset[name].batch(batch_size)</code> in <code>data_fn</code>) I get the same error but after a much longer time.</p>\n\n<p>I can only avoid this error if I don't batch the data and use <code>steps=1</code> for evaluation, but does that perform the evaluation on the whole dataset?</p>\n\n<p>I don't understand what causes this error as the documentation suggests I should be able to evaluate on batches too.</p>\n\n<p>Note: I get the same error when using <code>tf.estimator.evaluate</code> on data batches.</p>\n"</li></ul>                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         |
| 0     | <ul><li>'<p>I\'m working on a project where I have trained a series of binary classifiers with <strong>Keras</strong>, with <strong>Tensorflow</strong> as the backend engine. The input data I have is a series of images, where each binary classifier must make the prediction on the images, later I save the predictions on a CSV file.</p>\n<p>The problem I have is when I get the predictions from the first series of binary classifiers there isn\'t any warning, but when the 5th or 6th binary classifier calls the method <strong>predict</strong> on the input data I get the following warning:</p>\n<blockquote>\n<p>WARNING:tensorflow:5 out of the last 5 calls to &lt;function\nModel.make_predict_function..predict_function at\n0x2b280ff5c158&gt; triggered tf.function retracing. Tracing is expensive\nand the excessive number of tracings could be due to (1) creating\[email protected] repeatedly in a loop, (2) passing tensors with different\nshapes, (3) passing Python objects instead of tensors. For (1), please\ndefine your @tf.function outside of the loop. For (2), @tf.function\nhas experimental_relax_shapes=True option that relaxes argument shapes\nthat can avoid unnecessary retracing. For (3), please refer to\n<a href="https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args" rel="noreferrer">https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args</a>\nand <a href="https://www.tensorflow.org/api_docs/python/tf/function" rel="noreferrer">https://www.tensorflow.org/api_docs/python/tf/function</a> for  more\ndetails.</p>\n</blockquote>\n<p>To answer each point in the parenthesis, here are my answers:</p>\n<ol>\n<li>The <strong>predict</strong> method is called inside a for loop.</li>\n<li>I don\'t pass tensors but a list of <strong>NumPy arrays</strong> of gray scale images, all of them with the same size in width and height. The only thing that can change is the batch size because the list can have only 1 image or more than one.</li>\n<li>As I wrote in point 2, I pass a list of NumPy arrays.</li>\n</ol>\n<p>I have debugged my program and found that this warning always happens when the method predict is called. To summarize the code I have written is the following:</p>\n<pre><code>import cv2 as cv\nimport tensorflow as tf\nfrom tensorflow.keras.models import load_model\n# Load the models\nbinary_classifiers = [load_model(path) for path in path2models]\n# Get the images\nimages = [#Load the images with OpenCV]\n# Apply the resizing and reshapes on the images.\nmy_list = list()\nfor image in images:\n    image_reworked = # Apply the resizing and reshaping on images\n    my_list.append(image_reworked)\n\n# Get the prediction from each model\n# This is where I get the warning\npredictions = [model.predict(x=my_list,verbose=0) for model in binary_classifiers]\n</code></pre>\n<h3>What I have tried</h3>\n<p>I have defined a function as tf.function and putted the code of the predictions inside the tf.function like this</p>\n<pre><code>@tf.function\ndef testing(models, faces):\n    return [model.predict(x=faces,verbose=0) for model in models]\n</code></pre>\n<p>But I ended up getting the following error:</p>\n<blockquote>\n<p>RuntimeError: Detected a call to <code>Model.predict</code> inside a\n<code>tf.function</code>. Model.predict is a high-level endpoint that manages\nits own <code>tf.function</code>. Please move the call to <code>Model.predict</code> outside\nof all enclosing <code>tf.function</code>s. Note that you can call a <code>Model</code>\ndirectly on Tensors inside a <code>tf.function</code> like: <code>model(x)</code>.</p>\n</blockquote>\n<p>So calling the method <code>predict</code> is basically already a tf.function. So it\'s useless to define a tf.function when the warning I get it\'s from that method.</p>\n<p>I have also checked those other two questions:</p>\n<ol>\n<li><a href="https://stackoverflow.com/questions/61647404/tensorflow-2-getting-warningtensorflow9-out-of-the-last-9-calls-to-function">Tensorflow 2: Getting &quot;WARNING:tensorflow:9 out of the last 9 calls to  triggered tf.function retracing. Tracing is expensive&quot;</a></li>\n<li><a href="https://stackoverflow.com/questions/65563185/loading-multiple-saved-tensorflow-keras-models-for-prediction">Loading multiple saved tensorflow/keras models for prediction</a></li>\n</ol>\n<p>But neither of the two questions answers my question about how to avoid this warning. Plus I have also checked the links in the warning message but I couldn\'t solve my problem.</p>\n<h3>What I want</h3>\n<p>I simply want to avoid this warning. While I\'m still getting the predictions from the models I noticed that the python program takes way too much time on doing predictions for a list of images.</p>\n<h3>What I\'m using</h3>\n<ul>\n<li>Python 3.6.13</li>\n<li>Tensorflow 2.3.0</li>\n</ul>\n<h3>Solution</h3>\n<p>After some tries to suppress the warning from the <code>predict</code> method, I have checked the documentation of Tensorflow and in one of the first tutorials on how to use Tensorflow it is explained that, by default, Tensorflow is executed in eager mode, which is useful for testing and debugging the network models. Since I have already tested my models many times, it was only required to disable the eager mode by writing this single python line of code:</p>\n<p><code>tf.compat.v1.disable_eager_execution()</code></p>\n<p>Now the warning doesn\'t show up anymore.</p>\n'</li><li>'<p>I try to export a Tensorflow model but I can not find the best way to add the exogenous feature to the <code>tf.contrib.timeseries.StructuralEnsembleRegressor.build_raw_serving_input_receiver_fn</code>. </p>\n\n<p>I use the sample from the Tensorflow contrib: <a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/timeseries/examples/known_anomaly.py" rel="nofollow noreferrer">https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/timeseries/examples/known_anomaly.py</a> and I just try to save the model.</p>\n\n<pre><code># this is the exogenous column \nstring_feature = tf.contrib.layers.sparse_column_with_keys(\n      column_name="is_changepoint", keys=["no", "yes"])\n\none_hot_feature = tf.contrib.layers.one_hot_column(\n      sparse_id_column=string_feature)\n\nestimator = tf.contrib.timeseries.StructuralEnsembleRegressor(\n      periodicities=12,    \n      cycle_num_latent_values=3,\n      num_features=1,\n      exogenous_feature_columns=[one_hot_feature],\n      exogenous_update_condition=\n      lambda times, features: tf.equal(features["is_changepoint"], "yes"))\n\nreader = tf.contrib.timeseries.CSVReader(\n      csv_file_name,\n\n      column_names=(tf.contrib.timeseries.TrainEvalFeatures.TIMES,\n                    tf.contrib.timeseries.TrainEvalFeatures.VALUES,\n                    "is_changepoint"),\n\n      column_dtypes=(tf.int64, tf.float32, tf.string),\n\n      skip_header_lines=1)\n\ntrain_input_fn = tf.contrib.timeseries.RandomWindowInputFn(reader, batch_size=4, window_size=64)\nestimator.train(input_fn=train_input_fn, steps=train_steps)\nevaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)\nevaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)\n\nexport_directory = tempfile.mkdtemp()\n\n###################################################### \n# the exogenous column must be provided to the build_raw_serving_input_receiver_fn. \n# But How ?\n######################################################\n\ninput_receiver_fn = estimator.build_raw_serving_input_receiver_fn()\n# -&gt; error missing \'is_changepoint\' key    \n\n#input_receiver_fn = estimator.build_raw_serving_input_receiver_fn({\'is_changepoint\' : string_feature}) \n# -&gt; cast exception\n\nexport_location = estimator.export_savedmodel(export_directory, input_receiver_fn)\n</code></pre>\n\n<p>According to the <a href="https://www.tensorflow.org/api_docs/python/tf/contrib/timeseries/StructuralEnsembleRegressor" rel="nofollow noreferrer">documentation</a>, build_raw_serving_input_receiver_fn <strong>exogenous_features</strong> parameter : <em>A dictionary mapping feature keys to exogenous features (either Numpy arrays or Tensors). Used to determine the shapes of placeholders for these features</em>.</p>\n\n<p>So what is the best way to transform the <em>one_hot_column</em> or <em>sparse_column_with_keys</em> to a <em>Tensor</em> object ?</p>\n'</li><li>"<p>I am currently working on an optical flow project and I come across a strange error. </p>\n\n<p>I have uint16 images stored in bytes in my TFrecords. When I read the TFrecords from my local machine it is giving me uint16 values, but when I deploy the same code and read it from the docker I am getting uint8 values eventhough my dtype is uint16. I mean the uint16 values are getting reduced to uint8 like 32768 --> 128.</p>\n\n<p>What is causing this error?</p>\n\n<p>My local machine has: Tensorflow 1.10.1 and python 3.6\nMy Docker Image has: Tensorflow 1.12.0 and python 3.5</p>\n\n<p>I am working on tensorflow object detection API\nWhile creating the TF records I use:</p>\n\n<pre><code>with tf.gfile.GFile(flows, 'rb') as fid:\n    flow_images = fid.read()\n</code></pre>\n\n<p>While reading it back I am using: tf.image.decoderaw</p>\n\n<p>Dataset: KITTI FLOW 2015</p>\n"</li></ul> |

## Evaluation

### Metrics
| Label   | Accuracy | Precision | Recall | F1     |
|:--------|:---------|:----------|:-------|:-------|
| **all** | 0.85     | 0.8535    | 0.85   | 0.8496 |

## Uses

### Direct Use for Inference

First install the SetFit library:

```bash
pip install setfit
```

Then you can load this model and run inference.

```python
from setfit import SetFitModel

# Download from the 🤗 Hub
model = SetFitModel.from_pretrained("sharukat/sbert-questionclassifier")
# Run inference
preds = model("<p>In the documentation it seems they focus on how to save and restore tf.keras.models, but i was wondering how do you save and restore models trained customly through some basic iteration loop?</p>

<p>Now that there isnt a graph or a session, how do we save structure defined in a tf function that is customly built without using layer abstractions?</p>
")
```

<!--
### Downstream Use

*List how someone could finetune this model on their own dataset.*
-->

<!--
### Out-of-Scope Use

*List how the model may foreseeably be misused and address what users ought not to do with the model.*
-->

<!--
## Bias, Risks and Limitations

*What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
-->

<!--
### Recommendations

*What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
-->

## Training Details

### Training Set Metrics
| Training set | Min | Median   | Max  |
|:-------------|:----|:---------|:-----|
| Word count   | 15  | 330.0667 | 3755 |

| Label | Training Sample Count |
|:------|:----------------------|
| 0     | 450                   |
| 1     | 450                   |

### Training Hyperparameters
- batch_size: (16, 2)
- num_epochs: (1, 16)
- max_steps: -1
- sampling_strategy: unique
- body_learning_rate: (2e-05, 1e-05)
- head_learning_rate: 0.01
- loss: CosineSimilarityLoss
- distance_metric: cosine_distance
- margin: 0.25
- end_to_end: False
- use_amp: False
- warmup_proportion: 0.1
- max_length: 256
- seed: 42
- eval_max_steps: -1
- load_best_model_at_end: True

### Training Results
| Epoch   | Step      | Training Loss | Validation Loss |
|:-------:|:---------:|:-------------:|:---------------:|
| 0.0000  | 1         | 0.2951        | -               |
| **1.0** | **25341** | **0.0**       | **0.2473**      |

* The bold row denotes the saved checkpoint.
### Framework Versions
- Python: 3.10.13
- SetFit: 1.0.3
- Sentence Transformers: 2.5.0
- Transformers: 4.38.1
- PyTorch: 2.1.2
- Datasets: 2.17.1
- Tokenizers: 0.15.2

## Citation

### BibTeX
```bibtex
@article{https://doi.org/10.48550/arxiv.2209.11055,
    doi = {10.48550/ARXIV.2209.11055},
    url = {https://arxiv.org/abs/2209.11055},
    author = {Tunstall, Lewis and Reimers, Nils and Jo, Unso Eun Seo and Bates, Luke and Korat, Daniel and Wasserblat, Moshe and Pereg, Oren},
    keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
    title = {Efficient Few-Shot Learning Without Prompts},
    publisher = {arXiv},
    year = {2022},
    copyright = {Creative Commons Attribution 4.0 International}
}
```

<!--
## Glossary

*Clearly define terms in order to be accessible across audiences.*
-->

<!--
## Model Card Authors

*Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
-->

<!--
## Model Card Contact

*Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
-->