methodw commited on
Commit
373a0b2
·
1 Parent(s): 391a1fb
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -8,15 +8,16 @@ import faiss
8
 
9
 
10
  # Init similarity search AI model and processor
11
- torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- dino_v2_model = AutoModel.from_pretrained("./dinov2-large").to(torch_device)
13
- dino_v2_image_processor = AutoImageProcessor.from_pretrained("./dinov2-large")
 
 
14
 
15
- # Provide a sample input for tracing
16
- sample_input = dino_v2_image_processor(
17
- images=Image.new("RGB", (224, 224)), return_tensors="pt"
18
- ).to(torch_device)
19
- traced_dino_v2_model = torch.jit.trace(dino_v2_model, sample_input["pixel_values"])
20
 
21
 
22
  def process_image(image):
@@ -49,12 +50,10 @@ def process_image(image):
49
 
50
  # Extract the features from the uploaded image
51
  with torch.no_grad():
52
- inputs = dino_v2_image_processor(images=image, return_tensors="pt").to(
53
- torch_device
54
- )
55
 
56
  # Use the traced model for inference
57
- outputs = traced_dino_v2_model(**inputs)
58
 
59
  # Normalize the features before search, whatever that means
60
  embeddings = outputs.last_hidden_state
 
8
 
9
 
10
  # Init similarity search AI model and processor
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ processor = AutoImageProcessor.from_pretrained("./dinov2-large")
13
+ model = AutoModel.from_pretrained("./dinov2-large")
14
+ model.config.return_dict = False # Set return_dict to False for JIT tracing
15
+ model.to(device)
16
 
17
+ # Prepare an example input for tracing
18
+ example_input = torch.rand(1, 3, 224, 224).to(device) # Adjust size if needed
19
+ traced_model = torch.jit.trace(model, example_input)
20
+ traced_model = traced_model.to(device)
 
21
 
22
 
23
  def process_image(image):
 
50
 
51
  # Extract the features from the uploaded image
52
  with torch.no_grad():
53
+ inputs = processor(images=image, return_tensors="pt").to(device)
 
 
54
 
55
  # Use the traced model for inference
56
+ outputs = traced_model(**inputs)
57
 
58
  # Normalize the features before search, whatever that means
59
  embeddings = outputs.last_hidden_state