File size: 3,068 Bytes
023c8c3
 
 
bceaa96
023c8c3
 
 
bceaa96
023c8c3
bceaa96
023c8c3
bceaa96
023c8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bceaa96
 
 
023c8c3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import tempfile
import torch
import numpy as np
import gradio as gr
from PIL import Image
import cv2
from diffusers import DiffusionPipeline
from script import SatelliteModelGenerator

# Initialize models and device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16

# Initialize FLUX model for satellite imagery
flux_pipe = DiffusionPipeline.from_pretrained(
    "jbilcke-hf/flux-satellite",
    torch_dtype=dtype
).to(device)

def generate_and_process_map(prompt: str) -> str | None:
    """Generate satellite image from prompt and convert to 3D model."""
    try:
        # Set dimensions
        width = height = 1024
        
        # Generate random seed
        seed = np.random.randint(0, np.iinfo(np.int32).max)
        
        # Set random seeds
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Generate satellite image using FLUX
        generator = torch.Generator(device=device).manual_seed(seed)
        generated_image = flux_pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=30,
            generator=generator,
            guidance_scale=7.5
        ).images[0]
        
        # Convert PIL Image to OpenCV format
        cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR)
        
        # Initialize SatelliteModelGenerator
        generator = SatelliteModelGenerator(building_height=0.09)
        
        # Process image
        print("Segmenting image...")
        segmented_img = generator.segment_image(cv_image, window_size=5)
        
        print("Estimating heights...")
        height_map = generator.estimate_heights(cv_image, segmented_img)
        
        # Generate mesh
        print("Generating mesh...")
        mesh = generator.generate_mesh(height_map, cv_image, add_walls=True)
        
        # Export to GLB
        temp_dir = tempfile.mkdtemp()
        output_path = os.path.join(temp_dir, 'output.glb')
        mesh.export(output_path)
        
        return output_path
        
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        import traceback
        traceback.print_exc()
        return None

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Text to Map")
    gr.Markdown("Generate 3D maps from text descriptions using FLUX and mesh generation.")
    
    with gr.Row():
        prompt_input = gr.Text(
            label="Enter your prompt",
            placeholder="eg. satellite view of downtown Manhattan"
        )
    
    with gr.Row():
        generate_btn = gr.Button("Generate", variant="primary")
    
    with gr.Row():
        model_output = gr.Model3D(
            label="Generated 3D Map",
            clear_color=[0.0, 0.0, 0.0, 0.0],
        )
    
    # Event handler
    generate_btn.click(
        fn=generate_and_process_map,
        inputs=[prompt_input],
        outputs=[model_output],
        api_name="generate"
    )

if __name__ == "__main__":
    demo.queue().launch()