text-to-map / app.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Update app.py
023c8c3 verified
raw
history blame
3.07 kB
import os
import tempfile
import torch
import numpy as np
import gradio as gr
from PIL import Image
import cv2
from diffusers import DiffusionPipeline
from script import SatelliteModelGenerator
# Initialize models and device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# Initialize FLUX model for satellite imagery
flux_pipe = DiffusionPipeline.from_pretrained(
"jbilcke-hf/flux-satellite",
torch_dtype=dtype
).to(device)
def generate_and_process_map(prompt: str) -> str | None:
"""Generate satellite image from prompt and convert to 3D model."""
try:
# Set dimensions
width = height = 1024
# Generate random seed
seed = np.random.randint(0, np.iinfo(np.int32).max)
# Set random seeds
torch.manual_seed(seed)
np.random.seed(seed)
# Generate satellite image using FLUX
generator = torch.Generator(device=device).manual_seed(seed)
generated_image = flux_pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=30,
generator=generator,
guidance_scale=7.5
).images[0]
# Convert PIL Image to OpenCV format
cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR)
# Initialize SatelliteModelGenerator
generator = SatelliteModelGenerator(building_height=0.09)
# Process image
print("Segmenting image...")
segmented_img = generator.segment_image(cv_image, window_size=5)
print("Estimating heights...")
height_map = generator.estimate_heights(cv_image, segmented_img)
# Generate mesh
print("Generating mesh...")
mesh = generator.generate_mesh(height_map, cv_image, add_walls=True)
# Export to GLB
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, 'output.glb')
mesh.export(output_path)
return output_path
except Exception as e:
print(f"Error during generation: {str(e)}")
import traceback
traceback.print_exc()
return None
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Text to Map")
gr.Markdown("Generate 3D maps from text descriptions using FLUX and mesh generation.")
with gr.Row():
prompt_input = gr.Text(
label="Enter your prompt",
placeholder="eg. satellite view of downtown Manhattan"
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
with gr.Row():
model_output = gr.Model3D(
label="Generated 3D Map",
clear_color=[0.0, 0.0, 0.0, 0.0],
)
# Event handler
generate_btn.click(
fn=generate_and_process_map,
inputs=[prompt_input],
outputs=[model_output],
api_name="generate"
)
if __name__ == "__main__":
demo.queue().launch()