dixisouls commited on
Commit
89ee5b3
·
1 Parent(s): 2bbab64

Initial API deployment

Browse files
Files changed (5) hide show
  1. Dockerfile +33 -0
  2. api.py +134 -0
  3. app/scene_graph_service.py +885 -0
  4. download_model.py +37 -0
  5. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ libgl1-mesa-glx \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements file
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy application code
19
+ COPY app/ ./app/
20
+ COPY api.py .
21
+ COPY download_model.py .
22
+
23
+ # Create necessary directories
24
+ RUN mkdir -p uploads outputs app/models
25
+
26
+ # Download model on build
27
+ RUN python download_model.py
28
+
29
+ # Expose port for the API
30
+ EXPOSE 7860
31
+
32
+ # Command to run the application
33
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
api.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import shutil
6
+ import uuid
7
+ import logging
8
+ from typing import Dict, List, Any
9
+ import json
10
+
11
+ # Import scene graph service
12
+ from app.scene_graph_service import process_image
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Create necessary directories
19
+ os.makedirs("uploads", exist_ok=True)
20
+ os.makedirs("outputs", exist_ok=True)
21
+ os.makedirs("app/models", exist_ok=True)
22
+
23
+ # Initialize FastAPI app
24
+ app = FastAPI(title="Scene Graph Generation API")
25
+
26
+ # Add CORS middleware
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ @app.get("/")
36
+ def read_root():
37
+ return {
38
+ "message": "Scene Graph Generation API is running",
39
+ "usage": "POST /generate with an image file to generate a scene graph",
40
+ "docs": "Visit /docs for API documentation"
41
+ }
42
+
43
+ @app.post("/generate")
44
+ async def generate_scene_graph(
45
+ image: UploadFile = File(...),
46
+ confidence_threshold: float = Form(0.5),
47
+ use_fixed_boxes: bool = Form(False),
48
+ ) -> Dict[str, Any]:
49
+ try:
50
+ # Input validation
51
+ if not image.content_type.startswith("image/"):
52
+ raise HTTPException(
53
+ status_code=400, detail="Uploaded file must be an image"
54
+ )
55
+
56
+ if not (0 <= confidence_threshold <= 1):
57
+ raise HTTPException(
58
+ status_code=400, detail="Confidence threshold must be between 0 and 1"
59
+ )
60
+
61
+ # Generate unique ID for this job
62
+ job_id = str(uuid.uuid4())
63
+ short_id = job_id.split("-")[0]
64
+
65
+ # Create directories for this job
66
+ upload_dir = os.path.join("uploads", job_id)
67
+ output_dir = os.path.join("outputs", job_id)
68
+ os.makedirs(upload_dir, exist_ok=True)
69
+ os.makedirs(output_dir, exist_ok=True)
70
+
71
+ # Save the uploaded image
72
+ original_filename = image.filename
73
+ _, ext = os.path.splitext(original_filename)
74
+ image_filename = f"{short_id}{ext}"
75
+ image_path = os.path.join(upload_dir, image_filename)
76
+
77
+ # Save the file
78
+ with open(image_path, "wb") as buffer:
79
+ shutil.copyfileobj(image.file, buffer)
80
+
81
+ logger.info(f"Image saved to {image_path}")
82
+
83
+ # Define model paths
84
+ model_path = "app/models/model.pth"
85
+ vocabulary_path = "app/models/vocabulary.json"
86
+
87
+ # Process the image
88
+ objects, relationships, annotated_image_path, graph_path = process_image(
89
+ image_path=image_path,
90
+ model_path=model_path,
91
+ vocabulary_path=vocabulary_path,
92
+ confidence_threshold=confidence_threshold,
93
+ use_fixed_boxes=use_fixed_boxes,
94
+ output_dir=output_dir,
95
+ base_filename=short_id,
96
+ )
97
+
98
+ # Read the generated images as base64
99
+ with open(annotated_image_path, "rb") as img_file:
100
+ annotated_image_base64 = base64.b64encode(img_file.read()).decode("utf-8")
101
+
102
+ with open(graph_path, "rb") as img_file:
103
+ graph_image_base64 = base64.b64encode(img_file.read()).decode("utf-8")
104
+
105
+ # Prepare response with base64 encoded images
106
+ response = {
107
+ "objects": objects,
108
+ "relationships": relationships,
109
+ "annotated_image": annotated_image_base64,
110
+ "graph_image": graph_image_base64
111
+ }
112
+
113
+ # Clean up
114
+ try:
115
+ shutil.rmtree(upload_dir)
116
+ shutil.rmtree(output_dir)
117
+ except Exception as e:
118
+ logger.warning(f"Error cleaning up temporary files: {str(e)}")
119
+
120
+ return response
121
+
122
+ except Exception as e:
123
+ logger.error(f"Error processing image: {str(e)}")
124
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
125
+
126
+
127
+ @app.get("/health")
128
+ def health_check():
129
+ return {"status": "healthy"}
130
+
131
+
132
+ if __name__ == "__main__":
133
+ import uvicorn
134
+ uvicorn.run(app, host="0.0.0.0", port=7860)
app/scene_graph_service.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import networkx as nx
7
+ from PIL import Image
8
+ import torchvision.transforms as T
9
+ from typing import Dict, List, Tuple, Any, Union, Optional
10
+ import logging
11
+
12
+ # Import from your existing code
13
+ from ultralytics import YOLO
14
+ from math import isclose
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # Set random seeds for reproducibility
22
+ def set_seeds(seed=42):
23
+ import random
24
+
25
+ random.seed(seed)
26
+ np.random.seed(seed)
27
+ torch.manual_seed(seed)
28
+ torch.cuda.manual_seed_all(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+
33
+ # Call this at the start
34
+ set_seeds(42)
35
+
36
+ # Configuration
37
+ CONFIG = {
38
+ "img_size": 512,
39
+ "model": {
40
+ "backbone": "resnet50",
41
+ "embedding_dim": 512,
42
+ "hidden_dim": 256,
43
+ },
44
+ "yolo": {
45
+ "model": "yolov8n.pt", # Using the smallest YOLOv8 model for speed
46
+ "conf": 0.25, # Default confidence threshold
47
+ "iou": 0.45, # Default IoU threshold for NMS
48
+ },
49
+ }
50
+
51
+
52
+ # Vocabulary class
53
+ class Vocabulary:
54
+ """Vocabulary for objects, attributes, and relationships in scene graphs."""
55
+
56
+ def __init__(self):
57
+ # Initialize dictionaries for mapping between terms and IDs
58
+ self.object2id = {"<unk>": 0}
59
+ self.id2object = {0: "<unk>"}
60
+ self.relationship2id = {"<unk>": 0}
61
+ self.id2relationship = {0: "<unk>"}
62
+ self.attribute2id = {"<unk>": 0}
63
+ self.id2attribute = {0: "<unk>"}
64
+
65
+ def get_object_id(self, obj_name: str) -> int:
66
+ return self.object2id.get(obj_name, 0) # Return <unk> ID if not found
67
+
68
+ def get_relationship_id(self, rel_name: str) -> int:
69
+ return self.relationship2id.get(rel_name, 0) # Return <unk> ID if not found
70
+
71
+ def get_attribute_id(self, attr_name: str) -> int:
72
+ return self.attribute2id.get(attr_name, 0) # Return <unk> ID if not found
73
+
74
+ def get_object_name(self, obj_id: int) -> str:
75
+ return self.id2object.get(obj_id, "<unk>")
76
+
77
+ def get_relationship_name(self, rel_id: int) -> str:
78
+ return self.id2relationship.get(rel_id, "<unk>")
79
+
80
+ def get_attribute_name(self, attr_id: int) -> str:
81
+ return self.id2attribute.get(attr_id, "<unk>")
82
+
83
+ @classmethod
84
+ def load(cls, path: str) -> "Vocabulary":
85
+ """Load vocabulary from a JSON file."""
86
+ vocab = cls()
87
+
88
+ with open(path, "r") as f:
89
+ data = json.load(f)
90
+
91
+ # Load objects
92
+ vocab.object2id = data["objects"]
93
+ vocab.id2object = {
94
+ int(k): v for k, v in {v: k for k, v in vocab.object2id.items()}.items()
95
+ }
96
+
97
+ # Load relationships
98
+ vocab.relationship2id = data["relationships"]
99
+ vocab.id2relationship = {
100
+ int(k): v
101
+ for k, v in {v: k for k, v in vocab.relationship2id.items()}.items()
102
+ }
103
+
104
+ # Load attributes
105
+ vocab.attribute2id = data["attributes"]
106
+ vocab.id2attribute = {
107
+ int(k): v for k, v in {v: k for k, v in vocab.attribute2id.items()}.items()
108
+ }
109
+
110
+ return vocab
111
+
112
+
113
+ # Model Architecture
114
+ class VisualFeatureEncoder(torch.nn.Module):
115
+ """Visual feature encoder for scene graph generation."""
116
+
117
+ def __init__(
118
+ self,
119
+ backbone_name: str = "resnet50",
120
+ pretrained: bool = False,
121
+ ):
122
+ super().__init__()
123
+
124
+ self.backbone_name = backbone_name
125
+ self.backbone, self.out_channels = self._get_backbone(backbone_name, pretrained)
126
+
127
+ def _get_backbone(
128
+ self, backbone_name: str, pretrained: bool
129
+ ) -> Tuple[torch.nn.Module, int]:
130
+ """Get backbone network and output channels."""
131
+ if backbone_name == "resnet50":
132
+ from torchvision.models import resnet50
133
+
134
+ backbone = resnet50(pretrained=pretrained)
135
+ # Remove the last FC layer
136
+ backbone = torch.nn.Sequential(*list(backbone.children())[:-2])
137
+ out_channels = 2048
138
+ else:
139
+ raise ValueError(f"Unsupported backbone: {backbone_name}")
140
+
141
+ return backbone, out_channels
142
+
143
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
144
+ """Extract features from images."""
145
+ return self.backbone(x)
146
+
147
+
148
+ class RelationshipPredictor(torch.nn.Module):
149
+ """Predicts relationships between object pairs."""
150
+
151
+ def __init__(
152
+ self,
153
+ num_obj_classes: int,
154
+ num_rel_classes: int,
155
+ obj_embed_dim: int = 256,
156
+ rel_embed_dim: int = 256,
157
+ hidden_dim: int = 512,
158
+ dropout: float = 0.2,
159
+ ):
160
+ super().__init__()
161
+
162
+ # Object embeddings
163
+ self.obj_embedding = torch.nn.Embedding(num_obj_classes, obj_embed_dim)
164
+
165
+ # Spatial feature extractor
166
+ self.spatial_fc = torch.nn.Sequential(
167
+ torch.nn.Linear(10, 64), # 10 = 5 (subject) + 5 (object) spatial features
168
+ torch.nn.ReLU(),
169
+ torch.nn.Dropout(dropout),
170
+ torch.nn.Linear(64, 128),
171
+ torch.nn.ReLU(),
172
+ )
173
+
174
+ # Visual feature fusion
175
+ self.visual_fusion = torch.nn.Sequential(
176
+ torch.nn.Linear(obj_embed_dim * 2 + 128, hidden_dim),
177
+ torch.nn.ReLU(),
178
+ torch.nn.Dropout(dropout),
179
+ torch.nn.Linear(hidden_dim, hidden_dim),
180
+ torch.nn.ReLU(),
181
+ )
182
+
183
+ # Relationship classifier
184
+ self.rel_classifier = torch.nn.Linear(hidden_dim, num_rel_classes)
185
+
186
+ def forward(
187
+ self,
188
+ obj_features: List[torch.Tensor],
189
+ obj_boxes: List[torch.Tensor],
190
+ obj_pairs: List[torch.Tensor],
191
+ ) -> Dict[str, List[torch.Tensor]]:
192
+ """Forward pass for relationship prediction."""
193
+ results = {}
194
+ all_rel_logits = []
195
+
196
+ # Process each example in the batch
197
+ for i, (feats, boxes, pairs) in enumerate(
198
+ zip(obj_features, obj_boxes, obj_pairs)
199
+ ):
200
+ if len(pairs) == 0 or boxes.size(0) == 0:
201
+ # No relationships to predict
202
+ all_rel_logits.append(None)
203
+ continue
204
+
205
+ # Extract object classes from boxes
206
+ obj_classes = boxes[:, 4].long()
207
+ obj_embeds = self.obj_embedding(obj_classes)
208
+
209
+ # Create pairs of object features
210
+ subj_idx = pairs[:, 0].long()
211
+ obj_idx = pairs[:, 1].long()
212
+
213
+ subj_feats = obj_embeds[subj_idx]
214
+ obj_feats = obj_embeds[obj_idx]
215
+
216
+ # Spatial features
217
+ subj_boxes = boxes[subj_idx, :4] # [x_c, y_c, w, h]
218
+ obj_boxes = boxes[obj_idx, :4] # [x_c, y_c, w, h]
219
+
220
+ # Compute relative spatial features
221
+ delta_x = subj_boxes[:, 0] - obj_boxes[:, 0]
222
+ delta_y = subj_boxes[:, 1] - obj_boxes[:, 1]
223
+
224
+ # Concatenate spatial features
225
+ spatial_feats = torch.cat(
226
+ [subj_boxes, obj_boxes, delta_x.unsqueeze(1), delta_y.unsqueeze(1)],
227
+ dim=1,
228
+ )
229
+
230
+ spatial_feats = self.spatial_fc(spatial_feats)
231
+
232
+ # Concatenate subject and object features
233
+ subj_obj_feats = torch.cat([subj_feats, obj_feats, spatial_feats], dim=1)
234
+
235
+ # Visual fusion
236
+ fused_feats = self.visual_fusion(subj_obj_feats)
237
+
238
+ # Predict relationships
239
+ rel_logits = self.rel_classifier(fused_feats)
240
+ all_rel_logits.append(rel_logits)
241
+
242
+ results["rel_logits"] = all_rel_logits
243
+ return results
244
+
245
+
246
+ class SceneGraphGenerationModel(torch.nn.Module):
247
+ """Complete scene graph generation model."""
248
+
249
+ def __init__(
250
+ self,
251
+ backbone: torch.nn.Module,
252
+ num_obj_classes: int,
253
+ num_rel_classes: int,
254
+ num_attr_classes: int,
255
+ roi_size: int = 7,
256
+ embedding_dim: int = 512,
257
+ hidden_dim: int = 256,
258
+ dropout: float = 0.0,
259
+ ):
260
+ super().__init__()
261
+
262
+ self.backbone = backbone
263
+ self.num_obj_classes = num_obj_classes
264
+ self.num_rel_classes = num_rel_classes
265
+
266
+ # RoI pooling for object features
267
+ self.roi_size = roi_size
268
+ self.roi_pool = torch.nn.AdaptiveAvgPool2d((roi_size, roi_size))
269
+
270
+ # Object feature embedding
271
+ self.obj_feature_embedding = torch.nn.Sequential(
272
+ torch.nn.Linear(backbone.out_channels * roi_size * roi_size, embedding_dim),
273
+ torch.nn.ReLU(),
274
+ torch.nn.Dropout(dropout),
275
+ )
276
+
277
+ # Object classifier
278
+ self.obj_classifier = torch.nn.Linear(embedding_dim, num_obj_classes)
279
+
280
+ # Attribute classifier
281
+ self.attr_classifier = torch.nn.Linear(embedding_dim, num_attr_classes)
282
+
283
+ # Bounding box regressor
284
+ self.bbox_regressor = torch.nn.Linear(embedding_dim, 4) # [x_c, y_c, w, h]
285
+
286
+ # Relationship predictor
287
+ self.relationship_predictor = RelationshipPredictor(
288
+ num_obj_classes=num_obj_classes,
289
+ num_rel_classes=num_rel_classes,
290
+ obj_embed_dim=embedding_dim,
291
+ hidden_dim=hidden_dim,
292
+ dropout=dropout,
293
+ )
294
+
295
+ def extract_roi_features(
296
+ self,
297
+ features: torch.Tensor, # [batch_size, channels, height, width]
298
+ boxes: List[
299
+ torch.Tensor
300
+ ], # List of [num_boxes, 4] tensors with normalized boxes
301
+ ) -> List[torch.Tensor]:
302
+ """Extract RoI features for objects."""
303
+ batch_size = features.shape[0]
304
+ roi_features = []
305
+
306
+ for i in range(batch_size):
307
+ if len(boxes[i]) == 0:
308
+ # No objects in this image
309
+ roi_features.append(
310
+ torch.empty(
311
+ 0,
312
+ self.backbone.out_channels * self.roi_size**2,
313
+ device=features.device,
314
+ )
315
+ )
316
+ continue
317
+
318
+ # Convert normalized [x_c, y_c, w, h] to [x1, y1, x2, y2]
319
+ bbox = boxes[i][:, :4]
320
+ x_c, y_c, w, h = bbox[:, 0], bbox[:, 1], bbox[:, 2], bbox[:, 3]
321
+ x1 = (x_c - w / 2) * features.shape[3]
322
+ y1 = (y_c - h / 2) * features.shape[2]
323
+ x2 = (x_c + w / 2) * features.shape[3]
324
+ y2 = (y_c + h / 2) * features.shape[2]
325
+
326
+ # Ensure boxes are within image
327
+ x1 = torch.clamp(x1, 0, features.shape[3] - 1)
328
+ y1 = torch.clamp(y1, 0, features.shape[2] - 1)
329
+ x2 = torch.clamp(x2, 0, features.shape[3] - 1)
330
+ y2 = torch.clamp(y2, 0, features.shape[2] - 1)
331
+
332
+ # Create RoI boxes for torchvision's RoIPool
333
+ rois = torch.stack([x1, y1, x2, y2], dim=1)
334
+
335
+ # Extract features for each RoI
336
+ obj_features = []
337
+ for roi in rois:
338
+ x1, y1, x2, y2 = map(int, roi.cpu().numpy())
339
+ # Ensure valid box dimensions
340
+ if x2 <= x1 or y2 <= y1:
341
+ roi_feat = torch.zeros(
342
+ self.backbone.out_channels,
343
+ self.roi_size,
344
+ self.roi_size,
345
+ device=features.device,
346
+ )
347
+ else:
348
+ # Extract feature for this ROI
349
+ roi_feat = self.roi_pool(
350
+ features[i, :, y1:y2, x1:x2].unsqueeze(0)
351
+ ).squeeze(0)
352
+
353
+ # Flatten the feature
354
+ roi_feat = roi_feat.view(-1)
355
+ obj_features.append(roi_feat)
356
+
357
+ if obj_features:
358
+ obj_features = torch.stack(obj_features)
359
+ else:
360
+ obj_features = torch.empty(
361
+ 0,
362
+ self.backbone.out_channels * self.roi_size**2,
363
+ device=features.device,
364
+ )
365
+
366
+ roi_features.append(obj_features)
367
+
368
+ return roi_features
369
+
370
+ def forward(
371
+ self, images: torch.Tensor, boxes: List[torch.Tensor]
372
+ ) -> Dict[str, Any]:
373
+ """Forward pass for scene graph generation."""
374
+ batch_size = images.shape[0]
375
+
376
+ # Extract features from backbone
377
+ features = self.backbone(images)
378
+
379
+ # Extract RoI features
380
+ roi_features = self.extract_roi_features(features, boxes)
381
+
382
+ # Process each example in the batch
383
+ obj_logits_list = []
384
+ attr_logits_list = []
385
+ bbox_pred_list = []
386
+ obj_features_list = []
387
+
388
+ for i in range(batch_size):
389
+ if roi_features[i].shape[0] == 0:
390
+ # No objects in this image
391
+ obj_logits_list.append(
392
+ torch.empty(0, self.num_obj_classes, device=images.device)
393
+ )
394
+ attr_logits_list.append(
395
+ torch.empty(0, self.num_attr_classes, device=images.device)
396
+ )
397
+ bbox_pred_list.append(torch.empty(0, 4, device=images.device))
398
+ obj_features_list.append(
399
+ torch.empty(
400
+ 0,
401
+ self.obj_feature_embedding[0].out_features,
402
+ device=images.device,
403
+ )
404
+ )
405
+ continue
406
+
407
+ # Embed RoI features
408
+ obj_feats = self.obj_feature_embedding(roi_features[i])
409
+ obj_features_list.append(obj_feats)
410
+
411
+ # Predict object classes
412
+ obj_logits = self.obj_classifier(obj_feats)
413
+ obj_logits_list.append(obj_logits)
414
+
415
+ # Predict attributes
416
+ attr_logits = self.attr_classifier(obj_feats)
417
+ attr_logits_list.append(attr_logits)
418
+
419
+ # Regress bounding box refinements
420
+ bbox_pred = self.bbox_regressor(obj_feats)
421
+ bbox_pred_list.append(bbox_pred)
422
+
423
+ # Create object pairs for relationship prediction
424
+ obj_pairs = []
425
+ for i in range(batch_size):
426
+ if boxes[i].shape[0] <= 1:
427
+ # Need at least 2 objects for relationships
428
+ obj_pairs.append(torch.empty(0, 2, device=images.device))
429
+ continue
430
+
431
+ # Create all possible object pairs
432
+ num_objs = boxes[i].shape[0]
433
+ subj_idx = torch.arange(num_objs, device=images.device).repeat_interleave(
434
+ num_objs
435
+ )
436
+ obj_idx = torch.arange(num_objs, device=images.device).repeat(num_objs)
437
+
438
+ # Exclude self-relationships
439
+ mask = subj_idx != obj_idx
440
+ pairs = torch.stack([subj_idx[mask], obj_idx[mask]], dim=1)
441
+ obj_pairs.append(pairs)
442
+
443
+ # Predict relationships
444
+ rel_preds = self.relationship_predictor(obj_features_list, boxes, obj_pairs)
445
+
446
+ return {
447
+ "obj_logits": obj_logits_list,
448
+ "attr_logits": attr_logits_list,
449
+ "bbox_pred": bbox_pred_list,
450
+ "rel_logits": rel_preds.get("rel_logits", []),
451
+ "obj_pairs": obj_pairs,
452
+ }
453
+
454
+
455
+ # YOLO-based object detection
456
+ def detect_objects_yolo(
457
+ image_path: str,
458
+ vocabulary: Vocabulary,
459
+ device: torch.device,
460
+ use_fixed_boxes: bool = False,
461
+ ) -> torch.Tensor:
462
+ """
463
+ Detect objects in an image using YOLOv8.
464
+
465
+ Args:
466
+ image_path: Path to the input image
467
+ vocabulary: Vocabulary for mapping class names
468
+ device: PyTorch device
469
+ use_fixed_boxes: Whether to use fixed boxes or YOLO detection
470
+
471
+ Returns:
472
+ Bounding boxes in format [x_c, y_c, w, h, class_id] (normalized)
473
+ """
474
+ # Load YOLOv8 model - will download if not present
475
+ yolo_model = YOLO(CONFIG["yolo"]["model"])
476
+
477
+ # Run inference
478
+ results = yolo_model(image_path)
479
+ detections = results[0]
480
+
481
+ # No detections
482
+ if len(detections.boxes) == 0:
483
+ return torch.zeros((0, 5), device=device, dtype=torch.float32)
484
+
485
+ # Process detections
486
+ boxes = []
487
+
488
+ # Get image dimensions
489
+ img = Image.open(image_path)
490
+ img_width, img_height = img.size
491
+
492
+ # YOLO class names (COCO class names)
493
+ yolo_class_names = yolo_model.names
494
+
495
+ # Create class name mapping from YOLO to our vocabulary
496
+ class_name_map = {}
497
+ for yolo_id, yolo_name in yolo_class_names.items():
498
+ # Try direct mapping first
499
+ if yolo_name in vocabulary.object2id:
500
+ class_name_map[yolo_id] = vocabulary.get_object_id(yolo_name)
501
+ # Try lowercase
502
+ elif yolo_name.lower() in vocabulary.object2id:
503
+ class_name_map[yolo_id] = vocabulary.get_object_id(yolo_name.lower())
504
+ # Fallback to <unk>
505
+ else:
506
+ class_name_map[yolo_id] = 0 # <unk>
507
+
508
+ # Process each detection
509
+ for i in range(len(detections.boxes)):
510
+ box = detections.boxes[i]
511
+
512
+ # Get class ID and confidence
513
+ cls_id = int(box.cls.item())
514
+ confidence = box.conf.item()
515
+
516
+ # Skip low-confidence detections
517
+ if confidence < CONFIG["yolo"]["conf"]:
518
+ continue
519
+
520
+ # Get bounding box in xyxy format (unnormalized)
521
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
522
+
523
+ # Convert to xywh format and normalize
524
+ x_c = ((x1 + x2) / 2) / img_width
525
+ y_c = ((y1 + y2) / 2) / img_height
526
+ w = (x2 - x1) / img_width
527
+ h = (y2 - y1) / img_height
528
+
529
+ # Map class ID to vocabulary
530
+ vocab_cls_id = class_name_map.get(cls_id, 0) # Default to <unk> if not found
531
+
532
+ # Add to boxes
533
+ boxes.append([x_c, y_c, w, h, vocab_cls_id])
534
+
535
+ # Convert to tensor with explicit float32 dtype
536
+ if boxes:
537
+ return torch.tensor(boxes, device=device, dtype=torch.float32)
538
+ else:
539
+ return torch.zeros((0, 5), device=device, dtype=torch.float32)
540
+
541
+
542
+ # Visualization functions
543
+ def visualize_image_with_boxes(
544
+ image: np.ndarray, objects: List[Dict[str, Any]], output_path: str
545
+ ) -> None:
546
+ """Visualize image with bounding boxes and labels."""
547
+ # Create figure
548
+ plt.figure(figsize=(10, 8))
549
+
550
+ # Display image
551
+ plt.imshow(image)
552
+
553
+ # Get image dimensions
554
+ img_height, img_width = image.shape[:2]
555
+
556
+ # Generate colors for classes
557
+ num_classes = len(objects)
558
+ colors = plt.cm.hsv(np.linspace(0, 1, num_classes))
559
+
560
+ # Draw bounding boxes and labels
561
+ for i, obj in enumerate(objects):
562
+ # Get bounding box
563
+ x_c, y_c, w, h = obj["bbox"]
564
+
565
+ # Scale to image size if normalized
566
+ if max(x_c, y_c, w, h) <= 1.0:
567
+ x_c *= img_width
568
+ y_c *= img_height
569
+ w *= img_width
570
+ h *= img_height
571
+
572
+ # Convert to (x1, y1, x2, y2) format
573
+ x1 = x_c - w / 2
574
+ y1 = y_c - h / 2
575
+ x2 = x_c + w / 2
576
+ y2 = y_c + h / 2
577
+
578
+ # Draw bounding box
579
+ rect = plt.Rectangle(
580
+ (x1, y1),
581
+ x2 - x1,
582
+ y2 - y1,
583
+ linewidth=2,
584
+ edgecolor=colors[i % len(colors)],
585
+ facecolor="none",
586
+ )
587
+ plt.gca().add_patch(rect)
588
+
589
+ # Draw label
590
+ plt.text(
591
+ x1,
592
+ y1 - 5,
593
+ f"{obj['label']} ({obj['score']:.2f})",
594
+ color=colors[i % len(colors)],
595
+ fontsize=10,
596
+ bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", pad=1),
597
+ )
598
+
599
+ # Add a title
600
+ plt.title("Object Detection")
601
+ plt.axis("off")
602
+
603
+ # Save the figure
604
+ plt.tight_layout()
605
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
606
+ plt.close()
607
+
608
+ logger.info(f"Annotated image saved to {output_path}")
609
+
610
+
611
+ def visualize_graph(
612
+ objects: List[Dict[str, Any]], relationships: List[Dict[str, Any]], output_path: str
613
+ ) -> None:
614
+ """Visualize relationship graph."""
615
+ # Create figure
616
+ plt.figure(figsize=(10, 8))
617
+
618
+ # Create graph
619
+ G = nx.DiGraph()
620
+
621
+ # Add nodes
622
+ for i, obj in enumerate(objects):
623
+ G.add_node(i, label=obj["label"])
624
+
625
+ # Add edges
626
+ for rel in relationships:
627
+ subj_idx = rel["subject_id"]
628
+ obj_idx = rel["object_id"]
629
+ G.add_edge(subj_idx, obj_idx, label=rel["predicate"])
630
+
631
+ # Position nodes
632
+ pos = nx.spring_layout(G, seed=42)
633
+
634
+ # Draw nodes
635
+ nx.draw_networkx_nodes(G, pos, node_size=700, node_color="skyblue", alpha=0.8)
636
+
637
+ # Draw node labels
638
+ nx.draw_networkx_labels(G, pos, font_size=10, font_weight="bold")
639
+
640
+ # Draw edges
641
+ nx.draw_networkx_edges(G, pos, width=2, alpha=0.7, arrows=True, arrowsize=15)
642
+
643
+ # Draw edge labels
644
+ nx.draw_networkx_edge_labels(
645
+ G, pos, edge_labels=nx.get_edge_attributes(G, "label"), font_size=8
646
+ )
647
+
648
+ # Add a title
649
+ plt.title("Scene Graph")
650
+ plt.axis("off")
651
+
652
+ # Save the figure
653
+ plt.tight_layout()
654
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
655
+ plt.close()
656
+
657
+ logger.info(f"Graph visualization saved to {output_path}")
658
+
659
+
660
+ def process_image(
661
+ image_path: str,
662
+ model_path: str,
663
+ vocabulary_path: str,
664
+ confidence_threshold: float = 0.5,
665
+ use_fixed_boxes: bool = False,
666
+ output_dir: str = "outputs",
667
+ base_filename: str = None,
668
+ ) -> Tuple[List, List, str, str]:
669
+ """
670
+ Process an image to generate a scene graph.
671
+
672
+ Args:
673
+ image_path: Path to the input image
674
+ model_path: Path to the model checkpoint
675
+ vocabulary_path: Path to the vocabulary file
676
+ confidence_threshold: Confidence threshold for relationships
677
+ use_fixed_boxes: Whether to use fixed boxes or YOLO detection
678
+ output_dir: Directory to save outputs
679
+ base_filename: Optional base filename to use instead of the original image name
680
+
681
+ Returns:
682
+ Tuple of (objects, relationships, annotated_image_path, graph_path)
683
+ """
684
+ # Check if files exist
685
+ if not os.path.exists(image_path):
686
+ raise FileNotFoundError(f"Image not found at {image_path}")
687
+
688
+ if not os.path.exists(model_path):
689
+ raise FileNotFoundError(f"Model not found at {model_path}")
690
+
691
+ if not os.path.exists(vocabulary_path):
692
+ raise FileNotFoundError(f"Vocabulary not found at {vocabulary_path}")
693
+
694
+ # Create output directory if it doesn't exist
695
+ os.makedirs(output_dir, exist_ok=True)
696
+
697
+ # Load vocabulary
698
+ vocabulary = Vocabulary.load(vocabulary_path)
699
+ logger.info(
700
+ f"Loaded vocabulary with {len(vocabulary.object2id)} objects and {len(vocabulary.relationship2id)} relationships"
701
+ )
702
+
703
+ # Set device
704
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
705
+ logger.info(f"Using device: {device}")
706
+
707
+ # Load and preprocess image
708
+ image = Image.open(image_path).convert("RGB")
709
+ img_width, img_height = image.size
710
+
711
+ # Use YOLO for object detection
712
+ logger.info("Detecting objects with YOLO...")
713
+ boxes = detect_objects_yolo(image_path, vocabulary, device, use_fixed_boxes)
714
+ logger.info(f"Detected {len(boxes)} objects")
715
+
716
+ if len(boxes) == 0:
717
+ raise ValueError("No objects detected. Cannot generate scene graph.")
718
+
719
+ # Create encoder
720
+ encoder = VisualFeatureEncoder(backbone_name=CONFIG["model"]["backbone"])
721
+
722
+ # Create model
723
+ model = SceneGraphGenerationModel(
724
+ backbone=encoder,
725
+ num_obj_classes=len(vocabulary.object2id),
726
+ num_rel_classes=len(vocabulary.relationship2id),
727
+ num_attr_classes=len(vocabulary.attribute2id),
728
+ embedding_dim=CONFIG["model"]["embedding_dim"],
729
+ hidden_dim=CONFIG["model"]["hidden_dim"],
730
+ )
731
+
732
+ # Load model weights
733
+ logger.info(f"Loading model from {model_path}...")
734
+ checkpoint = torch.load(model_path, map_location=device)
735
+ if "model_state_dict" in checkpoint:
736
+ model.load_state_dict(checkpoint["model_state_dict"])
737
+ logger.info("Loaded model state dict from checkpoint")
738
+ else:
739
+ model.load_state_dict(checkpoint)
740
+ logger.info("Loaded direct model state from checkpoint")
741
+
742
+ model.to(device)
743
+ model.eval()
744
+
745
+ # Preprocess image for scene graph model
746
+ transform = T.Compose(
747
+ [
748
+ T.Resize((CONFIG["img_size"], CONFIG["img_size"])),
749
+ T.ToTensor(),
750
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
751
+ ]
752
+ )
753
+ img_tensor = transform(image).unsqueeze(0).to(device)
754
+
755
+ # Run inference for scene graph generation
756
+ logger.info("Generating scene graph...")
757
+ with torch.no_grad():
758
+ # Forward pass
759
+ outputs = model(img_tensor, [boxes])
760
+
761
+ # Process predictions
762
+ obj_logits = outputs["obj_logits"][0]
763
+ obj_probs = torch.softmax(obj_logits, dim=1)
764
+ obj_scores, obj_labels = torch.max(obj_probs, dim=1)
765
+
766
+ # Get bounding box predictions
767
+ bbox_pred = outputs["bbox_pred"][0]
768
+
769
+ # Create object list
770
+ objects = []
771
+ for i in range(len(obj_labels)):
772
+ bbox = bbox_pred[i].cpu().numpy().tolist()
773
+ label_id = obj_labels[i].item()
774
+ score = obj_scores[i].item()
775
+
776
+ objects.append(
777
+ {
778
+ "label": vocabulary.get_object_name(label_id),
779
+ "label_id": label_id,
780
+ "score": score,
781
+ "bbox": bbox,
782
+ }
783
+ )
784
+
785
+ # Process relationships
786
+ relationships = []
787
+ if "rel_logits" in outputs and outputs["rel_logits"]:
788
+ rel_logits = outputs["rel_logits"][0]
789
+ obj_pairs = outputs["obj_pairs"][0]
790
+
791
+ if rel_logits is not None and len(rel_logits) > 0:
792
+ rel_probs = torch.softmax(rel_logits, dim=1)
793
+ rel_scores, rel_labels = torch.max(rel_probs, dim=1)
794
+
795
+ # Filter by confidence
796
+ rel_mask = rel_scores > confidence_threshold
797
+ rel_labels = rel_labels[rel_mask]
798
+ rel_scores = rel_scores[rel_mask]
799
+ filtered_pairs = obj_pairs[rel_mask]
800
+
801
+ # Create relationship list
802
+ for i in range(len(rel_labels)):
803
+ subj_idx = filtered_pairs[i, 0].item()
804
+ obj_idx = filtered_pairs[i, 1].item()
805
+ label_id = rel_labels[i].item()
806
+ score = rel_scores[i].item()
807
+
808
+ # Map to filtered object indices
809
+ subj_new_idx = -1
810
+ obj_new_idx = -1
811
+
812
+ for j, obj in enumerate(objects):
813
+ if j == subj_idx:
814
+ subj_new_idx = j
815
+ if j == obj_idx:
816
+ obj_new_idx = j
817
+
818
+ if subj_new_idx != -1 and obj_new_idx != -1:
819
+ relationships.append(
820
+ {
821
+ "subject_id": subj_new_idx,
822
+ "object_id": obj_new_idx,
823
+ "predicate": vocabulary.get_relationship_name(label_id),
824
+ "predicate_id": label_id,
825
+ "score": score,
826
+ "subject": objects[subj_new_idx]["label"],
827
+ "object": objects[obj_new_idx]["label"],
828
+ }
829
+ )
830
+
831
+ # Determine base filename for output files
832
+ if base_filename:
833
+ # Use provided base filename if specified
834
+ file_prefix = base_filename
835
+ else:
836
+ # Otherwise use the original image name
837
+ file_prefix = os.path.splitext(os.path.basename(image_path))[0]
838
+
839
+ # Generate output filenames with consistent naming pattern
840
+ annotated_image_path = os.path.join(output_dir, f"{file_prefix}_annotated.png")
841
+ graph_path = os.path.join(output_dir, f"{file_prefix}_graph.png")
842
+
843
+ # Log the paths for debugging
844
+ logger.info(f"Using file prefix: {file_prefix}")
845
+ logger.info(f"Saving annotated image to: {annotated_image_path}")
846
+ logger.info(f"Saving graph to: {graph_path}")
847
+
848
+ # Save visualizations
849
+ visualize_image_with_boxes(np.array(image), objects, annotated_image_path)
850
+ visualize_graph(objects, relationships, graph_path)
851
+
852
+ logger.info(f"Visualization complete. Files saved to:")
853
+ logger.info(f" - {annotated_image_path}")
854
+ logger.info(f" - {graph_path}")
855
+
856
+ # Convert objects for JSON serialization
857
+ serializable_objects = []
858
+ for obj in objects:
859
+ serializable_objects.append(
860
+ {
861
+ "label": obj["label"],
862
+ "label_id": int(obj["label_id"]),
863
+ "score": float(obj["score"]),
864
+ "bbox": [float(val) for val in obj["bbox"]],
865
+ }
866
+ )
867
+
868
+ return serializable_objects, relationships, annotated_image_path, graph_path
869
+
870
+
871
+ if __name__ == "__main__":
872
+ # This can be used for testing the service directly
873
+ image_path = "test.jpg"
874
+ model_path = "app/models/model.pth"
875
+ vocabulary_path = "app/models/vocabulary.json"
876
+
877
+ objects, relationships, annotated_path, graph_path = process_image(
878
+ image_path=image_path,
879
+ model_path=model_path,
880
+ vocabulary_path=vocabulary_path,
881
+ confidence_threshold=0.3,
882
+ output_dir="outputs",
883
+ )
884
+
885
+ print(f"Processed {len(objects)} objects and {len(relationships)} relationships")
download_model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from huggingface_hub import hf_hub_download
4
+ import shutil
5
+
6
+ def download_models():
7
+ print("Downloading model files...")
8
+
9
+ # Create directories if they don't exist
10
+ os.makedirs("app/models", exist_ok=True)
11
+
12
+ try:
13
+ # Download the model and vocabulary from Hugging Face
14
+ model_path = hf_hub_download(
15
+ repo_id="dixisouls/scene-graph-model",
16
+ filename="model.pth",
17
+ repo_type="model"
18
+ )
19
+ vocab_path = hf_hub_download(
20
+ repo_id="dixisouls/scene-graph-model",
21
+ filename="vocabulary.json",
22
+ repo_type="model"
23
+ )
24
+
25
+ # Copy the downloaded files to the app/models directory
26
+ shutil.copy(model_path, "app/models/model.pth")
27
+ shutil.copy(vocab_path, "app/models/vocabulary.json")
28
+
29
+ print(f"Model downloaded successfully to app/models/model.pth")
30
+ print(f"Vocabulary downloaded successfully to app/models/vocabulary.json")
31
+
32
+ except Exception as e:
33
+ print(f"Error downloading model files: {e}")
34
+ sys.exit(1)
35
+
36
+ if __name__ == "__main__":
37
+ download_models()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn==0.24.0
3
+ torch==2.0.1
4
+ torchvision==0.15.2
5
+ numpy==1.24.3
6
+ Pillow==10.0.1
7
+ matplotlib==3.7.2
8
+ networkx==3.1
9
+ ultralytics==8.0.196
10
+ python-multipart==0.0.6
11
+ huggingface_hub==0.17.3