Tijmen2 commited on
Commit
7efaceb
·
verified ·
1 Parent(s): b39668d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -36,13 +36,31 @@ GREETING_MESSAGES = [
36
  "The universe awaits! I'm AstroSage. What astronomical wonders shall we discuss?",
37
  ]
38
 
 
 
 
 
39
  def generate_text(prompt: str, history: list, max_new_tokens=512, temperature=0.7, top_p=0.95, top_k=50):
40
  """
41
- Generate a response using the transformer model.
42
  """
43
- # Combine history into the prompt
44
- formatted_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in history])
45
- prompt_with_history = f"{formatted_history}\nUser: {prompt}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Encode the prompt
48
  inputs = tokenizer([prompt_with_history], return_tensors="pt", truncation=True).to(DEVICE)
@@ -56,6 +74,7 @@ def generate_text(prompt: str, history: list, max_new_tokens=512, temperature=0.
56
  skip_prompt=True,
57
  skip_special_tokens=True
58
  )
 
59
  generation_kwargs = dict(
60
  **inputs,
61
  streamer=streamer,
 
36
  "The universe awaits! I'm AstroSage. What astronomical wonders shall we discuss?",
37
  ]
38
 
39
+ def format_message(role: str, content: str) -> str:
40
+ """Format a single message according to Llama-3 chat template."""
41
+ return f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
42
+
43
  def generate_text(prompt: str, history: list, max_new_tokens=512, temperature=0.7, top_p=0.95, top_k=50):
44
  """
45
+ Generate a response using the transformer model with proper Llama-3 chat formatting.
46
  """
47
+ # Start with begin_of_text token
48
+ formatted_messages = ["<|begin_of_text|>"]
49
+
50
+ # Add formatted history
51
+ for msg in history:
52
+ formatted_message = format_message(msg['role'], msg['content'])
53
+ formatted_messages.append(formatted_message)
54
+
55
+ # Add the current prompt
56
+ formatted_message = format_message('user', prompt)
57
+ formatted_messages.append(formatted_message)
58
+
59
+ # Add the start of assistant's response
60
+ formatted_messages.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
61
+
62
+ # Combine all messages
63
+ prompt_with_history = "\n".join(formatted_messages)
64
 
65
  # Encode the prompt
66
  inputs = tokenizer([prompt_with_history], return_tensors="pt", truncation=True).to(DEVICE)
 
74
  skip_prompt=True,
75
  skip_special_tokens=True
76
  )
77
+
78
  generation_kwargs = dict(
79
  **inputs,
80
  streamer=streamer,