dschandra commited on
Commit
54dcee8
·
verified ·
1 Parent(s): e743240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -8
app.py CHANGED
@@ -15,17 +15,23 @@ st.set_page_config(page_title="Solar Panel Fault Detection", layout="wide")
15
 
16
  # Title and description
17
  st.title("Solar Panel Fault Detection PoC")
18
- st.write("Upload a thermal video (MP4) of a solar panel to detect thermal, dust, and power generation faults.")
 
 
 
 
 
 
 
19
 
20
  # Load model and processor
21
  @st.cache_resource
22
  def load_model():
23
- # Suppress warning about unused weights
24
  warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
25
  logging.set_verbosity_error()
26
 
27
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
28
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  model.to(device)
31
  model.eval()
@@ -33,9 +39,19 @@ def load_model():
33
 
34
  processor, model, device = load_model()
35
 
 
 
 
 
 
 
 
 
36
  # Function to process a batch of frames
37
  async def detect_faults_batch(frames, processor, model, device):
38
  try:
 
 
39
  inputs = processor(images=frames, return_tensors="pt").to(device)
40
  with torch.no_grad():
41
  outputs = model(**inputs)
@@ -71,13 +87,17 @@ async def detect_faults_batch(frames, processor, model, device):
71
  annotated_frames.append(annotated_frame)
72
  all_faults.append(faults)
73
 
 
 
 
 
74
  return annotated_frames, all_faults
75
  except Exception as e:
76
  st.error(f"Error during fault detection: {str(e)}")
77
  return [], []
78
 
79
  # Function to process video
80
- async def process_video(video_path, frame_skip=5, batch_size=4):
81
  try:
82
  cap = cv2.VideoCapture(video_path)
83
  if not cap.isOpened():
@@ -89,9 +109,13 @@ async def process_video(video_path, frame_skip=5, batch_size=4):
89
  fps = int(cap.get(cv2.CAP_PROP_FPS))
90
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
91
 
 
 
 
 
92
  output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
93
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
94
- out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
95
 
96
  video_faults = {"Thermal Fault": False, "Dust Fault": False, "Power Generation Fault": False}
97
  frame_count = 0
@@ -104,9 +128,12 @@ async def process_video(video_path, frame_skip=5, batch_size=4):
104
 
105
  while cap.isOpened():
106
  ret, frame = cap.read()
107
- if not, break
 
108
 
109
  if frame_count % frame_skip != 0:
 
 
110
  out.write(frame)
111
  frame_count += 1
112
  processed_frames += 1
@@ -166,7 +193,7 @@ if uploaded_file is not None:
166
  st.video(tfile.name, format="video/mp4")
167
 
168
  loop = asyncio.get_event_loop()
169
- output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip=5, batch_size=4))
170
 
171
  if output_path and video_faults:
172
  st.subheader("Fault Detection Results")
 
15
 
16
  # Title and description
17
  st.title("Solar Panel Fault Detection PoC")
18
+ st.write("Upload a thermal video (MP4) to detect thermal, dust, and power generation faults.")
19
+
20
+ # UI controls for optimization parameters
21
+ st.sidebar.header("Analysis Settings")
22
+ frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=20, value=10)
23
+ batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=8)
24
+ resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
25
+ resize_width = 640 if resize_enabled else None
26
 
27
  # Load model and processor
28
  @st.cache_resource
29
  def load_model():
 
30
  warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
31
  logging.set_verbosity_error()
32
 
33
+ processor = DetrImageProcessor.from_pretrained("hustvl/yolos-tiny")
34
+ model = DetrForObjectDetection.from_pretrained("hustvl/yolos-tiny")
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  model.to(device)
37
  model.eval()
 
39
 
40
  processor, model, device = load_model()
41
 
42
+ # Function to resize frame
43
+ def resize_frame(frame, width=None):
44
+ if width is None:
45
+ return frame
46
+ aspect_ratio = frame.shape[1] / frame.shape[0]
47
+ height = int(width / aspect_ratio)
48
+ return cv2.resize(frame, (width, height), interpolation=cv2.INTER_AREA)
49
+
50
  # Function to process a batch of frames
51
  async def detect_faults_batch(frames, processor, model, device):
52
  try:
53
+ # Resize frames if enabled
54
+ frames = [resize_frame(frame, resize_width) for frame in frames]
55
  inputs = processor(images=frames, return_tensors="pt").to(device)
56
  with torch.no_grad():
57
  outputs = model(**inputs)
 
87
  annotated_frames.append(annotated_frame)
88
  all_faults.append(faults)
89
 
90
+ # Clear GPU memory
91
+ if torch.cuda.is_available():
92
+ torch.cuda.empty_cache()
93
+
94
  return annotated_frames, all_faults
95
  except Exception as e:
96
  st.error(f"Error during fault detection: {str(e)}")
97
  return [], []
98
 
99
  # Function to process video
100
+ async def process_video(video_path, frame_skip, batch_size):
101
  try:
102
  cap = cv2.VideoCapture(video_path)
103
  if not cap.isOpened():
 
109
  fps = int(cap.get(cv2.CAP_PROP_FPS))
110
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
111
 
112
+ # Adjust output size if resizing
113
+ out_width = resize_width if resize_width else frame_width
114
+ out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
115
+
116
  output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
117
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
118
+ out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height))
119
 
120
  video_faults = {"Thermal Fault": False, "Dust Fault": False, "Power Generation Fault": False}
121
  frame_count = 0
 
128
 
129
  while cap.isOpened():
130
  ret, frame = cap.read()
131
+ if not ret:
132
+ break
133
 
134
  if frame_count % frame_skip != 0:
135
+ # Resize frame for output if needed
136
+ frame = resize_frame(frame, resize_width)
137
  out.write(frame)
138
  frame_count += 1
139
  processed_frames += 1
 
193
  st.video(tfile.name, format="video/mp4")
194
 
195
  loop = asyncio.get_event_loop()
196
+ output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip, batch_size))
197
 
198
  if output_path and video_faults:
199
  st.subheader("Fault Detection Results")