AhmedSSabir commited on
Commit
86d638a
1 Parent(s): 0edc282

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -37
app.py CHANGED
@@ -55,46 +55,70 @@ model = GPT2LMHeadModel.from_pretrained('distilgpt2', output_hidden_states = Tru
55
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
56
  #tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
57
 
58
-
59
- def cloze_prob(text):
60
-
61
- whole_text_encoding = tokenizer.encode(text)
62
- # Parse out the stem of the whole sentence (i.e., the part leading up to but not including the critical word)
63
- text_list = text.split()
64
- stem = ' '.join(text_list[:-1])
65
- stem_encoding = tokenizer.encode(stem)
66
- # cw_encoding is just the difference between whole_text_encoding and stem_encoding
67
- # note: this might not correspond exactly to the word itself
68
- cw_encoding = whole_text_encoding[len(stem_encoding):]
69
- # Run the entire sentence through the model. Then go "back in time" to look at what the model predicted for each token, starting at the stem.
70
- # Put the whole text encoding into a tensor, and get the model's comprehensive output
71
- tokens_tensor = torch.tensor([whole_text_encoding])
 
 
72
 
73
- with torch.no_grad():
74
- outputs = model(tokens_tensor)
75
- predictions = outputs[0]
76
-
77
- logprobs = []
78
- # start at the stem and get downstream probabilities incrementally from the model(see above)
79
- start = -1-len(cw_encoding)
80
- for j in range(start,-1,1):
81
- raw_output = []
82
- for i in predictions[-1][j]:
83
- raw_output.append(i.item())
84
 
85
- logprobs.append(np.log(softmax(raw_output)))
86
 
87
- # if the critical word is three tokens long, the raw_probabilities should look something like this:
88
- # [ [0.412, 0.001, ... ] ,[0.213, 0.004, ...], [0.002,0.001, 0.93 ...]]
89
- # Then for the i'th token we want to find its associated probability
90
- # this is just: raw_probabilities[i][token_index]
91
- conditional_probs = []
92
- for cw,prob in zip(cw_encoding,logprobs):
93
- conditional_probs.append(prob[cw])
94
- # now that you have all the relevant probabilities, return their product.
95
- # This is the probability of the critical word given the context before it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- return np.exp(np.sum(conditional_probs))
 
98
 
99
 
100
 
@@ -118,7 +142,8 @@ def Visual_re_ranker(caption, visual_context_label, visual_context_prob):
118
  sim = str(sim)[1:-1]
119
  sim = str(sim)[1:-1]
120
 
121
- LM = cloze_prob(caption)
 
122
  #LM = scorer.sentence_score(caption, reduce="mean")
123
  score = pow(float(LM),pow((1-float(sim))/(1+ float(sim)),1-float(visual_context_prob)))
124
 
 
55
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
56
  #tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
57
 
58
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
59
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
60
+
61
+ # def cloze_prob(text):
62
+
63
+ # whole_text_encoding = tokenizer.encode(text)
64
+ # # Parse out the stem of the whole sentence (i.e., the part leading up to but not including the critical word)
65
+ # text_list = text.split()
66
+ # stem = ' '.join(text_list[:-1])
67
+ # stem_encoding = tokenizer.encode(stem)
68
+ # # cw_encoding is just the difference between whole_text_encoding and stem_encoding
69
+ # # note: this might not correspond exactly to the word itself
70
+ # cw_encoding = whole_text_encoding[len(stem_encoding):]
71
+ # # Run the entire sentence through the model. Then go "back in time" to look at what the model predicted for each token, starting at the stem.
72
+ # # Put the whole text encoding into a tensor, and get the model's comprehensive output
73
+ # tokens_tensor = torch.tensor([whole_text_encoding])
74
 
75
+ # with torch.no_grad():
76
+ # outputs = model(tokens_tensor)
77
+ # predictions = outputs[0]
78
+
79
+ # logprobs = []
80
+ # # start at the stem and get downstream probabilities incrementally from the model(see above)
81
+ # start = -1-len(cw_encoding)
82
+ # for j in range(start,-1,1):
83
+ # raw_output = []
84
+ # for i in predictions[-1][j]:
85
+ # raw_output.append(i.item())
86
 
87
+ # logprobs.append(np.log(softmax(raw_output)))
88
 
89
+ # # if the critical word is three tokens long, the raw_probabilities should look something like this:
90
+ # # [ [0.412, 0.001, ... ] ,[0.213, 0.004, ...], [0.002,0.001, 0.93 ...]]
91
+ # # Then for the i'th token we want to find its associated probability
92
+ # # this is just: raw_probabilities[i][token_index]
93
+ # conditional_probs = []
94
+ # for cw,prob in zip(cw_encoding,logprobs):
95
+ # conditional_probs.append(prob[cw])
96
+ # # now that you have all the relevant probabilities, return their product.
97
+ # # This is the probability of the critical word given the context before it.
98
+
99
+ # return np.exp(np.sum(conditional_probs))
100
+
101
+ def sentence_prob_mean(text):
102
+ # Tokenize the input text and add special tokens
103
+ input_ids = tokenizer.encode(text, return_tensors='pt')
104
+
105
+ # Obtain model outputs
106
+ with torch.no_grad():
107
+ outputs = model(input_ids, labels=input_ids)
108
+ logits = outputs.logits # logits are the model outputs before applying softmax
109
+
110
+ # Shift logits and labels so that tokens are aligned:
111
+ shift_logits = logits[..., :-1, :].contiguous()
112
+ shift_labels = input_ids[..., 1:].contiguous()
113
+
114
+ # Calculate the softmax probabilities
115
+ probs = softmax(shift_logits, dim=-1)
116
+
117
+ # Gather the probabilities of the actual token IDs
118
+ gathered_probs = torch.gather(probs, 2, shift_labels.unsqueeze(-1)).squeeze(-1)
119
 
120
+ # Compute the mean probability across the tokens
121
+ mean_prob = torch.mean(gathered_probs).item()
122
 
123
 
124
 
 
142
  sim = str(sim)[1:-1]
143
  sim = str(sim)[1:-1]
144
 
145
+ # LM = cloze_prob(caption)
146
+ LM = sentence_prob_mean(caption)
147
  #LM = scorer.sentence_score(caption, reduce="mean")
148
  score = pow(float(LM),pow((1-float(sim))/(1+ float(sim)),1-float(visual_context_prob)))
149