kendrickfff commited on
Commit
05b01da
·
verified ·
1 Parent(s): 9ba42c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -36
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  import gradio as gr
3
- from langchain_google_genai.chat_models import ChatGoogleGenerativeAI # Import for Gemini
4
  from PIL import Image
5
- import json
6
- from transformers import DetrImageProcessor, DetrForObjectDetection
7
  import torch
8
  import requests
 
9
 
10
  # Load credentials (stringified JSON) from environment variable
11
  credentials_string = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
@@ -25,10 +24,9 @@ os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json"
25
  # Initialize Gemini model
26
  llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
27
 
28
- # Initialize DETR model and processor
29
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
30
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
31
- model.eval()
32
 
33
  # Global chat history variable
34
  chat_history = []
@@ -42,37 +40,28 @@ def chat_with_gemini(message):
42
 
43
  def analyze_image(image_path):
44
  global chat_history
45
-
46
- # Load and preprocess image
47
- image = Image.open(image_path).convert("RGB")
48
- inputs = processor(images=image, return_tensors="pt")
49
-
50
- # Inference
51
- with torch.no_grad():
52
- outputs = model(**inputs)
53
-
54
- # Extract predictions
55
- logits = outputs.logits
56
- boxes = outputs.pred_boxes
57
-
58
- # Filter predictions by high confidence scores
59
- scores = logits.softmax(-1)[0, :, :-1].max(-1).values
60
- high_scores_indices = scores > 0.9 # Adjust the threshold as needed
61
- predicted_classes = logits.softmax(-1)[0, high_scores_indices, :-1].argmax(-1)
62
- predicted_boxes = boxes[0, high_scores_indices].tolist()
63
-
64
- # Map class IDs to labels
65
- labels = [processor.config.id2label[idx.item()] for idx in predicted_classes]
66
-
67
- # Combine predictions
68
- predictions = [{"label": label, "box": box} for label, box in zip(labels, predicted_boxes)]
69
-
70
- # Create response
71
- if predictions:
72
- detected_objects = ', '.join([p["label"] for p in predictions])
73
- bot_response = f"The image contains: {detected_objects}."
74
- else:
75
- bot_response = "No objects with high confidence were detected."
76
 
77
  chat_history.append(("Uploaded an image for analysis", bot_response))
78
  return chat_history
 
1
  import os
2
  import gradio as gr
3
+ from transformers import DetrForObjectDetection, DetrImageProcessor
4
  from PIL import Image
 
 
5
  import torch
6
  import requests
7
+ import json
8
 
9
  # Load credentials (stringified JSON) from environment variable
10
  credentials_string = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
 
24
  # Initialize Gemini model
25
  llm = ChatGoogleGenerativeAI(model='gemini-1.5-pro')
26
 
27
+ # Load the model and processor for DETR
28
  processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
29
  model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
 
30
 
31
  # Global chat history variable
32
  chat_history = []
 
40
 
41
  def analyze_image(image_path):
42
  global chat_history
43
+ try:
44
+ # Open the image file
45
+ image = Image.open(image_path).convert("RGB")
46
+ # Preprocess the image for DETR
47
+ inputs = processor(images=image, return_tensors="pt")
48
+
49
+ # Perform inference
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ # Get predictions
54
+ target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
55
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
56
+
57
+ # If any objects detected, display labels
58
+ if len(results["labels"]) > 0:
59
+ bot_response = f"Objects detected: {', '.join(map(str, results['labels'].tolist()))}."
60
+ else:
61
+ bot_response = "No objects detected."
62
+
63
+ except Exception as e:
64
+ bot_response = f"Error processing the image: {str(e)}"
 
 
 
 
 
 
 
 
 
65
 
66
  chat_history.append(("Uploaded an image for analysis", bot_response))
67
  return chat_history