Spaces:
Sleeping
Sleeping
File size: 5,105 Bytes
0d91e34 634d7ce 0d91e34 634d7ce 0d91e34 634d7ce 0d91e34 2b313f8 634d7ce 0d91e34 634d7ce 0d91e34 634d7ce 0d91e34 634d7ce 06355df 0d91e34 03cfd4a 0d91e34 734393f a413a55 0d91e34 634d7ce 0d91e34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from fastapi import FastAPI, File, UploadFile
import uvicorn
from typing import List
from io import BytesIO
import numpy as np
import rasterio
from pydantic import BaseModel
import torch
from huggingface_hub import hf_hub_download
from mmcv import Config
from mmseg.apis import init_segmentor
import gradio as gr
from functools import partial
import time
import os
# Initialize the FastAPI app
app = FastAPI()
# Load the model and config
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
filename="multi_temporal_crop_classification_Prithvi_100M.py",
token=os.environ.get("token"))
ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
filename='multi_temporal_crop_classification_Prithvi_100M.pth',
token=os.environ.get("token"))
config = Config.fromfile(config_path)
config.model.backbone.pretrained = None
model = init_segmentor(config, ckpt, device='cpu')
# Use the test pipeline directly
custom_test_pipeline = model.cfg.data.test.pipeline
# Define the input/output model for FastAPI
class PredictionOutput(BaseModel):
t1: List[float]
t2: List[float]
t3: List[float]
prediction: List[float]
# Define the inference function
def inference_on_file(file_path, model, custom_test_pipeline):
with rasterio.open(file_path) as src:
img = src.read()
# Apply preprocessing using the custom pipeline
processed_img = apply_pipeline(custom_test_pipeline, img)
# Run inference
output = model.inference(processed_img)
# Post-process the output to get the RGB and prediction images
rgb1 = postprocess_output(output[0])
rgb2 = postprocess_output(output[1])
rgb3 = postprocess_output(output[2])
return rgb1, rgb2, rgb3, output
def apply_pipeline(pipeline, img):
# Implement your custom pipeline processing here
# This could include normalization, resizing, etc.
return img
def postprocess_output(output):
# Convert the model's output into an RGB image or other formats as needed
return output
@app.post("/predict/", response_model=PredictionOutput)
async def predict(file: UploadFile = File(...)):
# Read the uploaded file
target_image = BytesIO(await file.read())
# Save the file temporarily if needed
with open("temp_image.tif", "wb") as f:
f.write(target_image.getvalue())
# Run the prediction
rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline)
# Return the results
return {
"t1": rgb1.tolist(),
"t2": rgb2.tolist(),
"t3": rgb3.tolist(),
"prediction": output.tolist()
}
# Optional: Serve the Gradio interface (if you still want to use it with FastAPI)
def run_gradio_interface():
func = partial(inference_on_file, model=model, custom_test_pipeline=custom_test_pipeline)
with gr.Blocks() as demo:
gr.Markdown(value='# Prithvi multi temporal crop classification')
gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. This demo showcases how the model was finetuned to classify crop and other land use categories using multi temporal data. More details can be found [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification).\n
The user needs to provide an HLS geotiff image, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order.''')
with gr.Row():
with gr.Column():
inp = gr.File()
btn = gr.Button("Submit")
with gr.Row():
inp1 = gr.Image(image_mode='RGB', scale=10, label='T1')
inp2 = gr.Image(image_mode='RGB', scale=10, label='T2')
inp3 = gr.Image(image_mode='RGB', scale=10, label='T3')
out = gr.Image(image_mode='RGB', scale=10, label='Model prediction')
btn.click(fn=func, inputs=inp, outputs=[inp1, inp2, inp3, out])
with gr.Row():
with gr.Column():
gr.Examples(examples=["chip_102_345_merged.tif",
"chip_104_104_merged.tif",
"chip_109_421_merged.tif"],
inputs=inp,
outputs=[inp1, inp2, inp3, out],
preprocess=preprocess_example,
fn=func,
cache_examples=True)
with gr.Column():
gr.Markdown(value='### Model prediction legend')
gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
demo.launch()
if __name__ == "__main__":
run_gradio_interface()
uvicorn.run(app, host="0.0.0.0", port=8000)
|