|
import os |
|
import logging |
|
import asyncio |
|
import requests |
|
import fal_client |
|
import json |
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
) |
|
|
|
async def remove_background_birefnet(image_path: str) -> Optional[str]: |
|
"""Remove background using BiRefNet API asynchronously.""" |
|
logging.info(f"Starting BiRefNet processing for: {image_path}") |
|
try: |
|
|
|
logging.info("Submitting request to BiRefNet API...") |
|
handler = await fal_client.submit_async( |
|
"fal-ai/birefnet/v2", |
|
arguments={ |
|
"image_url": image_path, |
|
"model": "General Use (Heavy)", |
|
"operating_resolution": "1024x1024", |
|
"output_format": "png", |
|
"refine_foreground": True |
|
} |
|
) |
|
request_id = handler.request_id |
|
logging.info(f"🔄 Request submitted with ID: {request_id}") |
|
|
|
|
|
while True: |
|
status = await fal_client.status_async("fal-ai/birefnet/v2", request_id, with_logs=True) |
|
|
|
|
|
if hasattr(status, 'logs') and status.logs: |
|
for log in status.logs: |
|
level = log.get('level', 'INFO') |
|
message = log.get('message', '') |
|
logging.info(f"🔄 BiRefNet {level}: {message}") |
|
|
|
|
|
if isinstance(status, fal_client.Queued): |
|
logging.info(f"⏳ Request in queue") |
|
elif isinstance(status, fal_client.InProgress): |
|
logging.info("🔄 Request is being processed...") |
|
elif isinstance(status, fal_client.Completed): |
|
logging.info("✅ Request completed") |
|
break |
|
elif isinstance(status, fal_client.Failed): |
|
logging.error(f"❌ Request failed: {status.error}") |
|
return None |
|
else: |
|
logging.error(f"❌ Unknown status type: {type(status)}") |
|
return None |
|
|
|
await asyncio.sleep(1) |
|
|
|
|
|
result = await fal_client.result_async("fal-ai/birefnet/v2", request_id) |
|
|
|
if not result or not isinstance(result, dict): |
|
logging.error("❌ Invalid result from BiRefNet") |
|
return None |
|
|
|
image_data = result.get('image', {}) |
|
if not image_data or not isinstance(image_data, dict): |
|
logging.error(f"❌ Missing or invalid image data in result: {result}") |
|
return None |
|
|
|
image_url = image_data.get('url') |
|
if not image_url: |
|
logging.error(f"❌ Missing image URL in result: {image_data}") |
|
return None |
|
|
|
|
|
logging.info(f"✅ Got image: {image_data.get('width')}x{image_data.get('height')} " |
|
f"({image_data.get('file_size', 0) / 1024 / 1024:.1f}MB)") |
|
return image_url |
|
|
|
except Exception as e: |
|
logging.error(f"❌ Unexpected error using BiRefNet API: {str(e)}", exc_info=True) |
|
return None |
|
|
|
async def process_single_image(input_path: str, output_path: str) -> bool: |
|
"""Process a single image asynchronously.""" |
|
try: |
|
|
|
logging.info(f"📤 Uploading to temporary storage...") |
|
image_url = await fal_client.upload_file_async(input_path) |
|
logging.info(f"✅ Upload successful: {image_url}") |
|
|
|
|
|
result_url = await remove_background_birefnet(image_url) |
|
|
|
if result_url: |
|
|
|
logging.info(f"📥 Downloading result...") |
|
response = requests.get(result_url) |
|
response.raise_for_status() |
|
|
|
content_type = response.headers.get('content-type', '') |
|
if 'image' not in content_type: |
|
logging.error(f"❌ Invalid content type: {content_type}") |
|
return False |
|
|
|
with open(output_path, 'wb') as f: |
|
f.write(response.content) |
|
logging.info(f"✅ Successfully saved to {output_path}") |
|
return True |
|
|
|
return False |
|
except Exception as e: |
|
logging.error(f"❌ Error processing image: {str(e)}", exc_info=True) |
|
return False |
|
|
|
async def iterate_over_directory(input_dir: str, output_dir: str): |
|
"""Process all images in a directory using BiRefNet with async processing.""" |
|
logging.info(f"🚀 Starting BiRefNet processing for directory: {input_dir}") |
|
logging.info(f"📁 Output directory: {output_dir}") |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
files = [f for f in os.listdir(input_dir) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
|
total_files = len(files) |
|
|
|
processed = 0 |
|
skipped = 0 |
|
failed = 0 |
|
|
|
logging.info(f"📊 Found {total_files} images to process") |
|
|
|
|
|
batch_size = 3 |
|
for i in range(0, len(files), batch_size): |
|
batch = files[i:i + batch_size] |
|
tasks = [] |
|
|
|
for filename in batch: |
|
input_path = os.path.join(input_dir, filename) |
|
output_path = os.path.join(output_dir, filename) |
|
|
|
logging.info(f"\n{'='*50}") |
|
logging.info(f"Processing [{i + len(tasks) + 1}/{total_files}]: {filename}") |
|
|
|
if os.path.exists(output_path): |
|
logging.info(f"⏭️ Skipping {filename} - already processed") |
|
skipped += 1 |
|
continue |
|
|
|
tasks.append(process_single_image(input_path, output_path)) |
|
|
|
if tasks: |
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
|
|
|
for filename, result in zip(batch, results): |
|
if isinstance(result, Exception): |
|
logging.error(f"❌ Failed to process {filename}: {str(result)}") |
|
failed += 1 |
|
elif result: |
|
processed += 1 |
|
else: |
|
failed += 1 |
|
|
|
|
|
await asyncio.sleep(1) |
|
|
|
logging.info(f"\n{'='*50}") |
|
logging.info(f"📊 Processing Summary:") |
|
logging.info(f"✅ Successfully processed: {processed}") |
|
logging.info(f"⏭️ Skipped (already existed): {skipped}") |
|
logging.info(f"❌ Failed: {failed}") |
|
logging.info(f"📁 Total files: {total_files}") |
|
|
|
def process_directory(input_dir: str, output_dir: str): |
|
"""Synchronous wrapper for iterate_over_directory.""" |
|
asyncio.run(iterate_over_directory(input_dir, output_dir)) |