Spaces:
Sleeping
Sleeping
File size: 6,648 Bytes
625dedf 8d082c2 625dedf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import streamlit as st
import torch
import numpy as np
from PIL import Image
import rasterio
from rasterio.windows import Window
from tqdm.auto import tqdm
import io
import zipfile
# Assuming you have these functions defined elsewhere
import torch
import numpy as np
from PIL import Image
import albumentations as albu
import segmentation_models_pytorch as smp
from albumentations.pytorch.transforms import ToTensorV2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
# Load and prepare the model
best_model = torch.load('deeplabv3+ v15.pth', map_location=DEVICE)
best_model.eval().float()
def to_tensor(x, **kwargs):
return x.astype('float32')#.transpose(2, 0, 1)
# Preprocessing
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
def get_preprocessing():
_transform = [
albu.Resize(512, 512),
albu.Lambda(image=preprocessing_fn),
albu.Lambda(image=to_tensor, mask=to_tensor),
ToTensorV2(),
#albu.Normalize(mean=MEAN,std=STD)
]
return albu.Compose(_transform)
preprocess = get_preprocessing()
@torch.no_grad()
def process_and_predict(image, model):
# Convert PIL Image to numpy array if necessary
if isinstance(image, Image.Image):
image = np.array(image)
# Ensure image is 3-channel
if image.ndim == 2:
image = np.stack([image] * 3, axis=-1)
elif image.shape[2] == 4:
image = image[:, :, :3]
# Apply preprocessing
preprocessed = preprocess(image=image)['image']
#preprocessed=torch.tensor(preprocessed)
# Add batch dimension and move to device
input_tensor = preprocessed.unsqueeze(0).to(DEVICE)
print(input_tensor.shape)
# Predict
mask = model(input_tensor)
mask = torch.sigmoid(mask)
mask = (mask > 0.6).float()
# Convert to PIL Image
mask_image = Image.fromarray((mask.squeeze().cpu().numpy() * 255).astype(np.uint8))
return mask_image
#example
def main(image_path):
image = Image.open(image_path)
mask = process_and_predict(image, best_model)
return mask
def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4):
tiles = []
with rasterio.open(map_file) as src:
height = src.height
width = src.width
effective_tile_size = tile_size - overlap
for y in tqdm(range(0, height, effective_tile_size)):
for x in range(0, width, effective_tile_size):
batch_images = []
batch_metas = []
for i in range(batch_size):
curr_y = y + (i * effective_tile_size)
if curr_y >= height:
break
window = Window(x, curr_y, tile_size, tile_size)
out_image = src.read(window=window)
if out_image.shape[0] == 1:
out_image = np.repeat(out_image, 3, axis=0)
elif out_image.shape[0] != 3:
raise ValueError("The number of channels in the image is not supported")
out_image = np.transpose(out_image, (1, 2, 0))
tile_image = Image.fromarray(out_image.astype(np.uint8))
out_meta = src.meta.copy()
out_meta.update({
"driver": "GTiff",
"height": tile_size,
"width": tile_size,
"transform": rasterio.windows.transform(window, src.transform)
})
tile_image = np.array(tile_image)
preprocessed_tile = preprocess(image=tile_image)['image']
batch_images.append(preprocessed_tile)
batch_metas.append(out_meta)
if not batch_images:
break
# Concatenate batch images
batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0)
# Perform inference on the batch
with torch.no_grad():
batch_masks = model(batch_tensor.to(DEVICE))
batch_masks = torch.sigmoid(batch_masks)
batch_masks = (batch_masks > 0.6).float()
# Process each mask in the batch
for j, mask_tensor in enumerate(batch_masks):
mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0)
mask_array = mask_resized.squeeze().cpu().numpy()
if mask_array.any() == 1:
tiles.append([mask_array, batch_metas[j]])
return tiles
def main():
st.title("TIF File Processor")
uploaded_file = st.file_uploader("Choose a TIF file", type="tif")
if uploaded_file is not None:
st.write("File uploaded successfully!")
# Process button
if st.button("Process File"):
st.write("Processing...")
# Save the uploaded file temporarily
with open("temp.tif", "wb") as f:
f.write(uploaded_file.getbuffer())
# Process the file
best_model.float()
tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4)
st.write("Processing complete!")
# Prepare zip file for download
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
for i, (mask_array, meta) in enumerate(tiles):
# Save each tile as a separate TIF file
with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst:
dst.write(mask_array, 1)
# Add the tile to the zip file
zip_file.write(f"tile_{i}.tif")
# Offer the zip file for download
st.download_button(
label="Download processed tiles",
data=zip_buffer.getvalue(),
file_name="processed_tiles.zip",
mime="application/zip"
)
if __name__ == "__main__":
main() |