File size: 3,411 Bytes
475ca6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
import subprocess
import os
import shutil
import tempfile
import zipfile
from diffusers import StableDiffusionInstructPix2PixPipeline
import torch
from PIL import Image
import json
app = FastAPI()

# Load InstructPix2Pix model
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timm/instruct-pix2pix",
    torch_dtype=torch.float16,
    safety_checker=None,
).to("cuda")

class DatasetRequest(BaseModel):
    objects: list[str]
    environment: str
    num_images: int
    augmentation_prompts: list[str]

def augment_image(image_path, prompt):
    image = Image.open(image_path).convert("RGB")
    augmented = pipe(prompt=prompt, image=image, num_inference_steps=20, image_guidance_scale=1.5).images[0]
    return augmented

@app.post("/generate_dataset")
async def generate_dataset(request: DatasetRequest):
    try:
        with tempfile.TemporaryDirectory() as tmpdirname:
            # Step 1: Generate base images with Blender
            base_dir = os.path.join(tmpdirname, "base")
            os.makedirs(base_dir)
            subprocess.run([
                "blender", "--background", "--python", "blender_script.py", "--",
                ",".join(request.objects), request.environment, str(request.num_images), base_dir
            ], check=True)
            
            # Load base annotations
            with open(os.path.join(base_dir, "annotations.json"), "r") as f:
                base_annotations = json.load(f)
            
            # Step 2: Augment images
            output_dir = os.path.join(tmpdirname, "output/images")
            os.makedirs(output_dir)
            annotations = []
            image_id = 0
            for base_anno in base_annotations:
                base_image_path = os.path.join(base_dir, base_anno["file_name"])
                for prompt in request.augmentation_prompts:
                    augmented = augment_image(base_image_path, prompt)
                    new_filename = f"image_{image_id}.png"
                    augmented.save(os.path.join(output_dir, new_filename))
                    annotations.append({
                        "image_id": image_id,
                        "file_name": new_filename,
                        "labels": base_anno["labels"]
                    })
                    image_id += 1
            
            # Save annotations
            anno_file = os.path.join(tmpdirname, "output/annotations.json")
            with open(anno_file, "w") as f:
                json.dump(annotations, f)
            
            # Step 3: Create zip file
            zip_path = os.path.join(tmpdirname, "dataset.zip")
            with zipfile.ZipFile(zip_path, "w") as zipf:
                for root, _, files in os.walk(output_dir):
                    for file in files:
                        zipf.write(os.path.join(root, file), os.path.join("images", file))
                zipf.write(anno_file, "annotations.json")
            
            return FileResponse(zip_path, media_type="application/zip", filename="dataset.zip")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy"}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)