Spaces:
Runtime error
Runtime error
Commit
·
4c1f576
1
Parent(s):
8ab1b3b
Update app.py
Browse files
app.py
CHANGED
@@ -53,15 +53,21 @@ def _launch_demo(args, model, tokenizer, config):
|
|
53 |
def predict(_query, _chatbot, _task_history):
|
54 |
print(f"User: {_parse_text(_query)}")
|
55 |
_chatbot.append((_parse_text(_query), ""))
|
56 |
-
|
57 |
-
attention_mask = torch.ones(input_ids.shape).to('cuda')
|
58 |
-
pad_token_id = tokenizer.eos_token_id
|
59 |
# Tokenize the input
|
60 |
input_ids = tokenizer.encode(_query, return_tensors='pt')
|
61 |
print("Input IDs:", input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# Generate a response using the model
|
63 |
generated_ids = model.generate(input_ids, max_length=300)
|
64 |
print("Generated IDs:", generated_ids)
|
|
|
65 |
# Decode the generated IDs to text
|
66 |
full_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
67 |
|
@@ -73,6 +79,7 @@ def _launch_demo(args, model, tokenizer, config):
|
|
73 |
_task_history.append((_query, full_response))
|
74 |
print(f"OpenHermes: {_parse_text(full_response)}")
|
75 |
|
|
|
76 |
def regenerate(_chatbot, _task_history):
|
77 |
if not _task_history:
|
78 |
yield _chatbot
|
|
|
53 |
def predict(_query, _chatbot, _task_history):
|
54 |
print(f"User: {_parse_text(_query)}")
|
55 |
_chatbot.append((_parse_text(_query), ""))
|
56 |
+
|
|
|
|
|
57 |
# Tokenize the input
|
58 |
input_ids = tokenizer.encode(_query, return_tensors='pt')
|
59 |
print("Input IDs:", input_ids)
|
60 |
+
|
61 |
+
# Move input_ids to CUDA if available
|
62 |
+
input_ids = input_ids.to('cuda')
|
63 |
+
|
64 |
+
# Generate attention_mask
|
65 |
+
attention_mask = torch.ones(input_ids.shape).to('cuda')
|
66 |
+
|
67 |
# Generate a response using the model
|
68 |
generated_ids = model.generate(input_ids, max_length=300)
|
69 |
print("Generated IDs:", generated_ids)
|
70 |
+
|
71 |
# Decode the generated IDs to text
|
72 |
full_response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
73 |
|
|
|
79 |
_task_history.append((_query, full_response))
|
80 |
print(f"OpenHermes: {_parse_text(full_response)}")
|
81 |
|
82 |
+
|
83 |
def regenerate(_chatbot, _task_history):
|
84 |
if not _task_history:
|
85 |
yield _chatbot
|