Spaces:
Paused
Paused
updated temperature and fixed filter func
Browse files
app.py
CHANGED
@@ -41,13 +41,13 @@ def predict (inp_text):
|
|
41 |
return_tensors = "pt",
|
42 |
).to("cuda")
|
43 |
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
44 |
-
outputs = model.generate(input_ids = inputs, use_cache = True ,temperature = 0.
|
45 |
result = tokenizer.batch_decode(outputs)
|
46 |
# print(result)
|
47 |
return filter_user_assistant_msgs(result[0])
|
48 |
|
49 |
def filter_user_assistant_msgs(text):
|
50 |
-
msg_pattern = r".*Response:\n(.*?)<\|
|
51 |
match = re.match(msg_pattern, text, re.DOTALL)
|
52 |
if match:
|
53 |
message = match.group(1).strip()
|
|
|
41 |
return_tensors = "pt",
|
42 |
).to("cuda")
|
43 |
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
44 |
+
outputs = model.generate(input_ids = inputs, use_cache = True ,temperature = 0.01,max_new_tokens = 1024)
|
45 |
result = tokenizer.batch_decode(outputs)
|
46 |
# print(result)
|
47 |
return filter_user_assistant_msgs(result[0])
|
48 |
|
49 |
def filter_user_assistant_msgs(text):
|
50 |
+
msg_pattern = r".*Response:\n(.*?)<\|eot_id\|>"
|
51 |
match = re.match(msg_pattern, text, re.DOTALL)
|
52 |
if match:
|
53 |
message = match.group(1).strip()
|