File size: 3,005 Bytes
0c1bad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdb6f4b
0c1bad2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fdb6f4b
6d70836
 
4e73bee
0c1bad2
 
 
 
 
 
fdb6f4b
0c1bad2
6d70836
0c1bad2
3f65192
6d70836
3f65192
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

# REQUIREMENTS
"""
!python -m pip -q install torchvision torch
!python -m pip -q install rasterio
!python -m pip -q install git+https://github.com/PatBall1/detectree2.git # in order for this to work, you must have installed gdal
!python -m pip install opencv-python
!python -m pip install requests
"""
from detectree2.preprocessing.tiling import tile_data
from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
from detectree2.models.predict import predict_on_data
from detectree2.models.train import setup_cfg
from detectron2.engine import DefaultPredictor
import rasterio
import os
import requests

#Somehow this tiles_path where the tilings are stored, only works if the absolute path is provided
#Do not use relative path

#Make sure that tiles_path ends with '/' otherwise the predict_on_data() will not work later

def create_tiles(input_path, tile_width, tile_height, tile_buffer):
    img_path = input_path
    
    current_directory = os.getcwd()
    tiles_directory = os.path.join(current_directory, "tiles/")
    if not os.path.exists(tiles_directory):
        os.makedirs(tiles_directory)

    data = rasterio.open(img_path)

    buffer = tile_buffer
    tile_width = tile_width
    tile_height = tile_height
    tile_data(data, tiles_directory, buffer, tile_width, tile_height, dtype_bool = True)

    return tiles_directory

def download_file(url, local_filename):
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192): 
                f.write(chunk)
    return local_filename

def predict(tile_path, overlap_threshold, confidence_threshold, simplify_value, store_path):
    url = "https://zenodo.org/records/10522461/files/230103_randresize_full.pth"
    trained_model = "./230103_randresize_full.pth"
    
    download_file(url=url, local_filename=trained_model)

    cfg = setup_cfg(update_model=trained_model, out_dir=store_path)

    # hash the following line if you have gpu support
    # cfg.MODEL.DEVICE = "cpu"
    predict_on_data(tile_path, predictor=DefaultPredictor(cfg))

    project_to_geojson(tile_path, tile_path + "predictions/", tile_path + "predictions_geo/")
    crowns = stitch_crowns(tile_path + "predictions_geo/", 1)
    clean = clean_crowns(crowns, overlap_threshold, confidence=confidence_threshold)
    clean = clean.set_geometry(clean.simplify(simplify_value))
    clean.to_file(store_path + "/detectree2_delin.geojson")

def run_detectree2(tif_input_path, store_path, tile_width=20, tile_height=20, tile_buffer=20, overlap_threshold=0.35, confidence_threshold=0.2, simplify_value=0.2): 
    tile_path = create_tiles(input_path=tif_input_path, tile_width=tile_width, tile_height=tile_height, tile_buffer=tile_buffer)
    print(tile_path)
    predict(tile_path=tile_path, overlap_threshold=overlap_threshold, confidence_threshold=confidence_threshold, simplify_value=simplify_value, store_path=store_path)