Spaces:
Runtime error
Runtime error
import torch | |
from torchvision import models, transforms | |
from PIL import Image | |
import pickle | |
import os | |
from tqdm import tqdm # Import tqdm for the progress bar | |
# Load a pretrained ResNet model | |
model = models.resnet50(pretrained=True) | |
model = model.eval() | |
# Define preprocessing transforms | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Function to extract features from an image | |
def extract_features(image_path): | |
image = Image.open(image_path).convert('RGB') | |
input_tensor = preprocess(image) | |
input_batch = input_tensor.unsqueeze(0) | |
with torch.no_grad(): | |
output = model(input_batch) | |
return output.squeeze().numpy() | |
# Directory containing your images | |
images_directory = "photos/" | |
# Process each image and save features | |
image_features = {} | |
for filename in tqdm(os.listdir(images_directory), desc="Processing Images"): | |
if filename.endswith(".jpg") or filename.endswith(".png"): | |
image_path = os.path.join(images_directory, filename) | |
features = extract_features(image_path) | |
image_features[filename] = features | |
# Save the features to a pickle file | |
output_file = "unsplash-25k-embeddings.pkl" | |
with open(output_file, 'wb') as f: | |
pickle.dump(image_features, f) |