tcm03 commited on
Commit
8d4eb6b
·
1 Parent(s): 8cda892

Update custom handler

Browse files
Files changed (2) hide show
  1. handler.py +51 -27
  2. test.py +2 -2
handler.py CHANGED
@@ -14,6 +14,39 @@ from clip.clip import _transform, tokenize
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class EndpointHandler:
18
  def __init__(self, path: str = ""):
19
  """
@@ -43,35 +76,26 @@ class EndpointHandler:
43
  """
44
  Process the request and return the fused embedding.
45
  Args:
46
- data (dict): Includes 'image' (base64) and 'text' (str) inputs.
47
  Returns:
48
- dict: {"fused_embedding": [float, float, ...]}
49
  """
50
  # Parse inputs
51
  inputs = data.pop("inputs", data)
52
- image_base64 = inputs.get("image", "")
53
- text_query = inputs.get("text", "")
54
-
55
- if not image_base64 or not text_query:
56
- return {"error": "Both 'image' (base64) and 'text' are required inputs."}
57
-
58
- # Preprocess the image
59
- image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
60
- image_tensor = self.transform(image).unsqueeze(0).to(device)
61
-
62
- # Preprocess the text
63
- text_tensor = tokenize([str(text_query)])[0].unsqueeze(0).to(device)
64
-
65
- # Generate features
66
- with torch.no_grad():
67
- sketch_feature = self.model.encode_sketch(image_tensor)
68
- text_feature = self.model.encode_text(text_tensor)
69
-
70
- # Normalize features
71
- sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
72
- text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
73
-
74
- # Fuse features
75
- fused_embedding = self.model.feature_fuse(sketch_feature, text_feature)
76
 
77
- return {"fused_embedding": fused_embedding.cpu().numpy().tolist()}
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ def preprocess_image(image_base64, transformer):
18
+ """Convert base64 encoded sketch to tensor."""
19
+ image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
20
+ image = transformer(image).unsqueeze(0).to(device)
21
+ return image
22
+
23
+ def preprocess_text(text):
24
+ """Tokenize text query."""
25
+ return tokenize([str(text)])[0].unsqueeze(0).to(device)
26
+
27
+ def get_fused_embedding(sketch_base64, text, model):
28
+ """Fuse sketch and text features into a single embedding."""
29
+ with torch.no_grad():
30
+ sketch_tensor = preprocess_image(sketch_base64)
31
+ text_tensor = preprocess_text(text)
32
+
33
+ sketch_feature = model.encode_sketch(sketch_tensor)
34
+ text_feature = model.encode_text(text_tensor)
35
+
36
+ sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
37
+ text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
38
+
39
+ fused_embedding = model.feature_fuse(sketch_feature, text_feature)
40
+ return fused_embedding.cpu().numpy().tolist()
41
+
42
+ def get_image_embedding(image_base64, model):
43
+ """Convert base64 encoded image to tensor."""
44
+ image_tensor = preprocess_image(image_base64)
45
+ with torch.no_grad():
46
+ image_feature = model.encode_image(image_tensor)
47
+ image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
48
+ return image_feature.cpu().numpy().tolist()
49
+
50
  class EndpointHandler:
51
  def __init__(self, path: str = ""):
52
  """
 
76
  """
77
  Process the request and return the fused embedding.
78
  Args:
79
+ data (dict): Includes 'sketch' (base64) and 'text' (str) inputs, or 'image' (base64)
80
  Returns:
81
+ dict: {"embedding": [float, float, ...]}
82
  """
83
  # Parse inputs
84
  inputs = data.pop("inputs", data)
85
+ if "sketch" in inputs:
86
+ sketch_base64 = inputs.get("sketch", "")
87
+ text_query = inputs.get("text", "")
88
+ if not sketch_base64 or not text_query:
89
+ return {"error": "Both 'sketch' (base64) and 'text' are required inputs."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Generate Fused Embedding
92
+ fused_embedding = get_fused_embedding(sketch_base64, text_query)
93
+ return {"embedding": fused_embedding}
94
+ elif "image" in inputs:
95
+ image_base64 = inputs.get("image", "")
96
+ if not image_base64:
97
+ return {"error": "Image 'image' (base64) is required input."}
98
+ embedding = get_image_embedding(image_base64)
99
+ return {"embedding": embedding}
100
+ else:
101
+ return {"error": "Input 'sketch' or 'image' is required."}
test.py CHANGED
@@ -10,13 +10,13 @@ def encode_image_to_base64(image_path):
10
  handler = EndpointHandler(path=".")
11
 
12
  # Prepare sample inputs
13
- image_path = "path_to_your_sketch_image.jpg" # Replace with your image path
14
  base64_image = encode_image_to_base64(image_path)
15
  text_query = "A pink flower"
16
 
17
  # Create payload
18
  payload = {
19
- "image": base64_image,
20
  "text": text_query
21
  }
22
 
 
10
  handler = EndpointHandler(path=".")
11
 
12
  # Prepare sample inputs
13
+ image_path = "sketches/COCO_val2014_000000163852.jpg"
14
  base64_image = encode_image_to_base64(image_path)
15
  text_query = "A pink flower"
16
 
17
  # Create payload
18
  payload = {
19
+ "sketch": base64_image,
20
  "text": text_query
21
  }
22