tcm03 commited on
Commit
aef4077
·
1 Parent(s): 1060235

Enable text-only embedding request

Browse files
Files changed (1) hide show
  1. handler.py +21 -4
handler.py CHANGED
@@ -47,6 +47,14 @@ def get_image_embedding(image_base64, model, transformer):
47
  image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
48
  return image_feature.cpu().numpy().tolist()
49
 
 
 
 
 
 
 
 
 
50
  class EndpointHandler:
51
  def __init__(self, path: str = ""):
52
  """
@@ -80,9 +88,10 @@ class EndpointHandler:
80
  Returns:
81
  dict: {"embedding": [float, float, ...]}
82
  """
83
- # Parse inputs
84
  inputs = data.pop("inputs", data)
85
- if "sketch" in inputs:
 
86
  sketch_base64 = inputs.get("sketch", "")
87
  text_query = inputs.get("text", "")
88
  if not sketch_base64 or not text_query:
@@ -91,11 +100,19 @@ class EndpointHandler:
91
  # Generate Fused Embedding
92
  fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
93
  return {"embedding": fused_embedding}
94
- elif "image" in inputs:
 
95
  image_base64 = inputs.get("image", "")
96
  if not image_base64:
97
  return {"error": "Image 'image' (base64) is required input."}
98
  embedding = get_image_embedding(image_base64, self.model, self.transform)
99
  return {"embedding": embedding}
 
 
 
 
 
 
 
100
  else:
101
- return {"error": "Input 'sketch' or 'image' is required."}
 
47
  image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)
48
  return image_feature.cpu().numpy().tolist()
49
 
50
+ def get_text_embedding(text, model):
51
+ """Convert text query to tensor."""
52
+ text_tensor = preprocess_text(text)
53
+ with torch.no_grad():
54
+ text_feature = model.encode_text(text_tensor)
55
+ text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
56
+ return text_feature.cpu().numpy().tolist()
57
+
58
  class EndpointHandler:
59
  def __init__(self, path: str = ""):
60
  """
 
88
  Returns:
89
  dict: {"embedding": [float, float, ...]}
90
  """
91
+
92
  inputs = data.pop("inputs", data)
93
+ # text-sketch embedding
94
+ if len(inputs) == 2 and "sketch" in inputs and "text" in inputs:
95
  sketch_base64 = inputs.get("sketch", "")
96
  text_query = inputs.get("text", "")
97
  if not sketch_base64 or not text_query:
 
100
  # Generate Fused Embedding
101
  fused_embedding = get_fused_embedding(sketch_base64, text_query, self.model, self.transform)
102
  return {"embedding": fused_embedding}
103
+ # image-only embedding
104
+ elif len(inputs) == 1 and "image" in inputs:
105
  image_base64 = inputs.get("image", "")
106
  if not image_base64:
107
  return {"error": "Image 'image' (base64) is required input."}
108
  embedding = get_image_embedding(image_base64, self.model, self.transform)
109
  return {"embedding": embedding}
110
+ # text-only embedding
111
+ elif len(inputs) == 1 and "text" in inputs:
112
+ text_query = inputs.get("text", "")
113
+ if not text_query:
114
+ return {"error": "Text 'text' is required input."}
115
+ embedding = get_text_embedding(text_query, self.model)
116
+ return {"embedding": embedding}
117
  else:
118
+ return {"error": "Invalid request."}