liujch1998 commited on
Commit
b921db9
β€’
1 Parent(s): be2d0e3
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -32,35 +32,35 @@ repo.git_pull()
32
  class Interactive:
33
  def __init__(self):
34
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
35
- # self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
36
- # self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1).to(device)
37
- # self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
38
- # self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
39
- # self.model.eval()
40
- # self.t = self.model.shared.weight[32097, 0].item()
41
 
42
  def run(self, statement):
43
- # input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
44
- # with torch.no_grad():
45
- # output = self.model(input_ids)
46
- # last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
47
- # hidden = last_hidden_state[0, -1, :] # (D)
48
- # logit = self.linear(hidden).squeeze(-1) # ()
49
- # logit_calibrated = logit / self.t
50
- # score = logit.sigmoid()
51
- # score_calibrated = logit_calibrated.sigmoid()
52
- # return {
53
- # 'logit': logit.item(),
54
- # 'logit_calibrated': logit_calibrated.item(),
55
- # 'score': score.item(),
56
- # 'score_calibrated': score_calibrated.item(),
57
- # }
58
  return {
59
- 'logit': 0.0,
60
- 'logit_calibrated': 0.0,
61
- 'score': 0.5,
62
- 'score_calibrated': 0.5,
63
  }
 
 
 
 
 
 
64
 
65
  interactive = Interactive()
66
 
 
32
  class Interactive:
33
  def __init__(self):
34
  self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
35
+ self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
36
+ self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device)
37
+ self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
38
+ self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
39
+ self.model.eval()
40
+ self.t = self.model.shared.weight[32097, 0].item()
41
 
42
  def run(self, statement):
43
+ input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
44
+ with torch.no_grad():
45
+ output = self.model(input_ids)
46
+ last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
47
+ hidden = last_hidden_state[0, -1, :] # (D)
48
+ logit = self.linear(hidden).squeeze(-1) # ()
49
+ logit_calibrated = logit / self.t
50
+ score = logit.sigmoid()
51
+ score_calibrated = logit_calibrated.sigmoid()
 
 
 
 
 
 
52
  return {
53
+ 'logit': logit.item(),
54
+ 'logit_calibrated': logit_calibrated.item(),
55
+ 'score': score.item(),
56
+ 'score_calibrated': score_calibrated.item(),
57
  }
58
+ # return {
59
+ # 'logit': 0.0,
60
+ # 'logit_calibrated': 0.0,
61
+ # 'score': 0.5,
62
+ # 'score_calibrated': 0.5,
63
+ # }
64
 
65
  interactive = Interactive()
66