Update app.py
Browse files
app.py
CHANGED
@@ -102,26 +102,26 @@ def respond(
|
|
102 |
|
103 |
messages= json_obj
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
#
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
|
124 |
-
|
125 |
|
126 |
|
127 |
# messages = [
|
|
|
102 |
|
103 |
messages= json_obj
|
104 |
|
105 |
+
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda') # .to(accelerator.device)
|
106 |
+
input_ids2 = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, return_tensors="pt") #.to('cuda')
|
107 |
+
print(f"Converted input_ids dtype: {input_ids.dtype}")
|
108 |
+
input_str= str(input_ids2)
|
109 |
+
print('input str = ', input_str)
|
110 |
|
111 |
|
112 |
+
gen_tokens = model.generate(
|
113 |
+
input_ids,
|
114 |
+
max_new_tokens=max_tokens,
|
115 |
+
# do_sample=True,
|
116 |
+
temperature=temperature,
|
117 |
+
)
|
118 |
+
|
119 |
+
gen_text = tokenizer.decode(gen_tokens[0])
|
120 |
+
print(gen_text)
|
121 |
+
gen_text= gen_text.replace(input_str,'')
|
122 |
+
gen_text= gen_text.replace('<|END_OF_TURN_TOKEN|>','')
|
123 |
|
124 |
+
yield gen_text
|
125 |
|
126 |
|
127 |
# messages = [
|