Spaces:
Paused
Paused
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
import subprocess | |
import os | |
import shutil | |
import tempfile | |
import zipfile | |
from diffusers import StableDiffusionInstructPix2PixPipeline | |
import torch | |
from PIL import Image | |
import json | |
app = FastAPI() | |
# Load InstructPix2Pix model | |
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
"timm/instruct-pix2pix", | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
).to("cuda") | |
class DatasetRequest(BaseModel): | |
objects: list[str] | |
environment: str | |
num_images: int | |
augmentation_prompts: list[str] | |
def augment_image(image_path, prompt): | |
image = Image.open(image_path).convert("RGB") | |
augmented = pipe(prompt=prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5).images[0] | |
return augmented | |
async def generate_dataset(request: DatasetRequest): | |
try: | |
with tempfile.TemporaryDirectory() as tmpdirname: | |
# Step 1: Generate base images with Blender | |
base_dir = os.path.join(tmpdirname, "base") | |
os.makedirs(base_dir) | |
subprocess.run([ | |
"blender", "--background", "--python", "blender_script.py", "--", | |
",".join(request.objects), request.environment, str(request.num_images), base_dir | |
], check=True) | |
# Load base annotations | |
with open(os.path.join(base_dir, "annotations.json"), "r") as f: | |
base_annotations = json.load(f) | |
# Step 2: Augment images | |
output_dir = os.path.join(tmpdirname, "output/images") | |
os.makedirs(output_dir) | |
annotations = [] | |
image_id = 0 | |
for base_anno in base_annotations: | |
base_image_path = os.path.join(base_dir, base_anno["file_name"]) | |
for prompt in request.augmentation_prompts: | |
augmented = augment_image(base_image_path, prompt) | |
new_filename = f"image_{image_id}.png" | |
augmented.save(os.path.join(output_dir, new_filename)) | |
annotations.append({ | |
"image_id": image_id, | |
"file_name": new_filename, | |
"labels": base_anno["labels"] | |
}) | |
image_id += 1 | |
# Save annotations | |
anno_file = os.path.join(tmpdirname, "output/annotations.json") | |
with open(anno_file, "w") as f: | |
json.dump(annotations, f) | |
# Step 3: Create zip file | |
zip_path = os.path.join(tmpdirname, "dataset.zip") | |
with zipfile.ZipFile(zip_path, "w") as zipf: | |
for root, _, files in os.walk(output_dir): | |
for file in files: | |
zipf.write(os.path.join(root, file), os.path.join("images", file)) | |
zipf.write(anno_file, "annotations.json") | |
return FileResponse(zip_path, media_type="application/zip", filename="dataset.zip") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |