import streamlit as st import torch import numpy as np from PIL import Image import rasterio from rasterio.windows import Window from tqdm.auto import tqdm import io import zipfile # Assuming you have these functions defined elsewhere import torch import numpy as np from PIL import Image import albumentations as albu import segmentation_models_pytorch as smp from albumentations.pytorch.transforms import ToTensorV2 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ENCODER = 'se_resnext50_32x4d' ENCODER_WEIGHTS = 'imagenet' # Load and prepare the model best_model = torch.load('deeplabv3+ v15.pth', map_location=DEVICE) best_model.eval().float() def to_tensor(x, **kwargs): return x.astype('float32')#.transpose(2, 0, 1) # Preprocessing preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) def get_preprocessing(): _transform = [ albu.Resize(512, 512), albu.Lambda(image=preprocessing_fn), albu.Lambda(image=to_tensor, mask=to_tensor), ToTensorV2(), #albu.Normalize(mean=MEAN,std=STD) ] return albu.Compose(_transform) preprocess = get_preprocessing() @torch.no_grad() def process_and_predict(image, model): # Convert PIL Image to numpy array if necessary if isinstance(image, Image.Image): image = np.array(image) # Ensure image is 3-channel if image.ndim == 2: image = np.stack([image] * 3, axis=-1) elif image.shape[2] == 4: image = image[:, :, :3] # Apply preprocessing preprocessed = preprocess(image=image)['image'] #preprocessed=torch.tensor(preprocessed) # Add batch dimension and move to device input_tensor = preprocessed.unsqueeze(0).to(DEVICE) print(input_tensor.shape) # Predict mask = model(input_tensor) mask = torch.sigmoid(mask) mask = (mask > 0.6).float() # Convert to PIL Image mask_image = Image.fromarray((mask.squeeze().cpu().numpy() * 255).astype(np.uint8)) return mask_image #example def main(image_path): image = Image.open(image_path) mask = process_and_predict(image, best_model) return mask def extract_tiles(map_file, model, tile_size=512, overlap=0, batch_size=4): tiles = [] with rasterio.open(map_file) as src: height = src.height width = src.width effective_tile_size = tile_size - overlap for y in tqdm(range(0, height, effective_tile_size)): for x in range(0, width, effective_tile_size): batch_images = [] batch_metas = [] for i in range(batch_size): curr_y = y + (i * effective_tile_size) if curr_y >= height: break window = Window(x, curr_y, tile_size, tile_size) out_image = src.read(window=window) if out_image.shape[0] == 1: out_image = np.repeat(out_image, 3, axis=0) elif out_image.shape[0] != 3: raise ValueError("The number of channels in the image is not supported") out_image = np.transpose(out_image, (1, 2, 0)) tile_image = Image.fromarray(out_image.astype(np.uint8)) out_meta = src.meta.copy() out_meta.update({ "driver": "GTiff", "height": tile_size, "width": tile_size, "transform": rasterio.windows.transform(window, src.transform) }) tile_image = np.array(tile_image) preprocessed_tile = preprocess(image=tile_image)['image'] batch_images.append(preprocessed_tile) batch_metas.append(out_meta) if not batch_images: break # Concatenate batch images batch_tensor = torch.cat([img.unsqueeze(0).to(DEVICE) for img in batch_images], dim=0) # Perform inference on the batch with torch.no_grad(): batch_masks = model(batch_tensor.to(DEVICE)) batch_masks = torch.sigmoid(batch_masks) batch_masks = (batch_masks > 0.6).float() # Process each mask in the batch for j, mask_tensor in enumerate(batch_masks): mask_resized = torch.nn.functional.interpolate(mask_tensor.unsqueeze(0), size=(tile_size, tile_size), mode='bilinear', align_corners=False).squeeze(0) mask_array = mask_resized.squeeze().cpu().numpy() if mask_array.any() == 1: tiles.append([mask_array, batch_metas[j]]) return tiles def main(): st.title("TIF File Processor") uploaded_file = st.file_uploader("Choose a TIF file", type="tif") if uploaded_file is not None: st.write("File uploaded successfully!") # Process button if st.button("Process File"): st.write("Processing...") # Save the uploaded file temporarily with open("temp.tif", "wb") as f: f.write(uploaded_file.getbuffer()) # Process the file best_model.float() tiles = extract_tiles("temp.tif", best_model, tile_size=512, overlap=15, batch_size=4) st.write("Processing complete!") # Prepare zip file for download zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: for i, (mask_array, meta) in enumerate(tiles): # Save each tile as a separate TIF file with rasterio.open(f"tile_{i}.tif", 'w', **meta) as dst: dst.write(mask_array, 1) # Add the tile to the zip file zip_file.write(f"tile_{i}.tif") # Offer the zip file for download st.download_button( label="Download processed tiles", data=zip_buffer.getvalue(), file_name="processed_tiles.zip", mime="application/zip" ) if __name__ == "__main__": main()