File size: 6,648 Bytes
625dedf
 
 
 
 
 
 
 
 
 
 
8d082c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625dedf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import streamlit as st
import torch
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from tqdm.auto import tqdm
import io
import zipfile

# Assuming you have these functions defined elsewhere
import torch
import numpy as np
from PIL import Image
import albumentations as albu
import segmentation_models_pytorch as smp
from albumentations.pytorch.transforms import ToTensorV2



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'

# Load and prepare the model
best_model = torch.load('deeplabv3+ v15.pth', map_location=DEVICE)
best_model.eval().float()

def to_tensor(x, **kwargs):
    return x.astype('float32')#.transpose(2, 0, 1)

# Preprocessing
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

def get_preprocessing():
    _transform = [
        albu.Resize(512, 512),
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
        ToTensorV2(),
        #albu.Normalize(mean=MEAN,std=STD)
        
    ]
    return albu.Compose(_transform)


preprocess = get_preprocessing()

@torch.no_grad()
def process_and_predict(image, model):
    # Convert PIL Image to numpy array if necessary
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    # Ensure image is 3-channel
    if image.ndim == 2:
        image = np.stack([image] * 3, axis=-1)
    elif image.shape[2] == 4:
        image = image[:, :, :3]
    
    # Apply preprocessing
    preprocessed = preprocess(image=image)['image']
    #preprocessed=torch.tensor(preprocessed)
    # Add batch dimension and move to device
    input_tensor = preprocessed.unsqueeze(0).to(DEVICE)

    print(input_tensor.shape)
    # Predict
    mask = model(input_tensor)
    mask = torch.sigmoid(mask)
    mask = (mask > 0.6).float()
    
    # Convert to PIL Image
    mask_image = Image.fromarray((mask.squeeze().cpu().numpy() * 255).astype(np.uint8))
    
    return mask_image

#example
def main(image_path):
    image = Image.open(image_path)
    mask = process_and_predict(image, best_model)
    return mask


def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
    tiles = []
    
    with rasterio.open(map_file) as src:
        height = src.height
        width = src.width
        
        effective_tile_size = tile_size - overlap
        
        for y in tqdm(range(0, height, effective_tile_size)):
            for x in range(0, width, effective_tile_size):
                batch_images = []
                batch_metas = []
                
                for i in range(batch_size):
                    curr_y = y + (i * effective_tile_size)
                    if curr_y >= height:
                        break
                    
                    window = Window(x, curr_y, tile_size, tile_size)
                    out_image = src.read(window=window)
                    
                    if out_image.shape[0] == 1:
                        out_image = np.repeat(out_image, 3, axis=0)
                    elif out_image.shape[0] != 3:
                        raise ValueError("The number of channels in the image is not supported")
                    
                    out_image = np.transpose(out_image, (1, 2, 0))
                    tile_image = Image.fromarray(out_image.astype(np.uint8))
                    
                    out_meta = src.meta.copy()
                    out_meta.update({
                        "driver": "GTiff",
                        "height": tile_size,
                        "width": tile_size,
                        "transform": rasterio.windows.transform(window, src.transform)
                    })
                    tile_image = np.array(tile_image)
                    
                    preprocessed_tile = preprocess(image=tile_image)['image']
                    batch_images.append(preprocessed_tile)
                    batch_metas.append(out_meta)
                
                if not batch_images:
                    break
                
                # Concatenate batch images
                batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0)
                # Perform inference on the batch
                with torch.no_grad():
                    batch_masks = model(batch_tensor.to(DEVICE))
                
                batch_masks = torch.sigmoid(batch_masks)
                batch_masks = (batch_masks > 0.6).float()
                
                # Process each mask in the batch
                for j, mask_tensor in enumerate(batch_masks):
                    mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0)
                    
                    mask_array = mask_resized.squeeze().cpu().numpy()
                    
                    if mask_array.any() == 1:
                        tiles.append([mask_array, batch_metas[j]])
    
    return tiles

def main():
    st.title("TIF File Processor")

    uploaded_file = st.file_uploader("Choose a TIF file", type="tif")

    if uploaded_file is not None:
        st.write("File uploaded successfully!")

        # Process button
        if st.button("Process File"):
            st.write("Processing...")
            
            # Save the uploaded file temporarily
            with open("temp.tif", "wb") as f:
                f.write(uploaded_file.getbuffer())
            
            # Process the file
            best_model.float()
            tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4)
            
            st.write("Processing complete!")

            # Prepare zip file for download
            zip_buffer = io.BytesIO()
            with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
                for i, (mask_array, meta) in enumerate(tiles):
                    # Save each tile as a separate TIF file
                    with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst:
                        dst.write(mask_array, 1)
                    
                    # Add the tile to the zip file
                    zip_file.write(f"tile_{i}.tif")

            # Offer the zip file for download
            st.download_button(
                label="Download processed tiles",
                data=zip_buffer.getvalue(),
                file_name="processed_tiles.zip",
                mime="application/zip"
            )

if __name__ == "__main__":
    main()