samidh commited on
Commit
17d9a5e
·
verified ·
1 Parent(s): c39cdfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -179,38 +179,39 @@ def predict_batch(contents, policies):
179
  with torch.inference_mode():
180
  outputs = model(input_ids)
181
 
182
- # Get logits for the last tokens
183
- logits = outputs.logits[:, -1, :]
184
 
185
- # Get token IDs for "0" and "1"
186
- token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
187
- token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]
188
 
189
- # Extract logits for "0" and "1"
190
- binary_logits = logits[:, [token_id_0, token_id_1]]
191
 
192
- # Apply softmax to get probabilities for these two tokens
193
- probabilities = F.softmax(binary_logits, dim=-1)
194
 
195
- probs_0 = probabilities[:, 0].cpu().numpy()
196
- probs_1 = probabilities[:, 1].cpu().numpy()
197
 
198
- results = []
199
- for prob_0, prob_1 in zip(probs_0, probs_1):
200
- if prob_1 > prob_0:
201
- output = f'VIOLATING\n(P: {prob_1:.2f})'
202
- else:
203
- output = f'NON-Violating\n(P: {prob_0:.2f})'
204
- results.append(output)
205
- print(results)
206
- return results
 
207
 
208
  # Create Gradio interface
209
  iface = gr.Interface(
210
  fn=predict_batch,
211
  inputs=[gr.Textbox(label="Content", lines=2, value=DEFAULT_CONTENT),
212
  gr.Textbox(label="Policy", lines=10, value=DEFAULT_POLICY)],
213
- outputs=[gr.Textbox(label="Results")],
214
  batch=True,
215
  max_batch_size=8,
216
  title="CoPE Dev (Unstable)",
 
179
  with torch.inference_mode():
180
  outputs = model(input_ids)
181
 
182
+ # Get logits for the last tokens
183
+ logits = outputs.logits[:, -1, :]
184
 
185
+ # Get token IDs for "0" and "1"
186
+ token_id_0 = tokenizer.encode("0", add_special_tokens=False)[0]
187
+ token_id_1 = tokenizer.encode("1", add_special_tokens=False)[0]
188
 
189
+ # Extract logits for "0" and "1"
190
+ binary_logits = logits[:, [token_id_0, token_id_1]]
191
 
192
+ # Apply softmax to get probabilities for these two tokens
193
+ probabilities = F.softmax(binary_logits, dim=-1)
194
 
195
+ probs_0 = probabilities[:, 0].cpu().numpy()
196
+ probs_1 = probabilities[:, 1].cpu().numpy()
197
 
198
+ results = []
199
+ for prob_0, prob_1 in zip(probs_0, probs_1):
200
+ if prob_1 > prob_0:
201
+ output = f'VIOLATING\n(P: {prob_1:.2f})'
202
+ else:
203
+ output = f'NON-Violating\n(P: {prob_0:.2f})'
204
+ results.append(output)
205
+ print(results)
206
+ #return results
207
+ return "\n\n".join(results)
208
 
209
  # Create Gradio interface
210
  iface = gr.Interface(
211
  fn=predict_batch,
212
  inputs=[gr.Textbox(label="Content", lines=2, value=DEFAULT_CONTENT),
213
  gr.Textbox(label="Policy", lines=10, value=DEFAULT_POLICY)],
214
+ outputs=gr.Textbox(label="Results"),
215
  batch=True,
216
  max_batch_size=8,
217
  title="CoPE Dev (Unstable)",