wangzhang commited on
Commit
614e7ac
·
1 Parent(s): 0fc60cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -38,10 +38,10 @@ def generate(
38
  message: str,
39
  chat_history: list[tuple[str, str]],
40
  system_prompt: str,
41
- max_new_tokens: int = 512,
42
- temperature: float = 0.2,
43
  top_p: float = 0.9,
44
- top_k: int = 10,
45
  repetition_penalty: float = 1.2,
46
  ) -> Iterator[str]:
47
  conversation = []
@@ -49,17 +49,10 @@ def generate(
49
  conversation.append({"role": "system", "content": system_prompt})
50
  for user, assistant in chat_history:
51
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
52
- prompt = f"""### Instruction:
53
- 根据巨杉数据库SequoiaDB的相关问题进行回答。
54
-
55
- ### Input:
56
- {message}
57
-
58
- ### Response:
59
- """
60
- conversation.append({"role": "user", "content": prompt})
61
  chat = tokenizer.apply_chat_template(conversation, tokenize=False)
62
- inputs = tokenizer(chat, return_tensors="pt", truncation=True).input_ids.cuda()
63
  if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
64
  inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
65
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
38
  message: str,
39
  chat_history: list[tuple[str, str]],
40
  system_prompt: str,
41
+ max_new_tokens: int = 1024,
42
+ temperature: float = 0.6,
43
  top_p: float = 0.9,
44
+ top_k: int = 50,
45
  repetition_penalty: float = 1.2,
46
  ) -> Iterator[str]:
47
  conversation = []
 
49
  conversation.append({"role": "system", "content": system_prompt})
50
  for user, assistant in chat_history:
51
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
52
+ conversation.append({"role": "user", "content": message})
53
+
 
 
 
 
 
 
 
54
  chat = tokenizer.apply_chat_template(conversation, tokenize=False)
55
+ inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
56
  if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
57
  inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
58
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")