Spaces:
Paused
Paused
import os | |
import tempfile | |
import torch | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import cv2 | |
from diffusers import DiffusionPipeline | |
import cupy as cp | |
from cupyx.scipy.ndimage import label as cp_label | |
from cupyx.scipy.ndimage import binary_dilation | |
from sklearn.cluster import DBSCAN | |
import trimesh | |
class GPUSatelliteModelGenerator: | |
def __init__(self, building_height=0.05): | |
self.building_height = building_height | |
# Move color arrays to GPU using cupy | |
self.shadow_colors = cp.array([ | |
[31, 42, 76], | |
[58, 64, 92], | |
[15, 27, 56], | |
[21, 22, 50], | |
[76, 81, 99] | |
]) | |
self.road_colors = cp.array([ | |
[187, 182, 175], | |
[138, 138, 138], | |
[142, 142, 129], | |
[202, 199, 189] | |
]) | |
self.water_colors = cp.array([ | |
[167, 225, 217], | |
[67, 101, 97], | |
[53, 83, 84], | |
[47, 94, 100], | |
[73, 131, 135] | |
]) | |
# Convert reference colors to HSV on GPU | |
self.shadow_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.shadow_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
self.road_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.road_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
self.water_colors_hsv = cp.asarray(cv2.cvtColor( | |
self.water_colors.get().reshape(-1, 1, 3).astype(np.uint8), | |
cv2.COLOR_RGB2HSV | |
).reshape(-1, 3)) | |
# Normalize HSV values on GPU | |
for colors_hsv in [self.shadow_colors_hsv, self.road_colors_hsv, self.water_colors_hsv]: | |
colors_hsv[:, 0] = colors_hsv[:, 0] * 2 | |
colors_hsv[:, 1:] = colors_hsv[:, 1:] / 255 | |
# Color tolerances | |
self.shadow_tolerance = {'hue': 15, 'sat': 0.15, 'val': 0.12} | |
self.road_tolerance = {'hue': 10, 'sat': 0.12, 'val': 0.15} | |
self.water_tolerance = {'hue': 20, 'sat': 0.15, 'val': 0.20} | |
# Output colors (BGR for OpenCV) | |
self.colors = { | |
'black': cp.array([0, 0, 0]), # Shadows | |
'blue': cp.array([255, 0, 0]), # Water | |
'green': cp.array([0, 255, 0]), # Vegetation | |
'gray': cp.array([128, 128, 128]), # Roads | |
'brown': cp.array([0, 140, 255]), # Terrain | |
'white': cp.array([255, 255, 255]) # Buildings | |
} | |
self.min_area_for_clustering = 1000 | |
self.residential_height_factor = 0.6 | |
self.isolation_threshold = 0.6 | |
# ... [Previous methods remain unchanged] ... | |
def generate_and_process_map(prompt: str) -> tuple[str | None, np.ndarray | None]: | |
"""Generate satellite image from prompt and convert to 3D model using GPU acceleration""" | |
try: | |
# Set dimensions and device | |
width = height = 1024 | |
# Generate random seed | |
seed = np.random.randint(0, np.iinfo(np.int32).max) | |
# Set random seeds | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
# Generate satellite image using FLUX | |
generator = torch.Generator(device=device).manual_seed(seed) | |
generated_image = flux_pipe( | |
prompt=f"satellite view in the style of TOK, {prompt}", | |
width=width, | |
height=height, | |
num_inference_steps=25, | |
generator=generator, | |
guidance_scale=7.5 | |
).images[0] | |
# Convert PIL Image to OpenCV format | |
cv_image = cv2.cvtColor(np.array(generated_image), cv2.COLOR_RGB2BGR) | |
# Initialize GPU-accelerated generator | |
generator = GPUSatelliteModelGenerator(building_height=0.09) | |
# Process image using GPU | |
print("Segmenting image using GPU...") | |
segmented_img = generator.segment_image_gpu(cv_image) | |
print("Estimating heights using GPU...") | |
height_map = generator.estimate_heights_gpu(cv_image, segmented_img) | |
# Generate mesh using GPU-accelerated calculations | |
print("Generating mesh using GPU...") | |
mesh = generator.generate_mesh_gpu(height_map, cv_image) | |
# Export to GLB | |
temp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(temp_dir, 'output.glb') | |
mesh.export(output_path) | |
# Save segmented image to a temporary file | |
segmented_path = os.path.join(temp_dir, 'segmented.png') | |
cv2.imwrite(segmented_path, segmented_img.get()) | |
return output_path, segmented_path | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return None, None | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# GPU-Accelerated Text to Map") | |
gr.Markdown("Generate 3D maps and segmentation maps from text descriptions using FLUX and GPU-accelerated processing.") | |
with gr.Row(): | |
prompt_input = gr.Text( | |
label="Enter your prompt", | |
placeholder="classic american town" | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
model_output = gr.Model3D( | |
label="Generated 3D Map", | |
clear_color=[0.0, 0.0, 0.0, 0.0], | |
) | |
with gr.Column(): | |
segmented_output = gr.Image( | |
label="Segmented Map", | |
type="filepath" | |
) | |
# Event handler | |
generate_btn.click( | |
fn=generate_and_process_map, | |
inputs=[prompt_input], | |
outputs=[model_output, segmented_output], | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
# Initialize FLUX pipeline | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 | |
repo_id = "black-forest-labs/FLUX.1-dev" | |
adapter_id = "jbilcke-hf/flux-satellite" | |
flux_pipe = DiffusionPipeline.from_pretrained( | |
repo_id, | |
torch_dtype=torch.bfloat16 | |
) | |
flux_pipe.load_lora_weights(adapter_id) | |
flux_pipe = flux_pipe.to(device) | |
# Launch Gradio app | |
demo.queue().launch() |