kendrickfff commited on
Commit
603c4d7
·
verified ·
1 Parent(s): b05e484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -49
app.py CHANGED
@@ -1,12 +1,35 @@
1
  import os
2
  import gradio as gr
3
- import torch
4
  from transformers import DetrImageProcessor, DetrForObjectDetection
 
5
  from PIL import Image
6
- import requests
7
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Custom Object Labels
 
 
 
 
 
 
 
 
 
 
10
  COCO_CLASSES = [
11
  'airplane', 'apple', 'backpack', 'banana', 'baseball hat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle',
12
  'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair',
@@ -18,64 +41,108 @@ COCO_CLASSES = [
18
  'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'wine glass'
19
  ]
20
 
21
- # Load the DETR model and processor
22
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
23
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
 
 
 
 
 
 
24
 
25
- # Initialize Gradio interface
26
  def analyze_image(image_path):
 
27
  try:
28
- # Open the image
29
- image = Image.open(image_path)
30
-
31
- # Preprocess the image
32
  inputs = processor(images=image, return_tensors="pt")
33
 
34
- # Perform object detection
35
- outputs = model(**inputs)
36
-
37
- # Get the logits (class predictions) and boxes (bounding boxes)
38
- logits = outputs.logits
39
- boxes = outputs.pred_boxes
40
-
41
- # Get the predicted labels (class IDs)
42
- class_ids = logits.argmax(-1)
43
-
44
- # Filter out detections with low confidence and map to custom labels
45
- results = []
46
- for idx, class_id in enumerate(class_ids[0]):
47
- confidence = logits[0, idx, class_id].item()
48
- if confidence > 0.5: # Confidence threshold
49
- label = COCO_CLASSES[class_id]
50
- box = boxes[0, idx].tolist()
51
- results.append({
52
- 'label': label,
53
- 'confidence': confidence,
54
- 'box': box
55
- })
56
-
57
- if len(results) == 0:
58
- return "No objects detected."
59
-
60
- # Generate a response with the detected objects
61
- detected_objects = "\n".join([f"{result['label']} (confidence: {result['confidence']:.2f})" for result in results])
62
- return f"Detected Objects:\n{detected_objects}"
63
 
64
- except Exception as e:
65
- return f"Error processing the image: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
67
 
68
- # Gradio Interface Setup
 
 
 
 
 
 
 
69
  with gr.Blocks() as demo:
70
- gr.Markdown("## Object Detection with Custom Labels")
71
- gr.Markdown("Upload an image for analysis!")
 
 
 
72
 
73
  # User input components
 
 
74
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
75
- output_text = gr.Textbox(label="Detection Results", interactive=False)
76
 
77
- # Define the interaction
78
- img_upload.change(analyze_image, img_upload, output_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Launch the interface
81
  demo.launch()
 
1
  import os
2
  import gradio as gr
 
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import Gemini
5
  from PIL import Image
6
+ import torch
7
  import json
8
+ import requests
9
+
10
+ # Load credentials (stringified JSON) from environment variable for Gemini
11
+ credentials_string = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
12
+ if not credentials_string:
13
+ raise ValueError("GOOGLE_APPLICATION_CREDENTIALS is not set in the environment!")
14
+
15
+ # Parse the stringified JSON back to a Python dictionary
16
+ credentials = json.loads(credentials_string)
17
+
18
+ # Save the credentials to a temporary JSON file (required by Google SDKs)
19
+ with open("service_account.json", "w") as f:
20
+ json.dump(credentials, f)
21
 
22
+ # Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the temporary file
23
+ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json"
24
+
25
+ # Initialize Gemini model (chatbot)
26
+ llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
27
+
28
+ # Initialize DETR model and processor for object detection
29
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
30
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
31
+
32
+ # Load COCO class labels (from the official COCO dataset)
33
  COCO_CLASSES = [
34
  'airplane', 'apple', 'backpack', 'banana', 'baseball hat', 'baseball glove', 'bear', 'bed', 'bench', 'bicycle',
35
  'bird', 'boat', 'book', 'bottle', 'bowl', 'broccoli', 'bus', 'cake', 'car', 'carrot', 'cat', 'cell phone', 'chair',
 
41
  'traffic light', 'train', 'truck', 'tv', 'umbrella', 'vase', 'wine glass'
42
  ]
43
 
44
+ # Global chat history variable
45
+ chat_history = []
46
+
47
+ # Function for chatting with Gemini
48
+ def chat_with_gemini(message):
49
+ global chat_history
50
+ bot_response = llm.predict(message) # This will interact with the Gemini model
51
+ chat_history.append((message, bot_response))
52
+ return chat_history
53
 
54
+ # Function for analyzing the uploaded image
55
  def analyze_image(image_path):
56
+ global chat_history
57
  try:
58
+ # Open and preprocess the image
59
+ image = Image.open(image_path).convert("RGB")
 
 
60
  inputs = processor(images=image, return_tensors="pt")
61
 
62
+ # Perform inference
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # Set a target size for post-processing
67
+ target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
68
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
69
+
70
+ # Collect detected objects (with no minimum confidence filter)
71
+ detected_objects = []
72
+ for idx, label in enumerate(results["labels"]):
73
+ # Get the object label based on label index
74
+ object_name = COCO_CLASSES[label.item()] # Assuming COCO_CLASSES is available
75
+ score = results["scores"][idx].item() # Confidence score for this detection
76
+
77
+ # Store only objects with a score higher than a threshold (e.g., 0.1)
78
+ if score > 0.1:
79
+ detected_objects.append(f"{object_name} (score: {score:.2f})")
80
 
81
+ if detected_objects:
82
+ bot_response = f"Objects detected: {', '.join(detected_objects)}."
83
+ else:
84
+ bot_response = "No objects detected."
85
 
86
+ chat_history.append(("Uploaded an image for analysis", bot_response))
87
+ return chat_history
88
+ except Exception as e:
89
+ error_msg = f"Error processing the image: {str(e)}"
90
+ chat_history.append(("Error during image analysis", error_msg))
91
+ return chat_history
92
+
93
+ # Build the Gradio interface
94
  with gr.Blocks() as demo:
95
+ gr.Markdown("# Ken Chatbot")
96
+ gr.Markdown("Ask me anything or upload an image for analysis!")
97
+
98
+ # Chatbot display without "User" or "Bot" labels
99
+ chatbot = gr.Chatbot(elem_id="chatbot")
100
 
101
  # User input components
102
+ msg = gr.Textbox(label="Type your message here...", placeholder="Enter your message...", show_label=False)
103
+ send_btn = gr.Button("Send")
104
  img_upload = gr.Image(type="filepath", label="Upload an image for analysis")
 
105
 
106
+ # Define interactions
107
+ def handle_text_message(message):
108
+ return chat_with_gemini(message)
109
+
110
+ def handle_image_upload(image_path):
111
+ return analyze_image(image_path)
112
+
113
+ # Set up Gradio components with Enter key for sending
114
+ msg.submit(handle_text_message, msg, chatbot)
115
+ send_btn.click(handle_text_message, msg, chatbot)
116
+ send_btn.click(lambda: "", None, msg) # Clear input field
117
+ img_upload.change(handle_image_upload, img_upload, chatbot)
118
+
119
+ # Custom CSS for styling without usernames
120
+ gr.HTML("""
121
+ <style>
122
+ #chatbot .message-container {
123
+ display: flex;
124
+ flex-direction: column;
125
+ margin-bottom: 10px;
126
+ max-width: 70%;
127
+ }
128
+ #chatbot .message {
129
+ border-radius: 15px;
130
+ padding: 10px;
131
+ margin: 5px 0;
132
+ word-wrap: break-word;
133
+ }
134
+ #chatbot .message.user {
135
+ background-color: #DCF8C6;
136
+ margin-left: auto;
137
+ text-align: right;
138
+ }
139
+ #chatbot .message.bot {
140
+ background-color: #E1E1E1;
141
+ margin-right: auto;
142
+ text-align: left;
143
+ }
144
+ </style>
145
+ """)
146
 
147
+ # Launch the Gradio interface
148
  demo.launch()