Spaces:
Paused
Paused
File size: 3,068 Bytes
023c8c3 bceaa96 023c8c3 bceaa96 023c8c3 bceaa96 023c8c3 bceaa96 023c8c3 bceaa96 023c8c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import tempfile
import torch
import numpy as np
import gradio as gr
from PIL import Image
import cv2
from diffusers import DiffusionPipeline
from script import SatelliteModelGenerator
# Initialize models and device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# Initialize FLUX model for satellite imagery
flux_pipe = DiffusionPipeline.from_pretrained(
"jbilcke-hf/flux-satellite",
torch_dtype=dtype
).to(device)
def generate_and_process_map(prompt: str) -> str | None:
"""Generate satellite image from prompt and convert to 3D model."""
try:
# Set dimensions
width = height = 1024
# Generate random seed
seed = np.random.randint(0, np.iinfo(np.int32).max)
# Set random seeds
torch.manual_seed(seed)
np.random.seed(seed)
# Generate satellite image using FLUX
generator = torch.Generator(device=device).manual_seed(seed)
generated_image = flux_pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=30,
generator=generator,
guidance_scale=7.5
).images[0]
# Convert PIL Image to OpenCV format
cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR)
# Initialize SatelliteModelGenerator
generator = SatelliteModelGenerator(building_height=0.09)
# Process image
print("Segmenting image...")
segmented_img = generator.segment_image(cv_image, window_size=5)
print("Estimating heights...")
height_map = generator.estimate_heights(cv_image, segmented_img)
# Generate mesh
print("Generating mesh...")
mesh = generator.generate_mesh(height_map, cv_image, add_walls=True)
# Export to GLB
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, 'output.glb')
mesh.export(output_path)
return output_path
except Exception as e:
print(f"Error during generation: {str(e)}")
import traceback
traceback.print_exc()
return None
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Text to Map")
gr.Markdown("Generate 3D maps from text descriptions using FLUX and mesh generation.")
with gr.Row():
prompt_input = gr.Text(
label="Enter your prompt",
placeholder="eg. satellite view of downtown Manhattan"
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
with gr.Row():
model_output = gr.Model3D(
label="Generated 3D Map",
clear_color=[0.0, 0.0, 0.0, 0.0],
)
# Event handler
generate_btn.click(
fn=generate_and_process_map,
inputs=[prompt_input],
outputs=[model_output],
api_name="generate"
)
if __name__ == "__main__":
demo.queue().launch() |