tcm03
commited on
Commit
·
5c18c06
1
Parent(s):
9e0347a
Modify inference.py and Add YAML metadata for Hugging Face Hub
Browse files- README.md +11 -0
- inference.py +30 -10
README.md
CHANGED
@@ -1,3 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Image Retrieval with Text and Sketch
|
2 |
This code is for our 2022 ECCV paper [A Sketch Is Worth a Thousand Words: Image Retrieval with Text and Sketch](https://patsorn.me/projects/tsbir/)
|
3 |
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- image-retrieval
|
4 |
+
- text-sketch
|
5 |
+
- clip
|
6 |
+
- pytorch
|
7 |
+
- inference
|
8 |
+
library_name: pytorch
|
9 |
+
inference: true
|
10 |
+
---
|
11 |
+
|
12 |
# Image Retrieval with Text and Sketch
|
13 |
This code is for our 2022 ECCV paper [A Sketch Is Worth a Thousand Words: Image Retrieval with Text and Sketch](https://patsorn.me/projects/tsbir/)
|
14 |
|
inference.py
CHANGED
@@ -2,21 +2,40 @@ import torch
|
|
2 |
from PIL import Image
|
3 |
import base64
|
4 |
from io import BytesIO
|
5 |
-
|
6 |
-
|
7 |
import sys
|
|
|
8 |
sys.path.append("code")
|
9 |
from clip.model import CLIP
|
|
|
10 |
|
11 |
-
# Load Model and Utilities
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
-
model = CLIP.from_pretrained("tcm03/tsbir").to(device)
|
14 |
-
model.eval()
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def preprocess_image(image_base64):
|
21 |
"""Convert base64 encoded image to tensor."""
|
22 |
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
|
@@ -49,16 +68,17 @@ def get_fused_embedding(image_base64, text):
|
|
49 |
# Hugging Face Inference API Entry Point
|
50 |
def infer(inputs):
|
51 |
"""
|
52 |
-
Inference API entry point.
|
53 |
Inputs:
|
54 |
- 'image': Base64 encoded sketch image.
|
55 |
- 'text': Text query.
|
56 |
"""
|
|
|
57 |
image_base64 = inputs.get("image", "")
|
58 |
text_query = inputs.get("text", "")
|
59 |
if not image_base64 or not text_query:
|
60 |
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
|
61 |
-
|
62 |
# Generate Fused Embedding
|
63 |
fused_embedding = get_fused_embedding(image_base64, text_query)
|
64 |
return {"fused_embedding": fused_embedding}
|
|
|
2 |
from PIL import Image
|
3 |
import base64
|
4 |
from io import BytesIO
|
5 |
+
import json
|
|
|
6 |
import sys
|
7 |
+
|
8 |
sys.path.append("code")
|
9 |
from clip.model import CLIP
|
10 |
+
from clip.clip import _transform, tokenize
|
11 |
|
|
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
13 |
|
14 |
+
MODEL_PATH = "model/tsbir_model_final.pt"
|
15 |
+
CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"
|
16 |
+
|
17 |
+
def load_model():
|
18 |
+
"""Load the model only once."""
|
19 |
+
global model
|
20 |
+
if "model" not in globals():
|
21 |
+
with open(CONFIG_PATH, 'r') as f:
|
22 |
+
model_info = json.load(f)
|
23 |
|
24 |
+
model = CLIP(**model_info)
|
25 |
+
checkpoint = torch.load(MODEL_PATH, map_location=device)
|
26 |
+
sd = checkpoint["state_dict"]
|
27 |
+
if next(iter(sd.items()))[0].startswith('module'):
|
28 |
+
sd = {k[len('module.'):]: v for k, v in sd.items()}
|
29 |
+
|
30 |
+
model.load_state_dict(sd, strict=False)
|
31 |
+
model = model.to(device).eval()
|
32 |
+
|
33 |
+
# Initialize transformer
|
34 |
+
global transformer
|
35 |
+
transformer = _transform(model.visual.input_resolution, is_train=False)
|
36 |
+
print("Model loaded successfully.")
|
37 |
+
|
38 |
+
# Preprocessing Functions
|
39 |
def preprocess_image(image_base64):
|
40 |
"""Convert base64 encoded image to tensor."""
|
41 |
image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
|
|
|
68 |
# Hugging Face Inference API Entry Point
|
69 |
def infer(inputs):
|
70 |
"""
|
71 |
+
Inference API entry point.
|
72 |
Inputs:
|
73 |
- 'image': Base64 encoded sketch image.
|
74 |
- 'text': Text query.
|
75 |
"""
|
76 |
+
load_model() # Ensure the model is loaded once
|
77 |
image_base64 = inputs.get("image", "")
|
78 |
text_query = inputs.get("text", "")
|
79 |
if not image_base64 or not text_query:
|
80 |
return {"error": "Both 'image' (base64) and 'text' are required inputs."}
|
81 |
+
|
82 |
# Generate Fused Embedding
|
83 |
fused_embedding = get_fused_embedding(image_base64, text_query)
|
84 |
return {"fused_embedding": fused_embedding}
|