Spaces:
Sleeping
Sleeping
from pydrive2.auth import GoogleAuth | |
from pydrive2.drive import GoogleDrive | |
import os | |
import gradio as gr | |
from datasets import load_dataset, Dataset | |
import pandas as pd | |
from PIL import Image | |
from tqdm import tqdm | |
import logging | |
import yaml | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load settings | |
with open('settings.yaml', 'r') as file: | |
settings = yaml.safe_load(file) | |
class DatasetManager: | |
def __init__(self, local_images_dir="downloaded_cards"): | |
self.local_images_dir = local_images_dir | |
self.drive = None | |
self.dataset_name = "GotThatData/sports-cards" | |
# Create local directory if it doesn't exist | |
os.makedirs(local_images_dir, exist_ok=True) | |
def authenticate_drive(self): | |
"""Authenticate with Google Drive""" | |
try: | |
gauth = GoogleAuth() | |
gauth.settings['client_config_file'] = settings['client_secrets_file'] | |
# Try to load saved credentials | |
gauth.LoadCredentialsFile("credentials.txt") | |
if gauth.credentials is None: | |
gauth.LocalWebserverAuth() | |
elif gauth.access_token_expired: | |
gauth.Refresh() | |
else: | |
gauth.Authorize() | |
gauth.SaveCredentialsFile("credentials.txt") | |
self.drive = GoogleDrive(gauth) | |
return True, "Successfully authenticated with Google Drive" | |
except Exception as e: | |
return False, f"Authentication failed: {str(e)}" | |
def download_and_rename_files(self, drive_folder_id, naming_convention): | |
"""Download files from Google Drive and rename them""" | |
if not self.drive: | |
return False, "Google Drive not authenticated", [] | |
try: | |
query = f"'{drive_folder_id}' in parents and trashed=false" | |
file_list = self.drive.ListFile({'q': query}).GetList() | |
if not file_list: | |
file = self.drive.CreateFile({'id': drive_folder_id}) | |
if file: | |
file_list = [file] | |
else: | |
return False, "No files found with the specified ID", [] | |
renamed_files = [] | |
try: | |
existing_dataset = load_dataset(self.dataset_name) | |
logger.info(f"Loaded existing dataset: {self.dataset_name}") | |
start_index = len(existing_dataset['train']) if 'train' in existing_dataset else 0 | |
except Exception as e: | |
logger.info(f"No existing dataset found, starting fresh: {str(e)}") | |
start_index = 0 | |
for i, file in enumerate(tqdm(file_list, desc="Downloading files")): | |
if file['mimeType'].startswith('image/'): | |
new_filename = f"{naming_convention}_{start_index + i + 1}.jpg" | |
file_path = os.path.join(self.local_images_dir, new_filename) | |
file.GetContentFile(file_path) | |
try: | |
with Image.open(file_path) as img: | |
img.verify() | |
renamed_files.append({ | |
'file_path': file_path, | |
'original_name': file['title'], | |
'new_name': new_filename, | |
'image': file_path | |
}) | |
except Exception as e: | |
logger.error(f"Error processing image {file['title']}: {str(e)}") | |
if os.path.exists(file_path): | |
os.remove(file_path) | |
return True, f"Successfully processed {len(renamed_files)} images", renamed_files | |
except Exception as e: | |
return False, f"Error downloading files: {str(e)}", [] | |
def update_huggingface_dataset(self, renamed_files): | |
"""Update the sports-cards dataset with new images""" | |
try: | |
df = pd.DataFrame(renamed_files) | |
new_dataset = Dataset.from_pandas(df) | |
try: | |
existing_dataset = load_dataset(self.dataset_name) | |
if 'train' in existing_dataset: | |
new_dataset = concatenate_datasets([existing_dataset['train'], new_dataset]) | |
except Exception: | |
logger.info("Creating new dataset") | |
new_dataset.push_to_hub(self.dataset_name, split="train") | |
return True, f"Successfully updated dataset '{self.dataset_name}' with {len(renamed_files)} new images" | |
except Exception as e: | |
return False, f"Error updating Hugging Face dataset: {str(e)}" | |
def process_pipeline(folder_id, naming_convention): | |
"""Main pipeline to process images and update dataset""" | |
manager = DatasetManager() | |
auth_success, auth_message = manager.authenticate_drive() | |
if not auth_success: | |
return auth_message | |
success, message, renamed_files = manager.download_and_rename_files(folder_id, naming_convention) | |
if not success: | |
return message | |
success, hf_message = manager.update_huggingface_dataset(renamed_files) | |
return f"{message}\n{hf_message}" | |
# Custom CSS for web-safe fonts and clean styling | |
custom_css = """ | |
.gradio-container { | |
font-family: Arial, sans-serif !important; | |
} | |
h1, h2, h3 { | |
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif !important; | |
font-weight: 600 !important; | |
} | |
.gr-button { | |
font-family: Arial, sans-serif !important; | |
} | |
.gr-input { | |
font-family: 'Courier New', Courier, monospace !important; | |
} | |
.gr-box { | |
border-radius: 8px !important; | |
border: 1px solid #e5e5e5 !important; | |
} | |
.gr-padded { | |
padding: 16px !important; | |
} | |
""" | |
# Gradio interface with custom theme | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("# Sports Cards Dataset Processor") | |
with gr.Box(): | |
gr.Markdown(""" | |
### Instructions | |
1. Enter the Google Drive folder/file ID | |
2. Choose a naming convention for your cards | |
3. Click Process to start | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
folder_id = gr.Textbox( | |
label="Google Drive File/Folder ID", | |
placeholder="Enter the ID from your Google Drive URL", | |
value="151VOxPO91mg0C3ORiioGUd4hogzP1ujm" | |
) | |
naming = gr.Textbox( | |
label="Naming Convention", | |
placeholder="e.g., sports_card", | |
value="sports_card" | |
) | |
process_btn = gr.Button("Process Images", variant="primary") | |
with gr.Box(): | |
output = gr.Textbox( | |
label="Processing Status", | |
show_label=True, | |
lines=5 | |
) | |
process_btn.click( | |
fn=process_pipeline, | |
inputs=[folder_id, naming], | |
outputs=output | |
) | |
if __name__ == "__main__": | |
demo.launch() |