Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,92 +1,43 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
from typing import List
|
4 |
-
from io import BytesIO
|
5 |
-
import numpy as np
|
6 |
-
import rasterio
|
7 |
-
from pydantic import BaseModel
|
8 |
-
import torch
|
9 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from mmcv import Config
|
11 |
-
from mmseg.apis import init_segmentor
|
12 |
-
import gradio as gr
|
13 |
-
from functools import partial
|
14 |
-
import time
|
15 |
-
import os
|
16 |
|
17 |
-
|
18 |
-
app = FastAPI()
|
19 |
|
20 |
-
|
21 |
-
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
22 |
-
filename="multi_temporal_crop_classification_Prithvi_100M.py",
|
23 |
-
token=os.environ.get("token"))
|
24 |
-
ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
25 |
-
filename='multi_temporal_crop_classification_Prithvi_100M.pth',
|
26 |
-
token=os.environ.get("token"))
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
model = init_segmentor(config, ckpt, device='cpu')
|
31 |
|
32 |
-
|
33 |
-
custom_test_pipeline = model.cfg.data.test.pipeline
|
34 |
|
35 |
-
|
36 |
-
class PredictionOutput(BaseModel):
|
37 |
-
t1: List[float]
|
38 |
-
t2: List[float]
|
39 |
-
t3: List[float]
|
40 |
-
prediction: List[float]
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
img = src.read()
|
46 |
|
47 |
-
|
48 |
-
processed_img = apply_pipeline(custom_test_pipeline, img)
|
49 |
-
|
50 |
-
# Run inference
|
51 |
-
output = model.inference(processed_img)
|
52 |
-
|
53 |
-
# Post-process the output to get the RGB and prediction images
|
54 |
-
rgb1 = postprocess_output(output[0])
|
55 |
-
rgb2 = postprocess_output(output[1])
|
56 |
-
rgb3 = postprocess_output(output[2])
|
57 |
-
|
58 |
-
return rgb1, rgb2, rgb3, output
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
return img
|
64 |
|
65 |
-
|
66 |
-
# Convert the model's output into an RGB image or other formats as needed
|
67 |
-
return output
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
target_image = BytesIO(await file.read())
|
73 |
-
|
74 |
-
# Save the file temporarily if needed
|
75 |
-
with open("temp_image.tif", "wb") as f:
|
76 |
-
f.write(target_image.getvalue())
|
77 |
|
78 |
-
# Run the prediction
|
79 |
-
rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline)
|
80 |
-
|
81 |
-
# Return the results
|
82 |
-
return {
|
83 |
-
"t1": rgb1.tolist(),
|
84 |
-
"t2": rgb2.tolist(),
|
85 |
-
"t3": rgb3.tolist(),
|
86 |
-
"prediction": output.tolist()
|
87 |
-
}
|
88 |
-
|
89 |
-
# Optional: Serve the Gradio interface (if you still want to use it with FastAPI)
|
90 |
cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
|
91 |
{'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
|
92 |
{'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
|
@@ -316,8 +267,4 @@ with gr.Blocks() as demo:
|
|
316 |
gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
|
317 |
|
318 |
|
319 |
-
demo.launch()
|
320 |
-
|
321 |
-
if __name__ == "__main__":
|
322 |
-
run_gradio_interface()
|
323 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
+
######### pull files
|
2 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
+
config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
5 |
+
filename="multi_temporal_crop_classification_Prithvi_100M.py",
|
6 |
+
token=os.environ.get("token"))
|
7 |
+
ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
|
8 |
+
filename='multi_temporal_crop_classification_Prithvi_100M.pth',
|
9 |
+
token=os.environ.get("token"))
|
10 |
+
##########
|
11 |
+
import argparse
|
12 |
from mmcv import Config
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
from mmseg.models import build_segmentor
|
|
|
15 |
|
16 |
+
from mmseg.datasets.pipelines import Compose, LoadImageFromFile
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
import rasterio
|
19 |
+
import torch
|
|
|
20 |
|
21 |
+
from mmseg.apis import init_segmentor
|
|
|
22 |
|
23 |
+
from mmcv.parallel import collate, scatter
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
import numpy as np
|
26 |
+
import glob
|
27 |
+
import os
|
|
|
28 |
|
29 |
+
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
import numpy as np
|
32 |
+
import gradio as gr
|
33 |
+
from functools import partial
|
|
|
34 |
|
35 |
+
import pdb
|
|
|
|
|
36 |
|
37 |
+
import matplotlib.pyplot as plt
|
38 |
+
|
39 |
+
from skimage import exposure
|
|
|
|
|
|
|
|
|
|
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
|
42 |
{'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
|
43 |
{'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
|
|
|
267 |
gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
|
268 |
|
269 |
|
270 |
+
demo.launch()
|
|
|
|
|
|
|
|