import json from pprint import pprint import tensorflow as tf import tensorflow.keras as keras from gradio import Interface, inputs, outputs RESNET50_SIZE = 256 RESNET101_SIZE = 360 DEVICE = "/cpu:0" with open("./tags.json", "rt", encoding="utf-8") as f: tags = json.load(f) with tf.device(DEVICE): model_resnet50 = keras.Sequential( [ keras.applications.resnet.ResNet50( include_top=False, weights=None, input_shape=(RESNET50_SIZE, RESNET50_SIZE, 3), ), keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"), keras.layers.BatchNormalization(epsilon=1.001e-5), keras.layers.GlobalAveragePooling2D(name="avg_pool"), keras.layers.Activation("sigmoid"), ] ) model_resnet50.load_weights("./tf_model_resnet50.h5") with tf.device(DEVICE): model_resnet101 = keras.Sequential( [ keras.applications.resnet.ResNet101( include_top=False, weights=None, input_shape=(RESNET101_SIZE, RESNET101_SIZE, 3), ), keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"), keras.layers.BatchNormalization(epsilon=1.001e-5), keras.layers.GlobalAveragePooling2D(name="avg_pool"), keras.layers.Activation("sigmoid"), ] ) model_resnet101.load_weights("./tf_model_resnet101.h5") def predict_resnet50(img): with tf.device(DEVICE): img = tf.image.resize_with_pad(img, RESNET50_SIZE, RESNET50_SIZE) img = tf.image.per_image_standardization(img) data = tf.expand_dims(img, 0) out, *_ = model_resnet50(data) return out def predict_resnet101(img): with tf.device(DEVICE): img = tf.image.resize_with_pad(img, RESNET101_SIZE, RESNET101_SIZE) img = tf.image.per_image_standardization(img) data = tf.expand_dims(img, 0) out, *_ = model_resnet101(data) return out def main(img, hide: float, model: str): if model.endswith("50"): out = predict_resnet50(img) elif model.endswith("101"): out = predict_resnet101(img) else: raise ValueError(f"Invalid model type: {model!r}") result = { tag: confidence for i, tag in enumerate(tags) if (confidence := float(out[i].numpy())) >= hide } pprint(result) return result image = inputs.Image(label="Upload your image here!") hide_threshold = inputs.Slider( label="Hide confidence lower than", default=0.5, maximum=1, minimum=0 ) select_model = inputs.Radio( ["ResNet50", "ResNet101"], label="Select model", type="value" ) labels = outputs.Label(label="Tags", type="confidences") interface = Interface( main, inputs=[image, hide_threshold, select_model], outputs=[labels] ) interface.launch()