Spaces:
Sleeping
Sleeping
from pydrive2.auth import GoogleAuth | |
from pydrive2.drive import GoogleDrive | |
import os | |
import gradio as gr | |
from datasets import load_dataset, Dataset, concatenate_datasets | |
import pandas as pd | |
from PIL import Image | |
from tqdm import tqdm | |
import logging | |
import yaml | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Load settings | |
if not os.path.exists("settings.yaml"): | |
raise FileNotFoundError("settings.yaml file is missing. Please add it with 'client_secrets_file'.") | |
with open('settings.yaml', 'r') as file: | |
settings = yaml.safe_load(file) | |
# Utility Functions | |
def safe_load_dataset(dataset_name): | |
"""Load Hugging Face dataset safely.""" | |
try: | |
dataset = load_dataset(dataset_name) | |
return dataset, len(dataset['train']) if 'train' in dataset else 0 | |
except Exception as e: | |
logger.info(f"No existing dataset found. Starting fresh. Error: {str(e)}") | |
return None, 0 | |
def is_valid_image(file_path): | |
"""Check if a file is a valid image.""" | |
try: | |
with Image.open(file_path) as img: | |
img.verify() | |
return True | |
except Exception as e: | |
logger.error(f"Invalid image: {file_path}. Error: {str(e)}") | |
return False | |
def validate_input(folder_id, naming_convention): | |
"""Validate user input.""" | |
if not folder_id or not folder_id.strip(): | |
return False, "Folder ID cannot be empty" | |
if not naming_convention or not naming_convention.strip(): | |
return False, "Naming convention cannot be empty" | |
if not naming_convention.replace('_', '').isalnum(): | |
return False, "Naming convention should only contain letters, numbers, and underscores" | |
return True, "" | |
# DatasetManager Class | |
class DatasetManager: | |
def __init__(self, local_images_dir="downloaded_cards"): | |
self.local_images_dir = local_images_dir | |
self.drive = None | |
self.dataset_name = "GotThatData/sports-cards" | |
os.makedirs(local_images_dir, exist_ok=True) | |
def authenticate_drive(self): | |
"""Authenticate with Google Drive.""" | |
try: | |
gauth = GoogleAuth() | |
gauth.settings['client_config_file'] = settings['client_secrets_file'] | |
# Try to load saved credentials | |
gauth.LoadCredentialsFile("credentials.txt") | |
if gauth.credentials is None: | |
gauth.LocalWebserverAuth() | |
elif gauth.access_token_expired: | |
gauth.Refresh() | |
else: | |
gauth.Authorize() | |
gauth.SaveCredentialsFile("credentials.txt") | |
self.drive = GoogleDrive(gauth) | |
return True, "Successfully authenticated with Google Drive" | |
except Exception as e: | |
logger.error(f"Authentication failed: {str(e)}") | |
return False, f"Authentication failed: {str(e)}" | |
def download_and_rename_files(self, drive_folder_id, naming_convention): | |
"""Download files from Google Drive and rename them.""" | |
if not self.drive: | |
return False, "Google Drive not authenticated", [] | |
try: | |
query = f"'{drive_folder_id}' in parents and trashed=false" | |
file_list = self.drive.ListFile({'q': query}).GetList() | |
if not file_list: | |
logger.warning(f"No files found in folder: {drive_folder_id}") | |
return False, "No files found in the specified folder.", [] | |
existing_dataset, start_index = safe_load_dataset(self.dataset_name) | |
renamed_files = [] | |
processed_count = 0 | |
error_count = 0 | |
for i, file in enumerate(tqdm(file_list, desc="Downloading files", unit="file")): | |
if 'mimeType' in file and 'image' in file['mimeType'].lower(): | |
new_filename = f"{naming_convention}_{start_index + processed_count + 1}.jpg" | |
file_path = os.path.join(self.local_images_dir, new_filename) | |
try: | |
file.GetContentFile(file_path) | |
if is_valid_image(file_path): | |
renamed_files.append({ | |
'file_path': file_path, | |
'original_name': file['title'], | |
'new_name': new_filename | |
}) | |
processed_count += 1 | |
logger.info(f"Successfully processed: {file['title']} -> {new_filename}") | |
else: | |
error_count += 1 | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
except Exception as e: | |
error_count += 1 | |
logger.error(f"Error processing file {file['title']}: {str(e)}") | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
status_message = f"Processed {processed_count} images successfully" | |
if error_count > 0: | |
status_message += f" ({error_count} files failed)" | |
return True, status_message, renamed_files | |
except Exception as e: | |
logger.error(f"Download error: {str(e)}") | |
return False, f"Error during download: {str(e)}", [] | |
def update_huggingface_dataset(self, renamed_files): | |
"""Update Hugging Face dataset with new images.""" | |
if not renamed_files: | |
return False, "No files to update" | |
try: | |
df = pd.DataFrame(renamed_files) | |
new_dataset = Dataset.from_pandas(df) | |
existing_dataset, _ = safe_load_dataset(self.dataset_name) | |
if existing_dataset and 'train' in existing_dataset: | |
combined_dataset = concatenate_datasets([existing_dataset['train'], new_dataset]) | |
else: | |
combined_dataset = new_dataset | |
combined_dataset.push_to_hub(self.dataset_name, split="train") | |
return True, f"Successfully updated dataset '{self.dataset_name}' with {len(renamed_files)} new images." | |
except Exception as e: | |
logger.error(f"Dataset update error: {str(e)}") | |
return False, f"Error updating Hugging Face dataset: {str(e)}" | |
def process_pipeline(folder_id, naming_convention): | |
"""Main pipeline for processing images and updating dataset.""" | |
# Validate input | |
is_valid, error_message = validate_input(folder_id, naming_convention) | |
if not is_valid: | |
return error_message, [] | |
manager = DatasetManager() | |
# Step 1: Authenticate Google Drive | |
auth_success, auth_message = manager.authenticate_drive() | |
if not auth_success: | |
return auth_message, [] | |
# Step 2: Download and rename files | |
success, message, renamed_files = manager.download_and_rename_files(folder_id, naming_convention) | |
if not success: | |
return message, [] | |
# Step 3: Update Hugging Face dataset | |
success, hf_message = manager.update_huggingface_dataset(renamed_files) | |
return f"{message}\n{hf_message}", renamed_files | |
def process_ui(folder_id, naming_convention): | |
"""UI handler for the process pipeline""" | |
status, renamed_files = process_pipeline(folder_id, naming_convention) | |
table_data = [[file['original_name'], file['new_name'], file['file_path']] | |
for file in renamed_files] if renamed_files else [] | |
return status, table_data | |
# Simplified Gradio interface | |
demo = gr.Interface( | |
fn=process_ui, | |
inputs=[ | |
gr.Textbox( | |
label="Google Drive Folder ID", | |
placeholder="Enter the folder ID from the URL" | |
), | |
gr.Textbox( | |
label="Naming Convention", | |
placeholder="e.g., sports_card", | |
value="sports_card" | |
) | |
], | |
outputs=[ | |
gr.Textbox(label="Status"), | |
gr.Dataframe( | |
headers=["Original Name", "New Name", "File Path"] | |
) | |
], | |
title="Sports Cards Dataset Processor", | |
description=""" | |
Instructions: | |
1. Enter the Google Drive folder ID (found in the folder's URL) | |
2. Specify a naming convention for the files (e.g., 'sports_card') | |
3. Click submit to start processing | |
Note: Only image files will be processed. Invalid images will be skipped. | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch() |