import os import gc import timm import gradio as gr import torch import tensorflow as tf model_names = [ "mobilenetv4_conv_small.e2400_r224_in1k", "mobilenetv4_conv_medium.e500_r224_in1k", "mobilenetv4_conv_blur_medium.e500_r224_in1k", "mobilenetv4_conv_medium.e500_r256_in1k", "mobilenetv4_conv_large.e500_r256_in1k", "mobilenetv4_conv_large.e600_r384_in1k", ] with open('imagenet_classes.txt', 'r') as file: lines = file.readlines() index_to_label = {index: line.strip() for index, line in enumerate(lines)} model, transforms = None, None tfl_model, input_details, output_details = None, None, None last_model = None def load_models(timm_model): convert_dir = "tflite_models" tf_model_path = os.path.join(convert_dir, f"{timm_model}_float16.tflite") model = timm.create_model(timm_model, pretrained=True) model = model.eval() data_config = timm.data.resolve_data_config(model=model) transforms = timm.data.create_transform(**data_config, is_training=False) tfl_model = tf.lite.Interpreter(model_path=tf_model_path) tfl_model.allocate_tensors() input_details = tfl_model.get_input_details() output_details = tfl_model.get_output_details() return model, transforms, tfl_model, input_details, output_details def classify(img, model_name): global model, transforms, tfl_model, input_details, output_details, last_model if last_model is None or model_name != last_model: if model is not None: model = None gc.collect() if tfl_model is not None: tfl_model = None gc.collect() model, transforms, tfl_model, input_details, output_details = load_models(model_name) last_model = model_name processed_img = transforms(img).unsqueeze(0) pt_output = model(processed_img) pt_top5_probs, pt_top5_indices = torch.topk(pt_output.softmax(dim=1), k=5) pt_index_list = pt_top5_indices[0].tolist() pt_probs_list = pt_top5_probs[0].tolist() pt_result_labels = { index_to_label[index]: prob for index, prob in zip(pt_index_list, pt_probs_list) } ############################################################ img_tf = processed_img.permute(0, 2, 3, 1) # BCHW to numpy BHWC input = input_details[0] tfl_model.set_tensor(input["index"], img_tf) tfl_model.invoke() tfl_output = tfl_model.get_tensor(output_details[0]["index"]) tfl_output_tensor = tf.convert_to_tensor(tfl_output) tfl_softmax_output = tf.nn.softmax(tfl_output_tensor, axis=1) tfl_top5_probs, tfl_top5_indices = tf.math.top_k(tfl_softmax_output, k=5) tfl_probs_list = tfl_top5_probs[0].numpy().tolist() tfl_index_list = tfl_top5_indices[0].numpy().tolist() tfl_result_labels = { index_to_label[index]: prob for index, prob in zip(tfl_index_list, tfl_probs_list) } return pt_result_labels, tfl_result_labels iface = gr.Interface( fn=classify, inputs=[gr.Image(type="pil"), gr.Dropdown(choices=model_names, value=model_names[0], label="Model Variant.")], outputs=[gr.Label(label="Pytorch Output"), gr.Label(label="TFLite Output")], title="MobileNetV4 Pytorch vs TFLite Imagenet1K Classification", examples=[ ["example_images/n01818515_macaw.JPEG", model_names[0]], ["example_images/n01828970_bee_eater.jpg", model_names[0]], ["example_images/n01833805_hummingbird.JPEG", model_names[0]] ] ) iface.launch()