ddriscoll commited on
Commit
cfd6339
·
verified ·
1 Parent(s): f8c180b

Delete EurybiaMini3.0Gradio.app

Browse files
Files changed (1) hide show
  1. EurybiaMini3.0Gradio.app +0 -536
EurybiaMini3.0Gradio.app DELETED
@@ -1,536 +0,0 @@
1
- import os
2
- import glob
3
- import time
4
- import threading
5
- import requests
6
- import wikipedia
7
- import torch
8
- import cv2
9
- import numpy as np
10
- from io import BytesIO
11
- from PIL import Image
12
- import base64 # Added import
13
-
14
- import gradio as gr
15
- from ultralytics import YOLO
16
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
17
- from diffusers import MarigoldDepthPipeline # Updated import for depth model
18
- from realesrgan import RealESRGANer
19
- from basicsr.archs.rrdbnet_arch import RRDBNet
20
-
21
- # Set environment variable for PyTorch MPS fallback before importing torch
22
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
23
-
24
- # Initialize Models
25
- def initialize_models():
26
- models = {}
27
-
28
- # Device detection
29
- if torch.cuda.is_available():
30
- device = 'cuda'
31
- elif torch.backends.mps.is_available():
32
- device = 'mps'
33
- else:
34
- device = 'cpu'
35
- models['device'] = device
36
-
37
- print(f"Using device: {device}")
38
-
39
- # Initialize the RoBERTa model for question answering
40
- try:
41
- models['qa_pipeline'] = pipeline(
42
- "question-answering", model="deepset/roberta-base-squad2", device=0 if device == 'cuda' else -1)
43
- print("RoBERTa QA pipeline initialized.")
44
- except Exception as e:
45
- print(f"Error initializing the RoBERTa model: {e}")
46
- models['qa_pipeline'] = None
47
-
48
- # Initialize the Gemma model
49
- try:
50
- models['gemma_tokenizer'] = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
51
- models['gemma_model'] = AutoModelForCausalLM.from_pretrained(
52
- "google/gemma-2-2b-it",
53
- device_map="auto",
54
- torch_dtype=torch.bfloat16 if device == 'cuda' else torch.float32
55
- )
56
- print("Gemma model initialized.")
57
- except Exception as e:
58
- print(f"Error initializing the Gemma model: {e}")
59
- models['gemma_model'] = None
60
-
61
- # Initialize the depth estimation model using MarigoldDepthPipeline exactly as per your sample
62
- try:
63
- if device == 'cuda':
64
- models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained(
65
- "prs-eth/marigold-depth-lcm-v1-0",
66
- variant="fp16",
67
- torch_dtype=torch.float16
68
- ).to('cuda')
69
- else:
70
- # For CPU or MPS devices, keep on 'cpu' to avoid unsupported operators
71
- models['depth_pipe'] = MarigoldDepthPipeline.from_pretrained(
72
- "prs-eth/marigold-depth-lcm-v1-0",
73
- torch_dtype=torch.float32
74
- ).to('cpu')
75
- print("Depth estimation model initialized.")
76
- except Exception as e:
77
- error_message = f"Error initializing the depth estimation model: {e}"
78
- print(error_message)
79
- models['depth_pipe'] = None
80
- models['depth_init_error'] = error_message # Store the error message
81
-
82
- # Initialize the upscaling model
83
- try:
84
- upscaler_model_path = 'weights/RealESRGAN_x4plus.pth' # Ensure this path is correct
85
- if not os.path.exists(upscaler_model_path):
86
- print(f"Upscaling model weights not found at {upscaler_model_path}. Please download them.")
87
- models['upscaler'] = None
88
- else:
89
- # Define the model architecture
90
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
91
- num_block=23, num_grow_ch=32, scale=4)
92
-
93
- # Initialize RealESRGANer
94
- models['upscaler'] = RealESRGANer(
95
- scale=4,
96
- model_path=upscaler_model_path,
97
- model=model,
98
- pre_pad=0,
99
- half=(device == 'cuda'),
100
- device=device
101
- )
102
- print("Real-ESRGAN upscaler initialized.")
103
- except Exception as e:
104
- print(f"Error initializing the upscaling model: {e}")
105
- models['upscaler'] = None
106
-
107
- # Initialize YOLO model
108
- try:
109
- source_weights_path = "/Users/David/Downloads/WheelOfFortuneLab-DavidDriscoll/Eurybia1.3/mbari_315k_yolov8.pt"
110
- if not os.path.exists(source_weights_path):
111
- print(f"YOLO weights not found at {source_weights_path}. Please download them.")
112
- models['yolo_model'] = None
113
- else:
114
- models['yolo_model'] = YOLO(source_weights_path)
115
- print("YOLO model initialized.")
116
- except Exception as e:
117
- print(f"Error initializing YOLO model: {e}")
118
- models['yolo_model'] = None
119
-
120
- return models
121
-
122
- models = initialize_models()
123
-
124
- # Utility Functions
125
- def search_class_description(class_name):
126
- wikipedia.set_lang("en")
127
- wikipedia.set_rate_limiting(True)
128
- description = ""
129
-
130
- try:
131
- page = wikipedia.page(class_name)
132
- if page:
133
- description = page.content[:5000] # Get more content
134
- except Exception as e:
135
- print(f"Error fetching description for {class_name}: {e}")
136
-
137
- return description
138
-
139
- def search_class_image(class_name):
140
- wikipedia.set_lang("en")
141
- wikipedia.set_rate_limiting(True)
142
- img_url = ""
143
-
144
- try:
145
- page = wikipedia.page(class_name)
146
- if page:
147
- for img in page.images:
148
- if img.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
149
- img_url = img
150
- break
151
- except Exception as e:
152
- print(f"Error fetching image for {class_name}: {e}")
153
-
154
- return img_url
155
-
156
- def process_image(image):
157
- if models['yolo_model'] is None:
158
- return None, "YOLO model is not initialized.", "YOLO model is not initialized.", [], None
159
-
160
- try:
161
- if image is None:
162
- return None, "No image uploaded.", "No image uploaded.", [], None
163
-
164
- # Convert Gradio Image to OpenCV format
165
- image_np = np.array(image)
166
- if image_np.dtype != np.uint8:
167
- image_np = image_np.astype(np.uint8)
168
-
169
- if len(image_np.shape) != 3 or image_np.shape[2] != 3:
170
- return None, "Invalid image format. Please upload a RGB image.", "Invalid image format. Please upload a RGB image.", [], None
171
-
172
- image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
173
-
174
- # Store the original image before drawing bounding boxes
175
- original_image_cv = image_cv.copy()
176
- original_image_pil = Image.fromarray(cv2.cvtColor(original_image_cv, cv2.COLOR_BGR2RGB))
177
-
178
- # Perform YOLO prediction
179
- results = models['yolo_model'].predict(
180
- source=image_cv, conf=0.075)[0] # Lowered the threshold
181
-
182
- bounding_boxes = []
183
- image_processed = image_cv.copy()
184
-
185
- if results.boxes is not None:
186
- for box in results.boxes:
187
- x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
188
- class_name = models['yolo_model'].names[int(box.cls)]
189
- confidence = box.conf.item() * 100 # Convert to percentage
190
-
191
- bounding_boxes.append({
192
- "coords": (x1, y1, x2, y2),
193
- "class_name": class_name,
194
- "confidence": confidence
195
- })
196
-
197
- cv2.rectangle(image_processed, (x1, y1), (x2, y2), (0, 0, 255), 2)
198
- cv2.putText(image_processed, f'{class_name} {confidence:.2f}%',
199
- (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX,
200
- 0.9, (0, 0, 255), 2)
201
-
202
- # Convert back to PIL Image
203
- processed_image = Image.fromarray(cv2.cvtColor(image_processed, cv2.COLOR_BGR2RGB))
204
-
205
- # Prepare detection info
206
- if bounding_boxes:
207
- detection_info = "\n".join(
208
- [f'{box["class_name"]}: {box["confidence"]:.2f}%' for box in bounding_boxes]
209
- )
210
- else:
211
- detection_info = "No detections found."
212
-
213
- # Prepare detection details as Markdown
214
- if bounding_boxes:
215
- details = ""
216
- for idx, box in enumerate(bounding_boxes):
217
- class_name = box['class_name']
218
- confidence = box['confidence']
219
- description = search_class_description(class_name)
220
- img_url = search_class_image(class_name)
221
- img_md = ""
222
- if img_url:
223
- try:
224
- headers = {
225
- 'User-Agent': 'MyApp/1.0 (https://example.com/contact; [email protected])'
226
- }
227
- response = requests.get(img_url, headers=headers, timeout=10)
228
- img_data = response.content
229
- img = Image.open(BytesIO(img_data)).convert("RGB")
230
- img.thumbnail((400, 400)) # Resize for faster loading
231
- buffered = BytesIO()
232
- img.save(buffered, format="PNG")
233
- img_str = base64.b64encode(buffered.getvalue()).decode()
234
- img_md = f"![{class_name}](data:image/png;base64,{img_str})\n\n"
235
- except Exception as e:
236
- print(f"Error fetching image for {class_name}: {e}")
237
- details += f"### {idx+1}. {class_name} ({confidence:.2f}%)\n\n"
238
- if description:
239
- details += f"{description}\n\n"
240
- if img_md:
241
- details += f"{img_md}\n\n"
242
- detection_details_md = details
243
- else:
244
- detection_details_md = "No detections to show."
245
-
246
- return processed_image, detection_info, detection_details_md, bounding_boxes, original_image_pil
247
- except Exception as e:
248
- print(f"Error processing image: {e}")
249
- return None, f"Error processing image: {e}", f"Error processing image: {e}", [], None
250
-
251
- def ask_eurybia(question, state):
252
- if not question.strip():
253
- return "Please enter a valid question.", state
254
-
255
- if not state['bounding_boxes']:
256
- return "No detected objects to ask about.", state
257
-
258
- # Combine descriptions of all detected objects as context
259
- context = ""
260
- for box in state['bounding_boxes']:
261
- description = search_class_description(box['class_name'])
262
- if description:
263
- context += description + "\n"
264
-
265
- if not context.strip():
266
- return "No sufficient context available to answer the question.", state
267
-
268
- try:
269
- if models['qa_pipeline'] is None:
270
- return "QA pipeline is not initialized.", state
271
-
272
- answer = models['qa_pipeline'](question=question, context=context)
273
- answer_text = answer['answer'].strip()
274
- if not answer_text:
275
- return "I couldn't find an answer to that question based on the detected objects.", state
276
- return answer_text, state
277
- except Exception as e:
278
- print(f"Error during question answering: {e}")
279
- return f"Error during question answering: {e}", state
280
-
281
- def enhance_image(cropped_image_pil):
282
- if models['upscaler'] is None:
283
- return None, "Upscaling model is not initialized."
284
-
285
- try:
286
- input_image = cropped_image_pil.convert("RGB")
287
- img = np.array(input_image)
288
-
289
- # Run the model to enhance the image
290
- output, _ = models['upscaler'].enhance(img, outscale=4)
291
-
292
- enhanced_image = Image.fromarray(output)
293
-
294
- return enhanced_image, "Image enhanced successfully."
295
- except Exception as e:
296
- print(f"Error during image enhancement: {e}")
297
- return None, f"Error during image enhancement: {e}"
298
-
299
- def run_depth_prediction(original_image):
300
- if models['depth_pipe'] is None:
301
- error_msg = models.get('depth_init_error', "Depth estimation model is not initialized.")
302
- return None, error_msg
303
-
304
- try:
305
- if original_image is None:
306
- return None, "No image uploaded for depth prediction."
307
-
308
- # Prepare the image
309
- input_image = original_image.convert("RGB")
310
-
311
- # Run the depth pipeline
312
- result = models['depth_pipe'](input_image)
313
-
314
- # Access the depth prediction
315
- depth_prediction = result.prediction # Adjust based on sample code
316
-
317
- # Visualize the depth map
318
- vis_depth = models['depth_pipe'].image_processor.visualize_depth(depth_prediction)
319
-
320
- # Ensure vis_depth is a list and extract the first image
321
- if isinstance(vis_depth, list) and len(vis_depth) > 0:
322
- vis_depth_image = vis_depth[0]
323
- else:
324
- vis_depth_image = vis_depth # Fallback if not a list
325
-
326
- return vis_depth_image, "Depth prediction completed."
327
- except Exception as e:
328
- print(f"Error during depth prediction: {e}")
329
- return None, f"Error during depth prediction: {e}"
330
-
331
- # Gradio Interface Components
332
- with gr.Blocks() as demo:
333
- gr.Markdown("# Eurybia Mini - Object Detection and Analysis Tool")
334
-
335
- with gr.Tab("Upload & Process"):
336
- with gr.Row():
337
- with gr.Column():
338
- image_input = gr.Image(type="pil", label="Upload Image")
339
- process_button = gr.Button("Process Image")
340
- clear_button = gr.Button("Clear")
341
- with gr.Column():
342
- processed_image = gr.Image(type="pil", label="Processed Image")
343
- detection_info = gr.Textbox(label="Detection Information", lines=10)
344
-
345
- with gr.Tab("Detection Details"):
346
- with gr.Accordion("Click to see detection details", open=False):
347
- detection_details_md = gr.Markdown("No detections to show.")
348
-
349
- with gr.Tab("Ask Eurybia"):
350
- with gr.Row():
351
- with gr.Column():
352
- question_input = gr.Textbox(label="Ask a question about the detected objects")
353
- ask_button = gr.Button("Ask Eurybia")
354
- with gr.Column():
355
- answer_output = gr.Markdown(label="Eurybia's Answer")
356
-
357
- with gr.Tab("Depth Estimation"):
358
- with gr.Row():
359
- with gr.Column():
360
- depth_button = gr.Button("Run Depth Prediction")
361
- with gr.Column():
362
- depth_output = gr.Image(type="pil", label="Depth Map")
363
- depth_status = gr.Textbox(label="Status", lines=2)
364
-
365
- # Display error message if depth estimation model failed to initialize
366
- if models.get('depth_init_error'):
367
- gr.Markdown(f"**Depth Estimation Initialization Error:** {models['depth_init_error']}")
368
-
369
- with gr.Tab("Enhance Detected Objects"):
370
- if models['yolo_model'] is not None and models['upscaler'] is not None:
371
- with gr.Row():
372
- detected_objects = gr.Dropdown(choices=[], label="Select Detected Object", interactive=True)
373
- enhance_btn = gr.Button("Enhance Image")
374
- with gr.Column():
375
- enhanced_image = gr.Image(type="pil", label="Enhanced Image")
376
- enhance_status = gr.Textbox(label="Status", lines=2)
377
- else:
378
- gr.Markdown("**Warning:** YOLO model or Upscaling model is not initialized. Image enhancement functionality will be unavailable.")
379
-
380
- with gr.Tab("Credits"):
381
- gr.Markdown("""
382
- # Credits and Licensing Information
383
-
384
- This project utilizes various open-source libraries, tools, pretrained models, and datasets. Below is the list of components used and their respective credits/licenses:
385
-
386
- ## Libraries
387
- - **Python** - Python Software Foundation License (PSF License)
388
- - **Gradio** - Licensed under the Apache License 2.0
389
- - **Torch (PyTorch)** - Licensed under the BSD 3-Clause License
390
- - **OpenCV (cv2)** - Licensed under the Apache License 2.0
391
- - **NumPy** - Licensed under the BSD License
392
- - **Pillow (PIL)** - Licensed under the HPND License
393
- - **Requests** - Licensed under the Apache License 2.0
394
- - **Wikipedia API** - Licensed under the MIT License
395
- - **Transformers** - Licensed under the Apache License 2.0
396
- - **Diffusers** - Licensed under the Apache License 2.0
397
- - **Real-ESRGAN** - Licensed under the MIT License
398
- - **BasicSR** - Licensed under the Apache License 2.0
399
- - **Ultralytics YOLO** - Licensed under the GPL-3.0 License
400
-
401
- ## Pretrained Models
402
- - **deepset/roberta-base-squad2 (RoBERTa)** - Model provided by Hugging Face under the Apache License 2.0.
403
- - **google/gemma-2-2b-it** - Model provided by Hugging Face under the Apache License 2.0.
404
- - **prs-eth/marigold-depth-lcm-v1-0** - Licensed under the Apache License 2.0.
405
- - **Real-ESRGAN model weights (RealESRGAN_x4plus.pth)** - Distributed under the MIT License.
406
- - **FathomNet MBARI 315K YOLOv8 Model**:
407
- - **Dataset**: Sourced from [FathomNet](https://fathomnet.org).
408
- - **Model**: Derived from MBARI’s curated dataset of 315,000 marine annotations.
409
- - **License**: Dataset and models adhere to MBARI’s Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0).
410
-
411
- ## Datasets
412
- - **FathomNet MBARI Dataset**:
413
- - A large-scale dataset for marine biodiversity image annotations.
414
- - All content adheres to the [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/).
415
-
416
- ## Acknowledgments
417
- - **Ultralytics YOLO**: For the YOLOv8 architecture used for object detection.
418
- - **FathomNet and MBARI**: For providing the marine dataset and annotations that support object detection in underwater imagery.
419
- - **Gradio**: For providing an intuitive interface for machine learning applications.
420
- - **Hugging Face**: For pretrained models and pipelines (e.g., Transformers, Diffusers).
421
- - **Real-ESRGAN**: For image enhancement and upscaling models.
422
- - **Wikipedia API**: For fetching object descriptions and images.
423
- """)
424
-
425
- # Hidden state to store bounding boxes, original and processed images
426
- state = gr.State({"bounding_boxes": [], "last_image": None, "original_image": None})
427
-
428
- # Event Handlers
429
- def on_process_image(image, state):
430
- processed_img, info, details, bounding_boxes, original_image_pil = process_image(image)
431
- if processed_img is not None:
432
- # Update the state with new bounding boxes and images
433
- state['bounding_boxes'] = bounding_boxes
434
- state['last_image'] = processed_img
435
- state['original_image'] = original_image_pil
436
- # Update the dropdown choices for detected objects
437
- choices = [f"{idx+1}. {box['class_name']} ({box['confidence']:.2f}%)" for idx, box in enumerate(bounding_boxes)]
438
- else:
439
- choices = []
440
- return processed_img, info, details, gr.update(choices=choices), state
441
-
442
- process_button.click(
443
- on_process_image,
444
- inputs=[image_input, state],
445
- outputs=[processed_image, detection_info, detection_details_md, detected_objects, state]
446
- )
447
-
448
- def on_clear(state):
449
- state = {"bounding_boxes": [], "last_image": None, "original_image": None}
450
- return None, "No detections found.", "No detections to show.", gr.update(choices=[]), state
451
-
452
- clear_button.click(
453
- on_clear,
454
- inputs=state,
455
- outputs=[processed_image, detection_info, detection_details_md, detected_objects, state]
456
- )
457
-
458
- def on_ask_eurybia(question, state):
459
- answer, state = ask_eurybia(question, state)
460
- return answer, state
461
-
462
- ask_button.click(
463
- on_ask_eurybia,
464
- inputs=[question_input, state],
465
- outputs=[answer_output, state]
466
- )
467
-
468
- def on_depth_prediction(state):
469
- original_image = state.get('original_image')
470
- depth_img, status = run_depth_prediction(original_image)
471
- return depth_img, status
472
-
473
- depth_button.click(
474
- on_depth_prediction,
475
- inputs=state,
476
- outputs=[depth_output, depth_status]
477
- )
478
-
479
- def on_enhance_image(selected_object, state):
480
- if not selected_object:
481
- return None, "No object selected.", state
482
-
483
- try:
484
- idx = int(selected_object.split('.')[0]) - 1
485
- box = state['bounding_boxes'][idx]
486
- class_name = box['class_name']
487
- x1, y1, x2, y2 = box['coords']
488
-
489
- if not state.get('last_image'):
490
- return None, "Processed image is not available.", state
491
-
492
- # Ensure processed_image is stored in state
493
- processed_img_pil = state['last_image']
494
- if not isinstance(processed_img_pil, Image.Image):
495
- return None, "Processed image is in an unsupported format.", state
496
-
497
- # Convert processed_image to OpenCV format with checks
498
- processed_img_cv = np.array(processed_img_pil)
499
- if processed_img_cv.dtype != np.uint8:
500
- processed_img_cv = processed_img_cv.astype(np.uint8)
501
-
502
- if len(processed_img_cv.shape) != 3 or processed_img_cv.shape[2] != 3:
503
- return None, "Invalid processed image format.", state
504
-
505
- processed_img_cv = cv2.cvtColor(processed_img_cv, cv2.COLOR_RGB2BGR)
506
-
507
- # Crop the detected object from the processed image
508
- cropped_img_cv = processed_img_cv[y1:y2, x1:x2]
509
- if cropped_img_cv.size == 0:
510
- return None, "Cropped image is empty.", state
511
-
512
- cropped_img_pil = Image.fromarray(cv2.cvtColor(cropped_img_cv, cv2.COLOR_BGR2RGB))
513
-
514
- # Enhance the cropped image
515
- enhanced_img, status = enhance_image(cropped_img_pil)
516
- return enhanced_img, status, state
517
- except Exception as e:
518
- return None, f"Error: {e}", state
519
-
520
- if models['yolo_model'] is not None and models['upscaler'] is not None:
521
- enhance_btn.click(
522
- on_enhance_image,
523
- inputs=[detected_objects, state],
524
- outputs=[enhanced_image, enhance_status, state]
525
- )
526
-
527
- # Optional: Add a note if the depth model isn't initialized
528
- if models['depth_pipe'] is None and not models.get('depth_init_error'):
529
- gr.Markdown("**Warning:** Depth estimation model is not initialized. Depth prediction functionality will be unavailable.")
530
-
531
- # Optional: Add a note if the upscaler isn't initialized
532
- if models['upscaler'] is None:
533
- gr.Markdown("**Warning:** Upscaling model is not initialized. Image enhancement functionality will be unavailable.")
534
-
535
- # Launch the Gradio app
536
- demo.launch()