Spaces:
Runtime error
Runtime error
Commit
·
b9bf03b
1
Parent(s):
8c28418
Update app.py
Browse files
app.py
CHANGED
@@ -39,7 +39,7 @@ def generate(
|
|
39 |
ctx,
|
40 |
image_features,
|
41 |
token_count=200,
|
42 |
-
temperature=
|
43 |
top_p=0.3,
|
44 |
presencePenalty = 0.1,
|
45 |
countPenalty = 0.1,
|
@@ -56,11 +56,10 @@ def generate(
|
|
56 |
occurrence = {}
|
57 |
for i in range(int(token_count)):
|
58 |
if i == 0:
|
59 |
-
prefix_ids = pipeline.encode("User: ")
|
60 |
-
prefix_embs = model.w['emb.weight'][prefix_ids]
|
61 |
input_ids = pipeline.encode(ctx)
|
|
|
62 |
text_embs = model.w['emb.weight'][input_ids]
|
63 |
-
input_embs = torch.cat((
|
64 |
out, state = model.forward(embs=input_embs, state=None)
|
65 |
else:
|
66 |
input_ids = [token]
|
|
|
39 |
ctx,
|
40 |
image_features,
|
41 |
token_count=200,
|
42 |
+
temperature=0.2,
|
43 |
top_p=0.3,
|
44 |
presencePenalty = 0.1,
|
45 |
countPenalty = 0.1,
|
|
|
56 |
occurrence = {}
|
57 |
for i in range(int(token_count)):
|
58 |
if i == 0:
|
|
|
|
|
59 |
input_ids = pipeline.encode(ctx)
|
60 |
+
print(input_ids)
|
61 |
text_embs = model.w['emb.weight'][input_ids]
|
62 |
+
input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
|
63 |
out, state = model.forward(embs=input_embs, state=None)
|
64 |
else:
|
65 |
input_ids = [token]
|