MESA / app.py
mikonvergence's picture
Create app.py
21202e6 verified
raw
history blame
4.84 kB
!pip install "huggingface_hub[hf_transfer]"
!pip install -U "huggingface_hub[cli]"
!pip install gradio trimesh scipy
!HF_HUB_ENABLE_HF_TRANSFER=1
!git clone https://github.com/PaulBorneP/MESA.git
!cd MESA
!mkdir weights
!huggingface-cli download NewtNewt/MESA --local-dir weights
import torch
from MESA.pipeline_terrain import TerrainDiffusionPipeline
import sys
import gradio as gr
import numpy as np
import trimesh
import tempfile
import torch
from scipy.spatial import Delaunay
sys.path.append('MESA/')
pipe = TerrainDiffusionPipeline.from_pretrained("./weights", torch_dtype=torch.float16)
pipe.to("cuda")
def generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
"""Generates terrain data (RGB and elevation) from a text prompt."""
if prefix and not prefix.endswith(' '):
prefix += ' ' # Ensure prefix ends with a space
full_prompt = prefix + prompt
generator = torch.Generator("cuda").manual_seed(seed)
image, dem = pipe(full_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator)
# Center crop the image and dem
h, w, c = image[0].shape
start_h = (h - crop_size) // 2
start_w = (w - crop_size) // 2
end_h = start_h + crop_size
end_w = start_w + crop_size
cropped_image = image[0][start_h:end_h, start_w:end_w, :]
cropped_dem = dem[0][start_h:end_h, start_w:end_w, :]
return (255 * cropped_image).astype(np.uint8), 500*cropped_dem.mean(-1)
def create_3d_mesh(rgb, elevation):
"""Creates a 3D mesh from RGB and elevation data."""
x, y = np.meshgrid(np.arange(elevation.shape[1]), np.arange(elevation.shape[0]))
points = np.stack([x.flatten(), y.flatten()], axis=-1)
tri = Delaunay(points)
vertices = np.stack([x.flatten(), y.flatten(), elevation.flatten()], axis=-1)
faces = tri.simplices
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=rgb.reshape(-1, 3))
return mesh
def generate_and_display(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix):
"""Generates terrain and displays it as a 3D model."""
rgb, elevation = generate_terrain(prompt, num_inference_steps, guidance_scale, seed, crop_size, prefix)
mesh = create_3d_mesh(rgb, elevation)
with tempfile.NamedTemporaryFile(suffix=".obj", delete=False) as temp_file:
mesh.export(temp_file.name)
file_path = temp_file.name
return file_path
theme = gr.themes.Soft(primary_hue="red", secondary_hue="red", font=['arial'])
with gr.Blocks(theme=theme) as demo:
with gr.Column(elem_classes="header"):
gr.Markdown("# MESA: Text-Driven Terrain Generation Using Latent Diffusion and Global Copernicus Data")
gr.Markdown("### Paul Borne–Pons, Mikolaj Czerkawski, Rosalie Martin, Romain Rouffet")
gr.Markdown('[[GitHub](https://github.com/PaulBorneP/MESA)] [[Model](https://huggingface.co/NewtNewt/MESA)] [[Dataset](https://huggingface.co/datasets/Major-TOM/Core-DEM)]')
# Abstract Section
with gr.Column(elem_classes="abstract"):
gr.Markdown("MESA is a novel generative model based on latent denoising diffusion capable of generating 2.5D representations of terrain based on the text prompt conditioning supplied via natural language. The model produces two co-registered modalities of optical and depth maps.") # Replace with your abstract text
gr.Markdown("This is a test version of the demo app. Please be aware that MESA supports primarily complex, mountainous terrains as opposed to flat land")
gr.Markdown("The generated image is quite large, so for the full resolution (768) it might take a while to load the surface")
with gr.Row():
prompt_input = gr.Textbox(lines=2, placeholder="Enter a terrain description...")
generate_button = gr.Button("Generate Terrain", variant="primary")
model_output = gr.Model3D(
camera_position=[90, 180, 512]
)
with gr.Accordion("Advanced Options", open=False) as advanced_options:
num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
seed_number = gr.Number(value=6378, label="Seed")
crop_size_slider = gr.Slider(minimum=128, maximum=768, step=64, value=512, label="Crop Size")
prefix_textbox = gr.Textbox(label="Prompt Prefix", value="A Sentinel-2 image of ")
generate_button.click(
fn=generate_and_display,
inputs=[prompt_input, num_inference_steps_slider, guidance_scale_slider, seed_number, crop_size_slider, prefix_textbox],
outputs=model_output,
)
if __name__ == "__main__":
demo.launch(debug=True,
share=True)