tonyassi commited on
Commit
b92b7b0
·
verified ·
1 Parent(s): 155f1ad

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -0
README.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from transformers import ViTModel, ViTConfig
6
+ from safetensors.torch import load_file as safetensors_load_file
7
+
8
+ # Define a transform to convert PIL images to tensors
9
+ transform = transforms.Compose([
10
+ transforms.Resize((224, 224)),
11
+ transforms.ToTensor(),
12
+ ])
13
+
14
+ class ViTSalesModel(nn.Module):
15
+ def __init__(self):
16
+ super(ViTSalesModel, self).__init__()
17
+ self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
18
+ self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
19
+
20
+ def forward(self, pixel_values, labels=None):
21
+ outputs = self.vit(pixel_values=pixel_values)
22
+ cls_output = outputs.last_hidden_state[:, 0, :] # Take the [CLS] token
23
+ sales = self.classifier(cls_output)
24
+ loss = None
25
+ if labels is not None:
26
+ loss_fct = nn.MSELoss()
27
+ loss = loss_fct(sales.view(-1), labels.view(-1))
28
+ return (loss, sales) if loss is not None else sales
29
+
30
+ model = ViTSalesModel()
31
+
32
+ # Load the saved model checkpoint
33
+ checkpoint_path = "/content/results/checkpoint-940/model.safetensors"
34
+ state_dict = safetensors_load_file(checkpoint_path)
35
+ model.load_state_dict(state_dict)
36
+ model.eval()
37
+
38
+ # Maximum sales value for de-normalization (from training)
39
+ max_sales_value = 100000 # Replace with the actual max sales value used during training
40
+
41
+ def predict_sales(image_path):
42
+ # Load and preprocess the image
43
+ image = Image.open(image_path).convert('RGB')
44
+ image = transform(image).unsqueeze(0) # Add batch dimension
45
+
46
+ with torch.no_grad():
47
+ # Run the model
48
+ prediction = model(image)
49
+
50
+ print(prediction)
51
+ # De-normalize the prediction
52
+ sales_prediction = prediction.item() * max_sales_value
53
+ return sales_prediction
54
+
55
+ # Example usage
56
+ image_path = "/content/0000.png"
57
+ predicted_sales = predict_sales(image_path)
58
+ print(f"Predicted sales: {predicted_sales}")
59
+
60
+ ```