Spaces:
Runtime error
Runtime error
howard-hou
commited on
Commit
•
586d6a1
1
Parent(s):
1d2fc64
Update modeling_rwkv.py
Browse files- modeling_rwkv.py +1 -0
modeling_rwkv.py
CHANGED
@@ -1043,6 +1043,7 @@ class RWKV(MyModule):
|
|
1043 |
elif embs is not None and tokens is not None:
|
1044 |
seq_mode = len(tokens) > 1
|
1045 |
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
|
|
|
1046 |
x = torch.cat([x, embs], dim=0)
|
1047 |
else:
|
1048 |
raise ValueError('Either tokens or embs must be provided')
|
|
|
1043 |
elif embs is not None and tokens is not None:
|
1044 |
seq_mode = len(tokens) > 1
|
1045 |
x = w['emb.weight'][tokens if seq_mode else tokens[0]]
|
1046 |
+
x = x.to(device=embs.device, dtype=embs.dtype, non_blocking=True)
|
1047 |
x = torch.cat([x, embs], dim=0)
|
1048 |
else:
|
1049 |
raise ValueError('Either tokens or embs must be provided')
|