Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -38,10 +38,10 @@ def generate(
|
|
38 |
message: str,
|
39 |
chat_history: list[tuple[str, str]],
|
40 |
system_prompt: str,
|
41 |
-
max_new_tokens: int =
|
42 |
-
temperature: float = 0.
|
43 |
top_p: float = 0.9,
|
44 |
-
top_k: int =
|
45 |
repetition_penalty: float = 1.2,
|
46 |
) -> Iterator[str]:
|
47 |
conversation = []
|
@@ -49,17 +49,10 @@ def generate(
|
|
49 |
conversation.append({"role": "system", "content": system_prompt})
|
50 |
for user, assistant in chat_history:
|
51 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
### Input:
|
56 |
-
{message}
|
57 |
-
|
58 |
-
### Response:
|
59 |
-
"""
|
60 |
-
conversation.append({"role": "user", "content": prompt})
|
61 |
chat = tokenizer.apply_chat_template(conversation, tokenize=False)
|
62 |
-
inputs = tokenizer(chat, return_tensors="pt",
|
63 |
if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
|
64 |
inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
|
65 |
gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
38 |
message: str,
|
39 |
chat_history: list[tuple[str, str]],
|
40 |
system_prompt: str,
|
41 |
+
max_new_tokens: int = 1024,
|
42 |
+
temperature: float = 0.6,
|
43 |
top_p: float = 0.9,
|
44 |
+
top_k: int = 50,
|
45 |
repetition_penalty: float = 1.2,
|
46 |
) -> Iterator[str]:
|
47 |
conversation = []
|
|
|
49 |
conversation.append({"role": "system", "content": system_prompt})
|
50 |
for user, assistant in chat_history:
|
51 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
52 |
+
conversation.append({"role": "user", "content": message})
|
53 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
chat = tokenizer.apply_chat_template(conversation, tokenize=False)
|
55 |
+
inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
|
56 |
if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
|
57 |
inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
|
58 |
gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|