Xenova HF staff commited on
Commit
65b5c46
·
verified ·
1 Parent(s): 2a4eaa1

Simplify usage

Browse files
Files changed (1) hide show
  1. README.md +6 -6
README.md CHANGED
@@ -1,5 +1,9 @@
1
  ---
2
  pipeline_tag: text-generation
 
 
 
 
3
  ---
4
 
5
  ## Usage
@@ -41,8 +45,7 @@ past_key_values = {
41
  for kv in ('key', 'value')
42
  }
43
  input_ids = inputs['input_ids']
44
- attention_mask = inputs['attention_mask']
45
- position_ids = np.cumsum(inputs['attention_mask'], axis=-1)
46
 
47
  # 3. Generation loop
48
  max_new_tokens = 1024
@@ -50,14 +53,12 @@ generated_tokens = np.array([[]], dtype=np.int64)
50
  for i in range(max_new_tokens):
51
  logits, *present_key_values = decoder_session.run(None, dict(
52
  input_ids=input_ids,
53
- attention_mask=attention_mask,
54
  position_ids=position_ids,
55
  **past_key_values,
56
  ))
57
 
58
  ## Update values for next generation loop
59
  input_ids = logits[:, -1].argmax(-1, keepdims=True)
60
- attention_mask = np.ones_like(input_ids)
61
  position_ids = position_ids[:, -1:] + 1
62
  for j, key in enumerate(past_key_values):
63
  past_key_values[key] = present_key_values[j]
@@ -145,5 +146,4 @@ const messages = [
145
  // Generate a response
146
  const output = await generator(messages, { max_new_tokens: 512, do_sample: false });
147
  console.log(output[0].generated_text.at(-1).content);
148
- ```
149
-
 
1
  ---
2
  pipeline_tag: text-generation
3
+ base_model:
4
+ - google/gemma-3-1b-it
5
+ library_name: transformers.js
6
+ license: gemma
7
  ---
8
 
9
  ## Usage
 
45
  for kv in ('key', 'value')
46
  }
47
  input_ids = inputs['input_ids']
48
+ position_ids = np.tile(np.arange(1, input_ids.shape[-1] + 1), (batch_size, 1))
 
49
 
50
  # 3. Generation loop
51
  max_new_tokens = 1024
 
53
  for i in range(max_new_tokens):
54
  logits, *present_key_values = decoder_session.run(None, dict(
55
  input_ids=input_ids,
 
56
  position_ids=position_ids,
57
  **past_key_values,
58
  ))
59
 
60
  ## Update values for next generation loop
61
  input_ids = logits[:, -1].argmax(-1, keepdims=True)
 
62
  position_ids = position_ids[:, -1:] + 1
63
  for j, key in enumerate(past_key_values):
64
  past_key_values[key] = present_key_values[j]
 
146
  // Generate a response
147
  const output = await generator(messages, { max_new_tokens: 512, do_sample: false });
148
  console.log(output[0].generated_text.at(-1).content);
149
+ ```