Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
import cv2 | |
import numpy as np | |
import tempfile | |
import os | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
import warnings | |
from transformers.utils import logging | |
# Set page configuration | |
st.set_page_config(page_title="Solar Panel Fault Detection", layout="wide") | |
# Title and description | |
st.title("Solar Panel Fault Detection PoC") | |
st.write("Upload a thermal video (MP4) to detect thermal, dust, and power generation faults.") | |
# UI controls for optimization parameters | |
st.sidebar.header("Analysis Settings") | |
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=50, value=30) | |
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=32, value=16 if torch.cuda.is_available() else 8) | |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True) | |
resize_width = 512 if resize_enabled else None | |
quantize_model = st.sidebar.checkbox("Quantize Model (faster, esp. on CPU)", value=True) | |
# Load model and processor | |
def load_model(quantize=quantize_model): | |
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used") | |
logging.set_verbosity_error() | |
try: | |
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Apply dynamic quantization if enabled | |
if quantize and device.type == "cpu": | |
model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
model.eval() | |
return processor, model, device | |
except Exception as e: | |
st.error(f"Failed to load model: {str(e)}. Check internet connection or cache (~/.cache/huggingface/hub).") | |
raise | |
processor, model, device = load_model() | |
# Function to resize frame | |
def resize_frame(frame, width=None): | |
if width is None: | |
return frame | |
aspect_ratio = frame.shape[1] / frame.shape[0] | |
height = int(width / aspect_ratio) | |
return cv2.resize(frame, (width, height), interpolation=cv2.INTER_LINEAR) | |
# Function to process a batch of frames | |
async def detect_faults_batch(frames, processor, model, device): | |
try: | |
frames = [resize_frame(frame, resize_width) for frame in frames] | |
inputs = processor(images=frames, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
target_sizes = torch.tensor([frame.shape[:2] for frame in frames]).to(device) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9) | |
annotated_frames = [] | |
all_faults = [] | |
for frame, result in zip(frames, results): | |
faults = {"Thermal Fault": False, "Dust Fault": False, "Power Generation Fault": False} | |
annotated_frame = frame.copy() | |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): | |
box = [int(i) for i in box.tolist()] | |
roi = frame[box[1]:box[3], box[0]:box[2]] | |
mean_intensity = np.mean(roi) | |
if mean_intensity > 200: | |
faults["Thermal Fault"] = True | |
cv2.rectangle(annotated_frame, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2) | |
cv2.putText(annotated_frame, "Thermal Fault", (box[0], box[1]-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
elif mean_intensity < 100: | |
faults["Dust Fault"] = True | |
cv2.rectangle(annotated_frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) | |
cv2.putText(annotated_frame, "Dust Fault", (box[0], box[1]-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
if faults["Thermal Fault"] or faults["Dust Fault"]: | |
faults["Power Generation Fault"] = True | |
annotated_frames.append(annotated_frame) | |
all_faults.append(faults) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return annotated_frames, all_faults | |
except Exception as e: | |
st.error(f"Error during fault detection: {str(e)}") | |
return [], [] | |
# Function to process video | |
async def process_video(video_path, frame_skip, batch_size): | |
try: | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
st.error("Error: Could not open video file.") | |
return None, None | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
out_width = resize_width if resize_width else frame_width | |
out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height | |
output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height)) | |
video_faults = {"Thermal Fault": False, "Dust Fault": False, "Power Generation Fault": False} | |
frame_count = 0 | |
frames_batch = [] | |
processed_frames = 0 | |
with st.spinner("Analyzing video..."): | |
progress = st.progress(0) | |
executor = ThreadPoolExecutor(max_workers=2) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % frame_skip != 0: | |
frame = resize_frame(frame, resize_width) | |
out.write(frame) | |
frame_count += 1 | |
processed_frames += 1 | |
progress.progress(min(processed_frames / total_frames, 1.0)) | |
continue | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames_batch.append(frame_rgb) | |
if len(frames_batch) >= batch_size: | |
annotated_frames, batch_faults = await detect_faults_batch(frames_batch, processor, model, device) | |
for annotated_frame, faults in zip(annotated_frames, batch_faults): | |
for fault in video_faults: | |
video_faults[fault] |= faults[fault] | |
annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) | |
out.write(annotated_frame_bgr) | |
frames_batch = [] | |
processed_frames += batch_size | |
progress.progress(min(processed_frames / total_frames, 1.0)) | |
frame_count += 1 | |
if frames_batch: | |
annotated_frames, batch_faults = await detect_faults_batch(frames_batch, processor, model, device) | |
for annotated_frame, faults in zip(annotated_frames, batch_faults): | |
for fault in video_faults: | |
video_faults[fault] |= faults[fault] | |
annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) | |
out.write(annotated_frame_bgr) | |
processed_frames += len(frames_batch) | |
progress.progress(min(processed_frames / total_frames, 1.0)) | |
cap.release() | |
out.release() | |
return output_path, video_faults | |
except Exception as e: | |
st.error(f"Error processing video: {str(e)}") | |
return None, None | |
finally: | |
if 'cap' in locals() and cap.isOpened(): | |
cap.release() | |
if 'out' in locals(): | |
out.release() | |
# File uploader | |
uploaded_file = st.file_uploader("Upload a thermal video", type=["mp4"]) | |
if uploaded_file is not None: | |
try: | |
tfile = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
tfile.write(uploaded_file.read()) | |
tfile.close() | |
st.video(tfile.name, format="video/mp4") | |
# Create a new event loop for Streamlit's ScriptRunner thread | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip, batch_size)) | |
finally: | |
loop.close() | |
if output_path and video_faults: | |
st.subheader("Fault Detection Results") | |
st.video(output_path, format="video/mp4") | |
st.write("**Detected Faults in Video:**") | |
for fault, detected in video_faults.items(): | |
status = "Detected" if detected else "Not Detected" | |
color = "red" if detected else "green" | |
st.markdown(f"- **{fault}**: <span style='color:{color}'>{status}</span>", unsafe_allow_html=True) | |
if any(video_faults.values()): | |
st.subheader("Recommendations") | |
if video_faults["Thermal Fault"]: | |
st.write("- **Thermal Fault**: Inspect for damaged components or overheating issues.") | |
if video_faults["Dust Fault"]: | |
st.write("- **Dust Fault**: Schedule cleaning to remove dust accumulation.") | |
if video_faults["Power Generation Fault"]: | |
st.write("- **Power Generation Fault**: Investigate efficiency issues due to detected faults.") | |
else: | |
st.write("No faults detected. The solar panel appears to be functioning normally.") | |
if os.path.exists(output_path): | |
os.unlink(output_path) | |
if os.path.exists(tfile.name): | |
os.unlink(tfile.name) | |
except Exception as e: | |
st.error(f"Error handling uploaded file: {str(e)}") | |
finally: | |
if 'tfile' in locals() and os.path.exists(tfile.name): | |
os.unlink(tfile.name) | |
# Footer | |
st.markdown("---") | |
st.write("Built with Streamlit, Hugging Face Transformers, and OpenCV for Solar Panel Fault Detection PoC") |