tcm03 commited on
Commit
af896ec
·
1 Parent(s): e9e7244

Add custom inference script for text and sketch

Browse files
Files changed (1) hide show
  1. inference.py +64 -0
inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import base64
4
+ from io import BytesIO
5
+ from transformers import AutoTokenizer
6
+
7
+ import sys
8
+ sys.path.append("code")
9
+ from clip.model import CLIP
10
+
11
+ # Load Model and Utilities
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = CLIP.from_pretrained("tcm03/tsbir").to(device)
14
+ model.eval()
15
+
16
+ # Preprocessing Functions
17
+ from clip.clip import _transform, tokenize
18
+ transformer = _transform(model.visual.input_resolution, is_train=False)
19
+
20
+ def preprocess_image(image_base64):
21
+ """Convert base64 encoded image to tensor."""
22
+ image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
23
+ image = transformer(image).unsqueeze(0).to(device)
24
+ return image
25
+
26
+ def preprocess_text(text):
27
+ """Tokenize text query."""
28
+ return tokenize([str(text)])[0].unsqueeze(0).to(device)
29
+
30
+ def get_fused_embedding(image_base64, text):
31
+ """Fuse sketch and text features into a single embedding."""
32
+ with torch.no_grad():
33
+ # Preprocess Inputs
34
+ image_tensor = preprocess_image(image_base64)
35
+ text_tensor = preprocess_text(text)
36
+
37
+ # Extract Features
38
+ sketch_feature = model.encode_sketch(image_tensor)
39
+ text_feature = model.encode_text(text_tensor)
40
+
41
+ # Normalize Features
42
+ sketch_feature = sketch_feature / sketch_feature.norm(dim=-1, keepdim=True)
43
+ text_feature = text_feature / text_feature.norm(dim=-1, keepdim=True)
44
+
45
+ # Fuse Features
46
+ fused_embedding = model.feature_fuse(sketch_feature, text_feature)
47
+ return fused_embedding.cpu().numpy().tolist()
48
+
49
+ # Hugging Face Inference API Entry Point
50
+ def infer(inputs):
51
+ """
52
+ Inference API entry point.
53
+ Inputs:
54
+ - 'image': Base64 encoded sketch image.
55
+ - 'text': Text query.
56
+ """
57
+ image_base64 = inputs.get("image", "")
58
+ text_query = inputs.get("text", "")
59
+ if not image_base64 or not text_query:
60
+ return {"error": "Both 'image' (base64) and 'text' are required inputs."}
61
+
62
+ # Generate Fused Embedding
63
+ fused_embedding = get_fused_embedding(image_base64, text_query)
64
+ return {"fused_embedding": fused_embedding}