Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile | |
import uvicorn | |
from typing import List | |
from io import BytesIO | |
import numpy as np | |
import rasterio | |
from pydantic import BaseModel | |
import torch | |
from huggingface_hub import hf_hub_download | |
from mmcv import Config | |
from mmseg.apis import init_segmentor | |
import gradio as gr | |
from functools import partial | |
import time | |
import os | |
# Initialize the FastAPI app | |
app = FastAPI() | |
# Load the model and config | |
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", | |
filename="multi_temporal_crop_classification_Prithvi_100M.py", | |
token=os.environ.get("token")) | |
ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification", | |
filename='multi_temporal_crop_classification_Prithvi_100M.pth', | |
token=os.environ.get("token")) | |
config = Config.fromfile(config_path) | |
config.model.backbone.pretrained = None | |
model = init_segmentor(config, ckpt, device='cpu') | |
# Use the test pipeline directly | |
custom_test_pipeline = model.cfg.data.test.pipeline | |
# Define the input/output model for FastAPI | |
class PredictionOutput(BaseModel): | |
t1: List[float] | |
t2: List[float] | |
t3: List[float] | |
prediction: List[float] | |
# Define the inference function | |
def inference_on_file(file_path, model, custom_test_pipeline): | |
with rasterio.open(file_path) as src: | |
img = src.read() | |
# Apply preprocessing using the custom pipeline | |
processed_img = apply_pipeline(custom_test_pipeline, img) | |
# Run inference | |
output = model.inference(processed_img) | |
# Post-process the output to get the RGB and prediction images | |
rgb1 = postprocess_output(output[0]) | |
rgb2 = postprocess_output(output[1]) | |
rgb3 = postprocess_output(output[2]) | |
return rgb1, rgb2, rgb3, output | |
def apply_pipeline(pipeline, img): | |
# Implement your custom pipeline processing here | |
# This could include normalization, resizing, etc. | |
return img | |
def postprocess_output(output): | |
# Convert the model's output into an RGB image or other formats as needed | |
return output | |
async def predict(file: UploadFile = File(...)): | |
# Read the uploaded file | |
target_image = BytesIO(await file.read()) | |
# Save the file temporarily if needed | |
with open("temp_image.tif", "wb") as f: | |
f.write(target_image.getvalue()) | |
# Run the prediction | |
rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline) | |
# Return the results | |
return { | |
"t1": rgb1.tolist(), | |
"t2": rgb2.tolist(), | |
"t3": rgb3.tolist(), | |
"prediction": output.tolist() | |
} | |
# Optional: Serve the Gradio interface (if you still want to use it with FastAPI) | |
def run_gradio_interface(): | |
func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline) | |
with gr.Blocks() as demo: | |
gr.Markdown(value='# Prithvi multi temporal crop classification') | |
gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n | |
The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.''') | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.File() | |
btn = gr.Button("Submit") | |
with gr.Row(): | |
inp1 = gr.Image(image_mode='RGB', scale=10, label='T1') | |
inp2 = gr.Image(image_mode='RGB', scale=10, label='T2') | |
inp3 = gr.Image(image_mode='RGB', scale=10, label='T3') | |
out = gr.Image(image_mode='RGB', scale=10, label='Model prediction') | |
btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out]) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Examples(examples=["chip_102_345_merged.tif", | |
"chip_104_104_merged.tif", | |
"chip_109_421_merged.tif"], | |
inputs=inp, | |
outputs=[inp1, inp2, inp3, out], | |
preprocess=preprocess_example, | |
fn=func, | |
cache_examples=True) | |
with gr.Column(): | |
gr.Markdown(value='### Model prediction legend') | |
gr.Image(value='Legend.png', image_mode='RGB', show_label=False) | |
demo.launch() | |
if __name__ == "__main__": | |
run_gradio_interface() | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |