Sanjayraju30 commited on
Commit
efb1638
·
verified ·
1 Parent(s): 0d71e7b

Rename src/streamlit_app.py to app.py

Browse files
Files changed (2) hide show
  1. app.py +37 -0
  2. src/streamlit_app.py +0 -24
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # Load your pre-trained model
8
+ model = torch.load('model/your_model_file.pt')
9
+ model.eval()
10
+
11
+ # Define image transformations
12
+ transform = transforms.Compose([
13
+ transforms.Resize((224, 224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
16
+ std=[0.229, 0.224, 0.225])
17
+ ])
18
+
19
+ st.title("VIEP: Utility Pole Fault Detection")
20
+
21
+ uploaded_file = st.file_uploader("Upload an image of a utility pole", type=["jpg", "jpeg", "png"])
22
+
23
+ if uploaded_file is not None:
24
+ image = Image.open(uploaded_file).convert('RGB')
25
+ st.image(image, caption='Uploaded Image', use_column_width=True)
26
+
27
+ # Preprocess the image
28
+ input_tensor = transform(image).unsqueeze(0)
29
+
30
+ # Perform inference
31
+ with torch.no_grad():
32
+ output = model(input_tensor)
33
+ _, predicted = torch.max(output, 1)
34
+
35
+ # Map the prediction to class names
36
+ classes = ['No Fault', 'Fault Detected']
37
+ st.write(f"Prediction: {classes[predicted.item()]}")
src/streamlit_app.py DELETED
@@ -1,24 +0,0 @@
1
- import os
2
- import streamlit as st
3
- import torch
4
- import numpy as np
5
- from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
6
-
7
- # Set a writable cache directory for PyTorch
8
- torch_cache_dir = os.path.join(os.getcwd(), 'torch_cache')
9
- os.makedirs(torch_cache_dir, exist_ok=True)
10
- os.environ['TORCH_HOME'] = torch_cache_dir
11
-
12
- # Load the YOLOv5 model
13
- model = torch.hub.load('ultralytics/yolov5', 'custom', path='model/best.pt', force_reload=True)
14
-
15
- st.title("Utility Pole Fault Detection")
16
-
17
- class VideoTransformer(VideoTransformerBase):
18
- def transform(self, frame):
19
- img = frame.to_ndarray(format="bgr24")
20
- results = model(img)
21
- annotated_frame = np.squeeze(results.render())
22
- return annotated_frame
23
-
24
- webrtc_streamer(key="live", video_transformer_factory=VideoTransformer)