import os |
import logging |
import asyncio |
import requests |
import fal_client |
import json |
from typing import Optional |
logging.basicConfig( |
level=logging.INFO, |
format='%(asctime)s - %(levelname)s - %(message)s', |
datefmt='%Y-%m-%d %H:%M:%S' |
) |
async def remove_background_birefnet(image_path: str) -> Optional[str]: |
"""Remove background using BiRefNet API asynchronously.""" |
logging.info(f"Starting BiRefNet processing for: {image_path}") |
try: |
logging.info("Submitting request to BiRefNet API...") |
handler = await fal_client.submit_async( |
"fal-ai/birefnet/v2", |
arguments={ |
"image_url": image_path, |
"model": "General Use (Heavy)", |
"operating_resolution": "1024x1024", |
"output_format": "png", |
"refine_foreground": True |
} |
) |
request_id = handler.request_id |
logging.info(f"🔄 Request submitted with ID: {request_id}") |
while True: |
status = await fal_client.status_async("fal-ai/birefnet/v2", request_id, with_logs=True) |
if hasattr(status, 'logs') and status.logs: |
for log in status.logs: |
level = log.get('level', 'INFO') |
message = log.get('message', '') |
logging.info(f"🔄 BiRefNet {level}: {message}") |
if isinstance(status, fal_client.Queued): |
logging.info(f"⏳ Request in queue") |
elif isinstance(status, fal_client.InProgress): |
logging.info("🔄 Request is being processed...") |
elif isinstance(status, fal_client.Completed): |
logging.info("✅ Request completed") |
break |
elif isinstance(status, fal_client.Failed): |
logging.error(f"❌ Request failed: {status.error}") |
return None |
else: |
logging.error(f"❌ Unknown status type: {type(status)}") |
return None |
await asyncio.sleep(1) |
result = await fal_client.result_async("fal-ai/birefnet/v2", request_id) |
if not result or not isinstance(result, dict): |
logging.error("❌ Invalid result from BiRefNet") |
return None |
image_data = result.get('image', {}) |
if not image_data or not isinstance(image_data, dict): |
logging.error(f"❌ Missing or invalid image data in result: {result}") |
return None |
image_url = image_data.get('url') |
if not image_url: |
logging.error(f"❌ Missing image URL in result: {image_data}") |
return None |
logging.info(f"✅ Got image: {image_data.get('width')}x{image_data.get('height')} " |
f"({image_data.get('file_size', 0) / 1024 / 1024:.1f}MB)") |
return image_url |
except Exception as e: |
logging.error(f"❌ Unexpected error using BiRefNet API: {str(e)}", exc_info=True) |
return None |
async def process_single_image(input_path: str, output_path: str) -> bool: |
"""Process a single image asynchronously.""" |
try: |
logging.info(f"📤 Uploading to temporary storage...") |
image_url = await fal_client.upload_file_async(input_path) |
logging.info(f"✅ Upload successful: {image_url}") |
result_url = await remove_background_birefnet(image_url) |
if result_url: |
logging.info(f"📥 Downloading result...") |
response = requests.get(result_url) |
response.raise_for_status() |
content_type = response.headers.get('content-type', '') |
if 'image' not in content_type: |
logging.error(f"❌ Invalid content type: {content_type}") |
return False |
with open(output_path, 'wb') as f: |
f.write(response.content) |
logging.info(f"✅ Successfully saved to {output_path}") |
return True |
return False |
except Exception as e: |
logging.error(f"❌ Error processing image: {str(e)}", exc_info=True) |
return False |
async def iterate_over_directory(input_dir: str, output_dir: str): |
"""Process all images in a directory using BiRefNet with async processing.""" |
logging.info(f"🚀 Starting BiRefNet processing for directory: {input_dir}") |
logging.info(f"📁 Output directory: {output_dir}") |
os.makedirs(output_dir, exist_ok=True) |
files = [f for f in os.listdir(input_dir) |
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
total_files = len(files) |
processed = 0 |
skipped = 0 |
failed = 0 |
logging.info(f"📊 Found {total_files} images to process") |
batch_size = 3 |
for i in range(0, len(files), batch_size): |
batch = files[i:i + batch_size] |
tasks = [] |
for filename in batch: |
input_path = os.path.join(input_dir, filename) |
output_path = os.path.join(output_dir, filename) |
logging.info(f"\n{'='*50}") |
logging.info(f"Processing [{i + len(tasks) + 1}/{total_files}]: {filename}") |
if os.path.exists(output_path): |
logging.info(f"⏭️ Skipping {filename} - already processed") |
skipped += 1 |
continue |
tasks.append(process_single_image(input_path, output_path)) |
if tasks: |
results = await asyncio.gather(*tasks, return_exceptions=True) |
for filename, result in zip(batch, results): |
if isinstance(result, Exception): |
logging.error(f"❌ Failed to process {filename}: {str(result)}") |
failed += 1 |
elif result: |
processed += 1 |
else: |
failed += 1 |
await asyncio.sleep(1) |
logging.info(f"\n{'='*50}") |
logging.info(f"📊 Processing Summary:") |
logging.info(f"✅ Successfully processed: {processed}") |
logging.info(f"⏭️ Skipped (already existed): {skipped}") |
logging.info(f"❌ Failed: {failed}") |
logging.info(f"📁 Total files: {total_files}") |
def process_directory(input_dir: str, output_dir: str): |
"""Synchronous wrapper for iterate_over_directory.""" |
asyncio.run(iterate_over_directory(input_dir, output_dir)) |