ai-card-hub / app.py
GotThatData's picture
update
8067321 verified
raw
history blame
7.14 kB
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
import os
import gradio as gr
from datasets import load_dataset, Dataset
import pandas as pd
from PIL import Image
from tqdm import tqdm
import logging
import yaml
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load settings
with open('settings.yaml', 'r') as file:
settings = yaml.safe_load(file)
class DatasetManager:
def __init__(self, local_images_dir="downloaded_cards"):
self.local_images_dir = local_images_dir
self.drive = None
self.dataset_name = "GotThatData/sports-cards"
# Create local directory if it doesn't exist
os.makedirs(local_images_dir, exist_ok=True)
def authenticate_drive(self):
"""Authenticate with Google Drive"""
try:
gauth = GoogleAuth()
gauth.settings['client_config_file'] = settings['client_secrets_file']
# Try to load saved credentials
gauth.LoadCredentialsFile("credentials.txt")
if gauth.credentials is None:
gauth.LocalWebserverAuth()
elif gauth.access_token_expired:
gauth.Refresh()
else:
gauth.Authorize()
gauth.SaveCredentialsFile("credentials.txt")
self.drive = GoogleDrive(gauth)
return True, "Successfully authenticated with Google Drive"
except Exception as e:
return False, f"Authentication failed: {str(e)}"
def download_and_rename_files(self, drive_folder_id, naming_convention):
"""Download files from Google Drive and rename them"""
if not self.drive:
return False, "Google Drive not authenticated", []
try:
query = f"'{drive_folder_id}' in parents and trashed=false"
file_list = self.drive.ListFile({'q': query}).GetList()
if not file_list:
file = self.drive.CreateFile({'id': drive_folder_id})
if file:
file_list = [file]
else:
return False, "No files found with the specified ID", []
renamed_files = []
try:
existing_dataset = load_dataset(self.dataset_name)
logger.info(f"Loaded existing dataset: {self.dataset_name}")
start_index = len(existing_dataset['train']) if 'train' in existing_dataset else 0
except Exception as e:
logger.info(f"No existing dataset found, starting fresh: {str(e)}")
start_index = 0
for i, file in enumerate(tqdm(file_list, desc="Downloading files")):
if file['mimeType'].startswith('image/'):
new_filename = f"{naming_convention}_{start_index + i + 1}.jpg"
file_path = os.path.join(self.local_images_dir, new_filename)
file.GetContentFile(file_path)
try:
with Image.open(file_path) as img:
img.verify()
renamed_files.append({
'file_path': file_path,
'original_name': file['title'],
'new_name': new_filename,
'image': file_path
})
except Exception as e:
logger.error(f"Error processing image {file['title']}: {str(e)}")
if os.path.exists(file_path):
os.remove(file_path)
return True, f"Successfully processed {len(renamed_files)} images", renamed_files
except Exception as e:
return False, f"Error downloading files: {str(e)}", []
def update_huggingface_dataset(self, renamed_files):
"""Update the sports-cards dataset with new images"""
try:
df = pd.DataFrame(renamed_files)
new_dataset = Dataset.from_pandas(df)
try:
existing_dataset = load_dataset(self.dataset_name)
if 'train' in existing_dataset:
new_dataset = concatenate_datasets([existing_dataset['train'], new_dataset])
except Exception:
logger.info("Creating new dataset")
new_dataset.push_to_hub(self.dataset_name, split="train")
return True, f"Successfully updated dataset '{self.dataset_name}' with {len(renamed_files)} new images"
except Exception as e:
return False, f"Error updating Hugging Face dataset: {str(e)}"
def process_pipeline(folder_id, naming_convention):
"""Main pipeline to process images and update dataset"""
manager = DatasetManager()
auth_success, auth_message = manager.authenticate_drive()
if not auth_success:
return auth_message
success, message, renamed_files = manager.download_and_rename_files(folder_id, naming_convention)
if not success:
return message
success, hf_message = manager.update_huggingface_dataset(renamed_files)
return f"{message}\n{hf_message}"
# Custom CSS for web-safe fonts and clean styling
custom_css = """
.gradio-container {
font-family: Arial, sans-serif !important;
}
h1, h2, h3 {
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif !important;
font-weight: 600 !important;
}
.gr-button {
font-family: Arial, sans-serif !important;
}
.gr-input {
font-family: 'Courier New', Courier, monospace !important;
}
.gr-box {
border-radius: 8px !important;
border: 1px solid #e5e5e5 !important;
}
.gr-padded {
padding: 16px !important;
}
"""
# Gradio interface with custom theme
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# Sports Cards Dataset Processor")
with gr.Box():
gr.Markdown("""
### Instructions
1. Enter the Google Drive folder/file ID
2. Choose a naming convention for your cards
3. Click Process to start
""")
with gr.Row():
with gr.Column():
folder_id = gr.Textbox(
label="Google Drive File/Folder ID",
placeholder="Enter the ID from your Google Drive URL",
value="151VOxPO91mg0C3ORiioGUd4hogzP1ujm"
)
naming = gr.Textbox(
label="Naming Convention",
placeholder="e.g., sports_card",
value="sports_card"
)
process_btn = gr.Button("Process Images", variant="primary")
with gr.Box():
output = gr.Textbox(
label="Processing Status",
show_label=True,
lines=5
)
process_btn.click(
fn=process_pipeline,
inputs=[folder_id, naming],
outputs=output
)
if __name__ == "__main__":
demo.launch()