tcm03 commited on
Commit
5c18c06
·
1 Parent(s): 9e0347a

Modify inference.py and Add YAML metadata for Hugging Face Hub

Browse files
Files changed (2) hide show
  1. README.md +11 -0
  2. inference.py +30 -10
README.md CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  # Image Retrieval with Text and Sketch
2
  This code is for our 2022 ECCV paper [A Sketch Is Worth a Thousand Words: Image Retrieval with Text and Sketch](https://patsorn.me/projects/tsbir/)
3
 
 
1
+ ---
2
+ tags:
3
+ - image-retrieval
4
+ - text-sketch
5
+ - clip
6
+ - pytorch
7
+ - inference
8
+ library_name: pytorch
9
+ inference: true
10
+ ---
11
+
12
  # Image Retrieval with Text and Sketch
13
  This code is for our 2022 ECCV paper [A Sketch Is Worth a Thousand Words: Image Retrieval with Text and Sketch](https://patsorn.me/projects/tsbir/)
14
 
inference.py CHANGED
@@ -2,21 +2,40 @@ import torch
2
  from PIL import Image
3
  import base64
4
  from io import BytesIO
5
- from transformers import AutoTokenizer
6
-
7
  import sys
 
8
  sys.path.append("code")
9
  from clip.model import CLIP
 
10
 
11
- # Load Model and Utilities
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model = CLIP.from_pretrained("tcm03/tsbir").to(device)
14
- model.eval()
15
 
16
- # Preprocessing Functions
17
- from clip.clip import _transform, tokenize
18
- transformer = _transform(model.visual.input_resolution, is_train=False)
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def preprocess_image(image_base64):
21
  """Convert base64 encoded image to tensor."""
22
  image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
@@ -49,16 +68,17 @@ def get_fused_embedding(image_base64, text):
49
  # Hugging Face Inference API Entry Point
50
  def infer(inputs):
51
  """
52
- Inference API entry point.
53
  Inputs:
54
  - 'image': Base64 encoded sketch image.
55
  - 'text': Text query.
56
  """
 
57
  image_base64 = inputs.get("image", "")
58
  text_query = inputs.get("text", "")
59
  if not image_base64 or not text_query:
60
  return {"error": "Both 'image' (base64) and 'text' are required inputs."}
61
-
62
  # Generate Fused Embedding
63
  fused_embedding = get_fused_embedding(image_base64, text_query)
64
  return {"fused_embedding": fused_embedding}
 
2
  from PIL import Image
3
  import base64
4
  from io import BytesIO
5
+ import json
 
6
  import sys
7
+
8
  sys.path.append("code")
9
  from clip.model import CLIP
10
+ from clip.clip import _transform, tokenize
11
 
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
13
 
14
+ MODEL_PATH = "model/tsbir_model_final.pt"
15
+ CONFIG_PATH = "code/training/model_configs/ViT-B-16.json"
16
+
17
+ def load_model():
18
+ """Load the model only once."""
19
+ global model
20
+ if "model" not in globals():
21
+ with open(CONFIG_PATH, 'r') as f:
22
+ model_info = json.load(f)
23
 
24
+ model = CLIP(**model_info)
25
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
26
+ sd = checkpoint["state_dict"]
27
+ if next(iter(sd.items()))[0].startswith('module'):
28
+ sd = {k[len('module.'):]: v for k, v in sd.items()}
29
+
30
+ model.load_state_dict(sd, strict=False)
31
+ model = model.to(device).eval()
32
+
33
+ # Initialize transformer
34
+ global transformer
35
+ transformer = _transform(model.visual.input_resolution, is_train=False)
36
+ print("Model loaded successfully.")
37
+
38
+ # Preprocessing Functions
39
  def preprocess_image(image_base64):
40
  """Convert base64 encoded image to tensor."""
41
  image = Image.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
 
68
  # Hugging Face Inference API Entry Point
69
  def infer(inputs):
70
  """
71
+ Inference API entry point.
72
  Inputs:
73
  - 'image': Base64 encoded sketch image.
74
  - 'text': Text query.
75
  """
76
+ load_model() # Ensure the model is loaded once
77
  image_base64 = inputs.get("image", "")
78
  text_query = inputs.get("text", "")
79
  if not image_base64 or not text_query:
80
  return {"error": "Both 'image' (base64) and 'text' are required inputs."}
81
+
82
  # Generate Fused Embedding
83
  fused_embedding = get_fused_embedding(image_base64, text_query)
84
  return {"fused_embedding": fused_embedding}