Roaoch commited on
Commit
ddde8f2
·
1 Parent(s): 59b4efe
Files changed (1) hide show
  1. src/cyberclaasic.py +2 -1
src/cyberclaasic.py CHANGED
@@ -39,7 +39,8 @@ class CyberClassic(torch.nn.Module):
39
 
40
  decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
41
 
42
- score = self.discriminator(decoded)
 
43
  index = int(torch.argmax(score))
44
 
45
  return decoded[index]
 
39
 
40
  decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
41
 
42
+ decoded_tokens = self.discriminator_tokenizer(decoded, return_tensors='pt', padding=True, truncation=True)
43
+ score = self.discriminator(decoded_tokens)
44
  index = int(torch.argmax(score))
45
 
46
  return decoded[index]