smishr-18 commited on
Commit
aef1b28
·
verified ·
1 Parent(s): 16fc36d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -66
app.py DELETED
@@ -1,66 +0,0 @@
1
- from unet import UNet
2
- import streamlit as st
3
- import torch
4
- from torchvision import transforms
5
- import albumentations as A
6
- from albumentations.pytorch import ToTensorV2
7
- from PIL import Image
8
- import numpy as np
9
-
10
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- model = UNet(3, 1, [64, 128, 256, 512]).to(device)
12
- model.load_state_dict(torch.load("best_model.pth", map_location=torch.device(device)))
13
- # Set up transformations for the input image
14
-
15
-
16
- transform = A.Compose([
17
- A.Resize(224, 224, p=1.0),
18
- ToTensorV2(),
19
- ])
20
- # Streamlit app
21
- def main():
22
- st.title("MRI segmenation App")
23
-
24
- # Upload image through Streamlit
25
- uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
26
-
27
- if uploaded_image is not None:
28
- # Display the uploaded and processed images side by side
29
- col1, col2 = st.columns(2) # Using beta_columns for side-by-side layout
30
-
31
- # Display the uploaded image in the first column
32
- col1.header("Original Image")
33
- col1.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
34
-
35
- # Process the image (replace this with your processing logic)
36
- processed_image = generate_image(uploaded_image)
37
-
38
- # Display the processed image in the second column
39
- col2.header("Processed Image")
40
- col2.image(processed_image, caption="Processed Image", use_column_width=True)
41
-
42
- # Function to generate an image using the PyTorch model
43
- def generate_image(uploaded_image):
44
- # Load the uploaded image
45
- input_image = Image.open(uploaded_image)
46
-
47
- image = np.array(input_image).astype(np.float32) / 255.
48
- # Apply transformations
49
- input_tensor = transform(image=image)["image"].unsqueeze(0)
50
-
51
- # Generate an image using the PyTorch model
52
- with torch.no_grad():
53
- input_tensor = input_tensor.type(torch.FloatTensor).to(device)
54
- pred = model(input_tensor)
55
- pred = torch.sigmoid(pred)
56
- mask = (pred > 0.6).float()
57
- mask = mask[0].permute(1, 2, 0)
58
- image = input_tensor[0].permute(1, 2, 0)
59
-
60
- mask = image + mask*0.3
61
- mask = mask.permute(2, 0, 1)
62
- mask = transforms.ToPILImage()(mask)
63
- return mask
64
-
65
- if __name__ == "__main__":
66
- main()