import os from gradio_client import Client import time import csv import matplotlib.pyplot as plt from matplotlib import rcParams from configs import * from PIL import Image client = Client("https://cycool29-spiralsense.hf.space/") list_of_times = [] rcParams["font.family"] = "Times New Roman" # Load the model model = MODEL.to(DEVICE) model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) model.eval() for disease in CLASSES: print("Processing", disease) for image_path in os.listdir(r"data\test\Task 1\{}".format(disease)): # print("Processing", image_path) image_path = r"data\test\Task 1\{}\{}".format(disease, image_path) start_time = time.time() result = client.predict( image_path, False, False, fn_index=0, ) time_taken = time.time() - start_time list_of_times.append(time_taken) print("Time taken:", time_taken) # Log to csv with open("log.csv", "a", newline="") as file: writer = csv.writer(file) writer.writerow([disease]) writer.writerow([image_path]) writer.writerow([time_taken]) print("Average time taken:", sum(list_of_times) / len(list_of_times)) print("Max time taken:", max(list_of_times)) print("Min time taken:", min(list_of_times)) print("Total time taken:", sum(list_of_times)) print("Median time taken:", sorted(list_of_times)[len(list_of_times) // 2]) # Plot the histogram plt.hist(list_of_times, bins=10) plt.xlabel("Time taken (s)") plt.ylabel("Frequency") plt.title("Time Taken to Process Each Image (Web)") plt.savefig("docs/evaluation/time_taken_for_web.png") # Now is local list_of_times = [] for disease in CLASSES: print("Processing", disease) for image_path in os.listdir(r"data\test\Task 1\{}".format(disease)): # print("Processing", image_path) image_path = r"data\test\Task 1\{}\{}".format(disease, image_path) start_time = time.time() image = Image.open(image_path).convert("RGB") image = preprocess(image).unsqueeze(0) image = image.to(DEVICE) output = model(image) time_taken = time.time() - start_time list_of_times.append(time_taken) print("Time taken:", time_taken) # Log to csv with open("log.csv", "a", newline="") as file: writer = csv.writer(file) writer.writerow([disease]) writer.writerow([image_path]) writer.writerow([time_taken]) print("Average time taken local:", sum(list_of_times) / len(list_of_times)) print("Max time taken local:", max(list_of_times)) print("Min time taken local:", min(list_of_times)) print("Total time taken local:", sum(list_of_times)) print("Median time taken local:", sorted(list_of_times)[len(list_of_times) // 2]) # Plot the histogram plt.hist(list_of_times, bins=10) plt.xlabel("Time taken (s)") plt.ylabel("Frequency") plt.title("Time taken to Process Each Image (Local)") plt.savefig("docs/evaluation/time_taken_for_local.png")