Pclanglais commited on
Commit
459a15e
·
verified ·
1 Parent(s): d37ed38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -33,6 +33,8 @@ repetition_penalty=1.7
33
 
34
  #llm = LLM(model_name, max_model_len=4096)
35
 
 
 
36
  #Vector search over the database
37
  def vector_search(sentence_query):
38
 
@@ -64,24 +66,19 @@ class StopOnTokens(StoppingCriteria):
64
  return True
65
  return False
66
 
 
67
  def predict(message, history):
68
  text = vector_search(message)
69
  message = message + "\n\n### Source ###\n" + text
70
  history_transformer_format = history + [[message, ""]]
71
-
72
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
73
- for item in history_transformer_format])
74
-
75
- return messages
76
-
77
- def predict_alt(message, history):
78
- history_transformer_format = history + [[message, ""]]
79
  stop = StopOnTokens()
80
 
81
- messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
82
  for item in history_transformer_format])
83
 
84
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
 
 
85
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
86
  generate_kwargs = dict(
87
  model_inputs,
@@ -101,7 +98,8 @@ def predict_alt(message, history):
101
  for new_token in streamer:
102
  if new_token != '<':
103
  partial_message += new_token
104
- yield partial_message
 
105
 
106
  # Define the Gradio interface
107
  title = "Tchap"
 
33
 
34
  #llm = LLM(model_name, max_model_len=4096)
35
 
36
+ system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
37
+
38
  #Vector search over the database
39
  def vector_search(sentence_query):
40
 
 
66
  return True
67
  return False
68
 
69
+
70
  def predict(message, history):
71
  text = vector_search(message)
72
  message = message + "\n\n### Source ###\n" + text
73
  history_transformer_format = history + [[message, ""]]
 
 
 
 
 
 
 
 
74
  stop = StopOnTokens()
75
 
76
+ messages = "".join(["".join(["<|start_header_id|>user<|end_header_id|>\n\n"+item[0], "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]])
77
  for item in history_transformer_format])
78
 
79
+ messages = system_prompt + messages
80
+
81
+ """"model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
82
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
83
  generate_kwargs = dict(
84
  model_inputs,
 
98
  for new_token in streamer:
99
  if new_token != '<':
100
  partial_message += new_token
101
+ yield partial_message"""
102
+ return messages
103
 
104
  # Define the Gradio interface
105
  title = "Tchap"