sanket09 commited on
Commit
06d70d7
·
verified ·
1 Parent(s): 015cdcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -81
app.py CHANGED
@@ -1,92 +1,43 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- import uvicorn
3
- from typing import List
4
- from io import BytesIO
5
- import numpy as np
6
- import rasterio
7
- from pydantic import BaseModel
8
- import torch
9
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
10
  from mmcv import Config
11
- from mmseg.apis import init_segmentor
12
- import gradio as gr
13
- from functools import partial
14
- import time
15
- import os
16
 
17
- # Initialize the FastAPI app
18
- app = FastAPI()
19
 
20
- # Load the model and config
21
- config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
22
- filename="multi_temporal_crop_classification_Prithvi_100M.py",
23
- token=os.environ.get("token"))
24
- ckpt = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
25
- filename='multi_temporal_crop_classification_Prithvi_100M.pth',
26
- token=os.environ.get("token"))
27
 
28
- config = Config.fromfile(config_path)
29
- config.model.backbone.pretrained = None
30
- model = init_segmentor(config, ckpt, device='cpu')
31
 
32
- # Use the test pipeline directly
33
- custom_test_pipeline = model.cfg.data.test.pipeline
34
 
35
- # Define the input/output model for FastAPI
36
- class PredictionOutput(BaseModel):
37
- t1: List[float]
38
- t2: List[float]
39
- t3: List[float]
40
- prediction: List[float]
41
 
42
- # Define the inference function
43
- def inference_on_file(file_path, model, custom_test_pipeline):
44
- with rasterio.open(file_path) as src:
45
- img = src.read()
46
 
47
- # Apply preprocessing using the custom pipeline
48
- processed_img = apply_pipeline(custom_test_pipeline, img)
49
-
50
- # Run inference
51
- output = model.inference(processed_img)
52
-
53
- # Post-process the output to get the RGB and prediction images
54
- rgb1 = postprocess_output(output[0])
55
- rgb2 = postprocess_output(output[1])
56
- rgb3 = postprocess_output(output[2])
57
-
58
- return rgb1, rgb2, rgb3, output
59
 
60
- def apply_pipeline(pipeline, img):
61
- # Implement your custom pipeline processing here
62
- # This could include normalization, resizing, etc.
63
- return img
64
 
65
- def postprocess_output(output):
66
- # Convert the model's output into an RGB image or other formats as needed
67
- return output
68
 
69
- @app.post("/predict/", response_model=PredictionOutput)
70
- async def predict(file: UploadFile = File(...)):
71
- # Read the uploaded file
72
- target_image = BytesIO(await file.read())
73
-
74
- # Save the file temporarily if needed
75
- with open("temp_image.tif", "wb") as f:
76
- f.write(target_image.getvalue())
77
 
78
- # Run the prediction
79
- rgb1, rgb2, rgb3, output = inference_on_file("temp_image.tif", model, custom_test_pipeline)
80
-
81
- # Return the results
82
- return {
83
- "t1": rgb1.tolist(),
84
- "t2": rgb2.tolist(),
85
- "t3": rgb3.tolist(),
86
- "prediction": output.tolist()
87
- }
88
-
89
- # Optional: Serve the Gradio interface (if you still want to use it with FastAPI)
90
  cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
91
  {'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
92
  {'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
@@ -316,8 +267,4 @@ with gr.Blocks() as demo:
316
  gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
317
 
318
 
319
- demo.launch()
320
-
321
- if __name__ == "__main__":
322
- run_gradio_interface()
323
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ ######### pull files
2
+ import os
 
 
 
 
 
 
3
  from huggingface_hub import hf_hub_download
4
+ config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
5
+ filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
+ token=os.environ.get("token"))
7
+ ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
8
+ filename='multi_temporal_crop_classification_Prithvi_100M.pth',
9
+ token=os.environ.get("token"))
10
+ ##########
11
+ import argparse
12
  from mmcv import Config
 
 
 
 
 
13
 
14
+ from mmseg.models import build_segmentor
 
15
 
16
+ from mmseg.datasets.pipelines import Compose, LoadImageFromFile
 
 
 
 
 
 
17
 
18
+ import rasterio
19
+ import torch
 
20
 
21
+ from mmseg.apis import init_segmentor
 
22
 
23
+ from mmcv.parallel import collate, scatter
 
 
 
 
 
24
 
25
+ import numpy as np
26
+ import glob
27
+ import os
 
28
 
29
+ import time
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ import numpy as np
32
+ import gradio as gr
33
+ from functools import partial
 
34
 
35
+ import pdb
 
 
36
 
37
+ import matplotlib.pyplot as plt
38
+
39
+ from skimage import exposure
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  cdl_color_map = [{'value': 1, 'label': 'Natural vegetation', 'rgb': (233,255,190)},
42
  {'value': 2, 'label': 'Forest', 'rgb': (149,206,147)},
43
  {'value': 3, 'label': 'Corn', 'rgb': (255,212,0)},
 
267
  gr.Image(value='Legend.png', image_mode='RGB', show_label=False)
268
 
269
 
270
+ demo.launch()