### # take a file containing image filepaths and return a file also containing detected objects # # the input csv file must contain an 'image_file' column containing all the image filepaths # # import os import clip import torch import pandas as pd from PIL import Image from torchvision.datasets import CIFAR100 from tqdm import tqdm # this dataset gives us the object classes cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) def save_checkpoint(checkpoint_path,df, object_list): output_df = df.copy() output_df['clip_recognized_objects'] = object_list output_df.to_csv(checkpoint_path, index= False, # don't write a new 'Index' column ) print("Saved checkpoint!") def load_checkpoint(checkpoint_path): try: print("reading checkpoint at ", checkpoint_path) df = pd.read_csv(checkpoint_path) cached_objects = { row['image_file']: row['clip_recognized_objects'] for _, row in df.iterrows() } print(f"Checkpoint loaded succesfully to cache: {len(cached_objects)} processed files") return cached_objects except: print("Checkpoint was not loaded") return cached_objects_dict def get_checkpoint_path(output_path): #checkpoint_path = "checkpoint" + os.path.basename(output_path) #checkpoint_path = os.path.join( os.path.dirname(output_path), checkpoint_path) #return checkpoint_path return output_path cached_objects_dict = {} # to avoid recomputing def get_objects(filepath, model, preprocess, device, cached_objects_dict): objects = cached_objects_dict.get(filepath) if objects is None: objects = get_objects_in_image(filepath, model, preprocess, device) cached_objects_dict[filepath] = objects return objects def get_objects_in_image(image_filepath, model, preprocess, device): # Prepare the inputs image = Image.open(image_filepath).resize((600,600)) image_input = preprocess(image).unsqueeze(0).to(device) text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) # Calculate features with torch.no_grad(): image_features = model.encode_image(image_input) text_features = model.encode_text(text_inputs) # Pick the top 5 most similar labels for the image image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(5) # Append the the result #print("\nTop predictions:\n") objects = [] for value, index in zip(values, indices): objects.append((cifar100.classes[index], value.item())) # print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") return objects def clip_object_detection(input_csv, output_csv): checkpoint_path = get_checkpoint_path(output_csv) cached_objects_dict = load_checkpoint(checkpoint_path) # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load('ViT-B/32', device) text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) recognized_objects_per_image = [] processed_files = set(cached_objects_dict.keys()) df = pd.read_csv(input_csv) iterable_list = list(enumerate( df['image_file'])) for elem in tqdm(iterable_list): idx = elem[0] filepath = elem[1] #save checkpoint every 50 files if (not (len(processed_files) % 49) ): print(f"Images processed: {len(processed_files)}") save_checkpoint(checkpoint_path, df.iloc[:idx], recognized_objects_per_image) objects = get_objects( filepath, model, preprocess, device, cached_objects_dict ) recognized_objects_per_image.append(objects) processed_files.add(filepath) recognized_objects_per_image = pd.Series(recognized_objects_per_image) return recognized_objects_per_image import argparse if __name__ == "__main__": parser = argparse.ArgumentParser(prog="CLIP object recognition", description='Recognizes the top 5 main objects per image in an image list') parser.add_argument("--input_csv", "-in", metavar='in', type=str, nargs=1, help='input file containing images-paths for object recognition.', #default=[default_painting_folder] ) parser.add_argument("--output_csv", "-out", metavar='out', type=str, nargs=1, help='output file containing images-paths + recognized objects' #default=[default_interpretation_folder] ) args = parser.parse_args() input_csv_file = args.input_csv[0] output_csv_file = args.output_csv[0] print(">>> input file: " , input_csv_file) print(">>> output file: ", output_csv_file) # perform object recognition recognized_objects_per_image = clip_object_detection(input_csv_file, output_csv_file) # add a column with the recognized objects output_df = pd.read_csv(input_csv_file) output_df['clip_recognized_objects'] = recognized_objects_per_image output_df.to_csv(output_csv_file, index= False, # don't write a new 'Index' column )