pedrocas15 commited on
Commit
ac2a08a
·
verified ·
1 Parent(s): f8ab6ea

Update rpc.py

Browse files
Files changed (1) hide show
  1. rpc.py +61 -54
rpc.py CHANGED
@@ -53,17 +53,18 @@ class SharedEmbedding(tf.keras.layers.Layer):
53
  if mode == 'embedding':
54
  return tf.nn.embedding_lookup(self.shared_weights, inputs)
55
  elif mode == 'classify':
56
- sw = tf.nn.l2_normalize(self.shared_weights, axis=-1)
57
- return tf.nn.softmax(tf.matmul(inputs, sw, transpose_b=True)/temp, axis=-1)
58
 
59
 
60
  # Attention Layer
61
- class Attention(keras.layers.Layer):
62
- def __init__(self, **kwargs):
63
- super(Attention, self).__init__(**kwargs)
 
64
 
65
  def build(self, input_shape):
66
  self.embed_dim = input_shape[-1]
 
67
  self.mask = tf.where(tf.linalg.band_part(tf.ones((input_shape[-2], input_shape[-2])), -1, 0) == 1.0, 0.0, float("-inf"))
68
  self.range_do = -tf.range(input_shape[-2])-1
69
  self.range_undo = tf.range(input_shape[-2])+1
@@ -79,7 +80,22 @@ class Attention(keras.layers.Layer):
79
  shape=(input_shape[-1], input_shape[-1]),
80
  initializer='uniform',
81
  trainable=True)
82
- super(Attention, self).build(input_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  def roll_embeddings(self, tensor, shift_values):
85
  batch_size, time_size, embed_dim = tensor.shape
@@ -93,55 +109,43 @@ class Attention(keras.layers.Layer):
93
  rolled_tensor = tf.gather(tensor, new_indices, batch_dims=2)
94
  return rolled_tensor
95
 
96
- def call(self, x, pos):
97
- q = x @ self.Q
98
- k = x @ self.K
99
  v = x @ self.V
 
 
100
  atti = tf.matmul(q, k, transpose_b=True)
101
  attp = tf.matmul(q, pos, transpose_b=True)
102
- attp = self.roll_embeddings(attp, self.range_do)
 
103
  att = atti + attp
104
  att = tf.nn.softmax((att / math.sqrt(self.embed_dim)) + self.mask, axis=-1)
 
 
 
 
 
 
 
 
 
105
  outi = att @ v
106
  attp = self.roll_embeddings(att, self.range_undo)
107
- outp = attp @ pos
108
  out = outi + outp
 
109
  return out
110
 
111
 
112
- # Encoder
113
- inputs = Input(shape=(input_size, ), dtype=tf.int32)
114
- emb_layer = SharedEmbedding(vocab_size, embed_dim)
115
- pos_layer = keras_nlp.layers.PositionEmbedding(input_size)
116
-
117
- x = LayerNormalization()(emb_layer(inputs, mode="embedding"))
118
- pos = pos_layer(x)
119
-
120
- b = 6
121
- for _ in range(b):
122
- x += (2*b)**-0.5 * LayerNormalization()(Attention()(x, pos))
123
- x += (2*b)**-0.5 * LayerNormalization()(Dense(embed_dim, activation="gelu")(x))
124
- x = tf.nn.l2_normalize(x, axis=-1)
125
-
126
- for _ in range(b):
127
- x1 = Dense(embed_dim, activation="gelu")(x)
128
- x1 = Dense(embed_dim, activation="gelu")(x1)
129
- x += b**-0.5 * LayerNormalization()(x1)
130
- x = tf.nn.l2_normalize(x, axis=-1)
131
-
132
- x = emb_layer(x, mode="classify", temp=0.1)
133
-
134
- model = keras.Model(inputs=inputs, outputs=x)
135
- model.compile(
136
- loss=keras.losses.SparseCategoricalCrossentropy(ignore_class=tokenizer.pad_token_id),
137
- optimizer=keras.optimizers.AdamW(learning_rate=0.001),
138
- metrics=[masked_accuracy, keras_nlp.metrics.Perplexity(mask_token_id=tokenizer.pad_token_id)],
139
- )
140
-
141
-
142
  # Import Model
143
- model.load_weights("rpc.keras")
144
- encoder = keras.Model(inputs=model.layers[0].input, outputs=model.layers[52].output)
 
 
 
 
 
 
 
145
  encoder.summary()
146
 
147
 
@@ -166,10 +170,10 @@ all_toks = None
166
  def load_index(index_path="/dev/shm/rpc-vecdb/index"):
167
  global index
168
  global all_toks
169
- #import ngtpy
170
- #index = ngtpy.Index(index_path, read_only=True)
171
- import faiss
172
- index = faiss.read_index(index_path + "/index.faiss")
173
  with open(index_path + "/all_toks.json", "r") as f:
174
  all_toks = json.loads(f.read())
175
 
@@ -184,14 +188,15 @@ def generate(text, use_rpc=True, max_tokens=128):
184
  enc_text = enc_text[-input_size:]
185
  if use_rpc:
186
  xq = vectorize_texts([enc_text])[-1]
187
- #_id, _ = index.search(xq, size=1, epsilon=2)[0]
188
- D, I = index.search(xq.reshape((1, -1)), 1)
189
- _id = I[0][0]
190
  if all_toks[_id] in carry_toks:
191
  tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
192
- if all_toks[tmp] in enc_text: tok = tmp
 
193
  else: tok = all_toks[_id]
194
- else: tok = all_toks[_id]
 
195
  else:
196
  ins = enc_text + [tokenizer.pad_token_id] * (input_size - len(enc_text))
197
  ins = tf.constant(ins, shape=(1, input_size))
@@ -199,6 +204,8 @@ def generate(text, use_rpc=True, max_tokens=128):
199
  tok = tf.argmax(res, axis=-1).numpy().tolist()
200
 
201
  enc_text += [tok]
202
- response = tokenizer.decode(enc_text)
 
 
203
 
204
- yield response
 
53
  if mode == 'embedding':
54
  return tf.nn.embedding_lookup(self.shared_weights, inputs)
55
  elif mode == 'classify':
56
+ return tf.nn.softmax(tf.matmul(inputs, self.shared_weights, transpose_b=True), axis=-1)
 
57
 
58
 
59
  # Attention Layer
60
+ class DiffAttention(keras.layers.Layer):
61
+ def __init__(self, depth, **kwargs):
62
+ super(DiffAttention, self).__init__(**kwargs)
63
+ self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
64
 
65
  def build(self, input_shape):
66
  self.embed_dim = input_shape[-1]
67
+ self.input_size = input_shape[-2]
68
  self.mask = tf.where(tf.linalg.band_part(tf.ones((input_shape[-2], input_shape[-2])), -1, 0) == 1.0, 0.0, float("-inf"))
69
  self.range_do = -tf.range(input_shape[-2])-1
70
  self.range_undo = tf.range(input_shape[-2])+1
 
80
  shape=(input_shape[-1], input_shape[-1]),
81
  initializer='uniform',
82
  trainable=True)
83
+
84
+ initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.1)
85
+ self.lambda_q1 = self.add_weight(
86
+ shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_q1"
87
+ )
88
+ self.lambda_k1 = self.add_weight(
89
+ shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_k1"
90
+ )
91
+ self.lambda_q2 = self.add_weight(
92
+ shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_q2"
93
+ )
94
+ self.lambda_k2 = self.add_weight(
95
+ shape=(input_shape[-1],), initializer=initializer, trainable=True, name="lambda_k2"
96
+ )
97
+
98
+ super(DiffAttention, self).build(input_shape)
99
 
100
  def roll_embeddings(self, tensor, shift_values):
101
  batch_size, time_size, embed_dim = tensor.shape
 
109
  rolled_tensor = tf.gather(tensor, new_indices, batch_dims=2)
110
  return rolled_tensor
111
 
112
+ def call(self, x, pos, pos_src):
 
 
113
  v = x @ self.V
114
+ q = tf.transpose(tf.reshape(x @ self.Q, (-1, self.input_size, 2, self.embed_dim//2)), perm=[0, 2, 1, 3])
115
+ k = tf.transpose(tf.reshape(x @ self.K, (-1, self.input_size, 2, self.embed_dim//2)), perm=[0, 2, 1, 3])
116
  atti = tf.matmul(q, k, transpose_b=True)
117
  attp = tf.matmul(q, pos, transpose_b=True)
118
+ attp = self.roll_embeddings(tf.reshape(attp, (-1, self.input_size, self.input_size)), self.range_do)
119
+ attp = tf.reshape(attp, (-1, 2, self.input_size, self.input_size))
120
  att = atti + attp
121
  att = tf.nn.softmax((att / math.sqrt(self.embed_dim)) + self.mask, axis=-1)
122
+ att1 = att[:, 0]
123
+ att2 = att[:, 1]
124
+
125
+ # Differential attention
126
+ lambda_1 = tf.math.exp(tf.reduce_sum(self.lambda_q1 * self.lambda_k1, axis=-1))
127
+ lambda_2 = tf.math.exp(tf.reduce_sum(self.lambda_q2 * self.lambda_k2, axis=-1))
128
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
129
+ att = att1 - lambda_full * att2
130
+
131
  outi = att @ v
132
  attp = self.roll_embeddings(att, self.range_undo)
133
+ outp = attp @ pos_src
134
  out = outi + outp
135
+ out = out * (1 - self.lambda_init)
136
  return out
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  # Import Model
140
+ model = keras.models.load_model(
141
+ "rpc_diff_12b_320inp_ct4_01w10.keras",
142
+ custom_objects={
143
+ "DiffAttention" : DiffAttention,
144
+ "SharedEmbedding" : SharedEmbedding,
145
+ "masked_accuracy" : masked_accuracy
146
+ }
147
+ )
148
+ encoder = keras.Model(inputs=model.layers[0].input, outputs=model.layers[-1].output)
149
  encoder.summary()
150
 
151
 
 
170
  def load_index(index_path="/dev/shm/rpc-vecdb/index"):
171
  global index
172
  global all_toks
173
+ import ngtpy
174
+ index = ngtpy.Index(index_path, read_only=True)
175
+ #import faiss
176
+ #index = faiss.read_index(index_path + "/index.faiss")
177
  with open(index_path + "/all_toks.json", "r") as f:
178
  all_toks = json.loads(f.read())
179
 
 
188
  enc_text = enc_text[-input_size:]
189
  if use_rpc:
190
  xq = vectorize_texts([enc_text])[-1]
191
+ _id = index.search(xq, size=1, epsilon=2)[0][0]
192
+ #_id = index.search(xq.reshape((1, -1)), 1)[1][0][0]
 
193
  if all_toks[_id] in carry_toks:
194
  tmp = tf.argmax(tf.matmul(xq.reshape((1, -1)), encoder.layers[1].shared_weights, transpose_b=True), axis=-1).numpy()[0]
195
+ if tmp in enc_text:
196
+ tok = tmp
197
  else: tok = all_toks[_id]
198
+ else:
199
+ tok = all_toks[_id]
200
  else:
201
  ins = enc_text + [tokenizer.pad_token_id] * (input_size - len(enc_text))
202
  ins = tf.constant(ins, shape=(1, input_size))
 
204
  tok = tf.argmax(res, axis=-1).numpy().tolist()
205
 
206
  enc_text += [tok]
207
+ new_text = tokenizer.decode(enc_text)
208
+ res = new_text[len(text):]
209
+ text = new_text
210
 
211
+ yield res