|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for common.py.""" |
|
import copy |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab import common |
|
|
|
|
|
class CommonTest(tf.test.TestCase): |
|
|
|
def testOutputsToNumClasses(self): |
|
num_classes = 21 |
|
model_options = common.ModelOptions( |
|
outputs_to_num_classes={common.OUTPUT_TYPE: num_classes}) |
|
self.assertEqual(model_options.outputs_to_num_classes[common.OUTPUT_TYPE], |
|
num_classes) |
|
|
|
def testDeepcopy(self): |
|
num_classes = 21 |
|
model_options = common.ModelOptions( |
|
outputs_to_num_classes={common.OUTPUT_TYPE: num_classes}) |
|
model_options_new = copy.deepcopy(model_options) |
|
self.assertEqual((model_options_new. |
|
outputs_to_num_classes[common.OUTPUT_TYPE]), |
|
num_classes) |
|
|
|
num_classes_new = 22 |
|
model_options_new.outputs_to_num_classes[common.OUTPUT_TYPE] = ( |
|
num_classes_new) |
|
self.assertEqual(model_options.outputs_to_num_classes[common.OUTPUT_TYPE], |
|
num_classes) |
|
self.assertEqual((model_options_new. |
|
outputs_to_num_classes[common.OUTPUT_TYPE]), |
|
num_classes_new) |
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|