Walid Aissa commited on
Commit
288a5de
·
1 Parent(s): 17cc89d

tokeniser tweaking

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -80,62 +80,79 @@ def answer_question(question):
80
  # Apply the tokenizer to the input text, treating them as a text-pair.
81
  input_ids = tokenizer.encode(question, context)
82
 
83
- # Report how long the input sequence is. if longer than 512 tokens, make it shorter
84
- while(len(input_ids) > 512):
85
- input_ids.pop()
86
 
87
- print('Query has {:,} tokens.\n'.format(len(input_ids)))
88
 
 
 
 
 
 
 
 
89
  # ======== Set Segment IDs ========
90
  # Search the input_ids for the first instance of the `[SEP]` token.
91
- sep_index = input_ids.index(tokenizer.sep_token_id)
92
 
93
  # The number of segment A tokens includes the [SEP] token istelf.
94
- num_seg_a = sep_index + 1
95
 
96
  # The remainder are segment B.
97
- num_seg_b = len(input_ids) - num_seg_a
98
 
99
  # Construct the list of 0s and 1s.
100
- segment_ids = [0]*num_seg_a + [1]*num_seg_b
101
 
102
  # There should be a segment_id for every input token.
103
- assert len(segment_ids) == len(input_ids)
104
 
105
  # ======== Evaluate ========
106
  # Run our example through the model.
107
- outputs = model(torch.tensor([input_ids]), # The tokens representing our input text.
108
  token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
109
  return_dict=True)
110
 
111
- start_scores = outputs.start_logits
112
- end_scores = outputs.end_logits
113
- print(start_scores)
114
- print(end_scores)
 
 
 
 
 
 
115
 
116
  # ======== Reconstruct Answer ========
117
  # Find the tokens with the highest `start` and `end` scores.
118
- answer_start = torch.argmax(start_scores)
119
- answer_end = torch.argmax(end_scores)
 
120
 
 
121
  # Get the string versions of the input tokens.
122
- tokens = tokenizer.convert_ids_to_tokens(input_ids)
123
 
124
  # Start with the first token.
125
- answer = tokens[answer_start]
126
 
127
  # Select the remaining answer tokens and join them with whitespace.
128
- for i in range(answer_start + 1, answer_end + 1):
129
 
130
  # If it's a subword token, then recombine it with the previous token.
131
- if tokens[i][0:2] == '##':
132
- answer += tokens[i][2:]
133
 
134
  # Otherwise, add a space then the token.
135
- else:
136
- answer += ' ' + tokens[i]
 
 
 
 
137
 
138
- return 'Answer: "' + answer + '"'
139
 
140
  # =====[ DEFINE INTERFACE ]===== #'
141
  title = "Azza Conversational Agent"
 
80
  # Apply the tokenizer to the input text, treating them as a text-pair.
81
  input_ids = tokenizer.encode(question, context)
82
 
83
+ # Report how long the input sequence is. if longer than 512 tokens divide it multiple sequences
 
 
84
 
85
+ print(f"Query has {len(input_ids)} tokens, divided in {len(input_ids)//513 + 1}.\n")
86
 
87
+ input_ids_split = []
88
+ for group in range(len(input_ids)//513):
89
+ input_ids_split.append(input_ids[512*group:512*(group+1)-1])
90
+ input_ids_split.append(input_ids[512*(len(input_ids)//513):len(input_ids)-1])
91
+
92
+ scores = []
93
+ for input in input_ids_split:
94
  # ======== Set Segment IDs ========
95
  # Search the input_ids for the first instance of the `[SEP]` token.
96
+ sep_index = input.index(tokenizer.sep_token_id)
97
 
98
  # The number of segment A tokens includes the [SEP] token istelf.
99
+ num_seg_a = sep_index + 1
100
 
101
  # The remainder are segment B.
102
+ num_seg_b = len(input) - num_seg_a
103
 
104
  # Construct the list of 0s and 1s.
105
+ segment_ids = [0]*num_seg_a + [1]*num_seg_b
106
 
107
  # There should be a segment_id for every input token.
108
+ assert len(segment_ids) == len(input)
109
 
110
  # ======== Evaluate ========
111
  # Run our example through the model.
112
+ outputs = model(torch.tensor([input]), # The tokens representing our input text.
113
  token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
114
  return_dict=True)
115
 
116
+ start_scores = outputs.start_logits
117
+ end_scores = outputs.end_logits
118
+
119
+ max_start_score = torch.max(start_scores)
120
+ max_end_score = torch.max(end_scores)
121
+
122
+ print(max_start_score)
123
+ print(max_end_score)
124
+
125
+
126
 
127
  # ======== Reconstruct Answer ========
128
  # Find the tokens with the highest `start` and `end` scores.
129
+
130
+ answer_start = torch.argmax(start_scores)
131
+ answer_end = torch.argmax(end_scores)
132
 
133
+
134
  # Get the string versions of the input tokens.
135
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
136
 
137
  # Start with the first token.
138
+ answer = tokens[answer_start]
139
 
140
  # Select the remaining answer tokens and join them with whitespace.
141
+ for i in range(answer_start + 1, answer_end + 1):
142
 
143
  # If it's a subword token, then recombine it with the previous token.
144
+ if tokens[i][0:2] == '##':
145
+ answer += tokens[i][2:]
146
 
147
  # Otherwise, add a space then the token.
148
+ else:
149
+ answer += ' ' + tokens[i]
150
+
151
+ scores.append((max_start_score, max_end_score, answer))
152
+
153
+ # Compare scores for answers found and each paragraph and pick the most relevant.
154
 
155
+ final_answer = max(scores, key=lambda x: x[0] + x[1])[2]
156
 
157
  # =====[ DEFINE INTERFACE ]===== #'
158
  title = "Azza Conversational Agent"