Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -179,38 +179,39 @@ def predict_batch(contents, policies):
|
|
179 |
with torch.inference_mode():
|
180 |
outputs = model(input_ids)
|
181 |
|
182 |
-
|
183 |
-
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
|
189 |
-
|
190 |
-
|
191 |
|
192 |
-
|
193 |
-
|
194 |
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
207 |
|
208 |
# Create Gradio interface
|
209 |
iface = gr.Interface(
|
210 |
fn=predict_batch,
|
211 |
inputs=[gr.Textbox(label="Content", lines=2, value=DEFAULT_CONTENT),
|
212 |
gr.Textbox(label="Policy", lines=10, value=DEFAULT_POLICY)],
|
213 |
-
outputs=
|
214 |
batch=True,
|
215 |
max_batch_size=8,
|
216 |
title="CoPE Dev (Unstable)",
|
|
|
179 |
with torch.inference_mode():
|
180 |
outputs = model(input_ids)
|
181 |
|
182 |
+
# Get logits for the last tokens
|
183 |
+
logits = outputs.logits[:, -1, :]
|
184 |
|
185 |
+
# Get token IDs for "0" and "1"
|
186 |
+
token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
|
187 |
+
token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]
|
188 |
|
189 |
+
# Extract logits for "0" and "1"
|
190 |
+
binary_logits = logits[:, [token_id_0, token_id_1]]
|
191 |
|
192 |
+
# Apply softmax to get probabilities for these two tokens
|
193 |
+
probabilities = F.softmax(binary_logits, dim=-1)
|
194 |
|
195 |
+
probs_0 = probabilities[:, 0].cpu().numpy()
|
196 |
+
probs_1 = probabilities[:, 1].cpu().numpy()
|
197 |
|
198 |
+
results = []
|
199 |
+
for prob_0, prob_1 in zip(probs_0, probs_1):
|
200 |
+
if prob_1 > prob_0:
|
201 |
+
output = f'VIOLATING\n(P: {prob_1:.2f})'
|
202 |
+
else:
|
203 |
+
output = f'NON-Violating\n(P: {prob_0:.2f})'
|
204 |
+
results.append(output)
|
205 |
+
print(results)
|
206 |
+
#return results
|
207 |
+
return "\n\n".join(results)
|
208 |
|
209 |
# Create Gradio interface
|
210 |
iface = gr.Interface(
|
211 |
fn=predict_batch,
|
212 |
inputs=[gr.Textbox(label="Content", lines=2, value=DEFAULT_CONTENT),
|
213 |
gr.Textbox(label="Policy", lines=10, value=DEFAULT_POLICY)],
|
214 |
+
outputs=gr.Textbox(label="Results"),
|
215 |
batch=True,
|
216 |
max_batch_size=8,
|
217 |
title="CoPE Dev (Unstable)",
|