import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from safetensors.torch import load_file as safetensors_load_file
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
class ViTSalesModel(nn.Module):
def __init__(self):
super(ViTSalesModel, self).__init__()
self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
self.classifier = nn.Linear(self.vit.config.hidden_size, 1)
def forward(self, pixel_values, labels=None):
outputs = self.vit(pixel_values=pixel_values)
cls_output = outputs.last_hidden_state[:, 0, :]
sales = self.classifier(cls_output)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct(sales.view(-1), labels.view(-1))
return (loss, sales) if loss is not None else sales
model = ViTSalesModel()
checkpoint_path = "/content/results/checkpoint-940/model.safetensors"
state_dict = safetensors_load_file(checkpoint_path)
model.load_state_dict(state_dict)
model.eval()
max_sales_value = 100000
def predict_sales(image_path):
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0)
with torch.no_grad():
prediction = model(image)
print(prediction)
sales_prediction = prediction.item() * max_sales_value
return sales_prediction
image_path = "/content/0000.png"
predicted_sales = predict_sales(image_path)
print(f"Predicted sales: {predicted_sales}")