simonhermansson commited on
Commit
1420df1
·
1 Parent(s): e1e20e4

Added linear reuse and price layers

Browse files
Files changed (3) hide show
  1. app.py +48 -2
  2. files/price_linear.pth +3 -0
  3. files/reuse_linear.pth +3 -0
app.py CHANGED
@@ -67,7 +67,53 @@ def predict_brand(img):
67
 
68
 
69
  def estimate_price_and_usage(img):
70
- return "Estimated price: 50-100 SEK - Usage: Reuse - Saved C02: 4 kg"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  def retrieve(query):
@@ -141,7 +187,7 @@ with gr.Blocks(
141
  predicted_brand = gr.Textbox(label="Brand", show_label=False)
142
 
143
  with gr.Column(variant="compact"):
144
- btn_estimate = gr.Button("Estimate Price, Reuse, and Saved C02").style(size="sm")
145
  text_box = gr.Textbox(label="Estimates:", show_label=False)
146
  with gr.Tab("Image Retrieval"):
147
  with gr.Row(variant="compact"):
 
67
 
68
 
69
  def estimate_price_and_usage(img):
70
+ query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
71
+
72
+ # Estimate usage
73
+ num_classes = 2
74
+ probe = torch.nn.Linear(
75
+ query_features.shape[-1],
76
+ num_classes,
77
+ dtype=torch.float16,
78
+ bias=False
79
+ )
80
+ # Load weights for the linear layer as a tensor
81
+ linear_data = torch.load("files/reuse_linear.pth")
82
+ probe.weight.data = linear_data["weight"]
83
+
84
+ # Do inference
85
+ probe.eval()
86
+ probe = probe.to(device)
87
+ output = probe(query_features)
88
+ print(output)
89
+ output = torch.softmax(output, dim=-1)
90
+ output = output.cpu().detach().numpy().astype("float32")
91
+ reuse = output.argmax(axis=-1)[0]
92
+ reuse_classes = ["Reuse", "Export"]
93
+
94
+ # Estimate price
95
+ num_classes = 4
96
+ probe = torch.nn.Linear(
97
+ query_features.shape[-1],
98
+ num_classes,
99
+ dtype=torch.float16,
100
+ bias=False
101
+ )
102
+ # Print output shape for the linear layer
103
+ # Load weights for the linear layer as a tensor
104
+ linear_data = torch.load("files/price_linear.pth")
105
+ probe.weight.data = linear_data["weight"]
106
+
107
+ # Do inference
108
+ probe.eval()
109
+ probe = probe.to(device)
110
+ output = probe(query_features)
111
+ output = torch.softmax(output, dim=-1)
112
+ output = output.cpu().detach().numpy().astype("float32")
113
+ price = output.argmax(axis=-1)[0]
114
+ price_classes = ["<50", "50-100", "100-150", ">150"]
115
+
116
+ return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
117
 
118
 
119
  def retrieve(query):
 
187
  predicted_brand = gr.Textbox(label="Brand", show_label=False)
188
 
189
  with gr.Column(variant="compact"):
190
+ btn_estimate = gr.Button("Estimate Price and Reuse").style(size="sm")
191
  text_box = gr.Textbox(label="Estimates:", show_label=False)
192
  with gr.Tab("Image Retrieval"):
193
  with gr.Row(variant="compact"):
files/price_linear.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a00e3c8fa9f78af43ab1208e14b818bea5028443e4c5260c6743e05f14b378f
3
+ size 5115
files/reuse_linear.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8998d9093128cab4269649e9dc70541940f0d8a9a92c6e02fe774a6153f5e29b
3
+ size 3067