Simplify usage
Browse files
README.md
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
---
|
2 |
pipeline_tag: text-generation
|
|
|
|
|
|
|
|
|
3 |
---
|
4 |
|
5 |
## Usage
|
@@ -41,8 +45,7 @@ past_key_values = {
|
|
41 |
for kv in ('key', 'value')
|
42 |
}
|
43 |
input_ids = inputs['input_ids']
|
44 |
-
|
45 |
-
position_ids = np.cumsum(inputs['attention_mask'], axis=-1)
|
46 |
|
47 |
# 3. Generation loop
|
48 |
max_new_tokens = 1024
|
@@ -50,14 +53,12 @@ generated_tokens = np.array([[]], dtype=np.int64)
|
|
50 |
for i in range(max_new_tokens):
|
51 |
logits, *present_key_values = decoder_session.run(None, dict(
|
52 |
input_ids=input_ids,
|
53 |
-
attention_mask=attention_mask,
|
54 |
position_ids=position_ids,
|
55 |
**past_key_values,
|
56 |
))
|
57 |
|
58 |
## Update values for next generation loop
|
59 |
input_ids = logits[:, -1].argmax(-1, keepdims=True)
|
60 |
-
attention_mask = np.ones_like(input_ids)
|
61 |
position_ids = position_ids[:, -1:] + 1
|
62 |
for j, key in enumerate(past_key_values):
|
63 |
past_key_values[key] = present_key_values[j]
|
@@ -145,5 +146,4 @@ const messages = [
|
|
145 |
// Generate a response
|
146 |
const output = await generator(messages, { max_new_tokens: 512, do_sample: false });
|
147 |
console.log(output[0].generated_text.at(-1).content);
|
148 |
-
```
|
149 |
-
|
|
|
1 |
---
|
2 |
pipeline_tag: text-generation
|
3 |
+
base_model:
|
4 |
+
- google/gemma-3-1b-it
|
5 |
+
library_name: transformers.js
|
6 |
+
license: gemma
|
7 |
---
|
8 |
|
9 |
## Usage
|
|
|
45 |
for kv in ('key', 'value')
|
46 |
}
|
47 |
input_ids = inputs['input_ids']
|
48 |
+
position_ids = np.tile(np.arange(1, input_ids.shape[-1] + 1), (batch_size, 1))
|
|
|
49 |
|
50 |
# 3. Generation loop
|
51 |
max_new_tokens = 1024
|
|
|
53 |
for i in range(max_new_tokens):
|
54 |
logits, *present_key_values = decoder_session.run(None, dict(
|
55 |
input_ids=input_ids,
|
|
|
56 |
position_ids=position_ids,
|
57 |
**past_key_values,
|
58 |
))
|
59 |
|
60 |
## Update values for next generation loop
|
61 |
input_ids = logits[:, -1].argmax(-1, keepdims=True)
|
|
|
62 |
position_ids = position_ids[:, -1:] + 1
|
63 |
for j, key in enumerate(past_key_values):
|
64 |
past_key_values[key] = present_key_values[j]
|
|
|
146 |
// Generate a response
|
147 |
const output = await generator(messages, { max_new_tokens: 512, do_sample: false });
|
148 |
console.log(output[0].generated_text.at(-1).content);
|
149 |
+
```
|
|