dofbi commited on
Commit
f8e85bf
1 Parent(s): 0827295
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -46,14 +46,30 @@ def generate_response(user_input, system_prompt, max_new_tokens=150, temperature
46
  {"role": "user", "content": user_input}
47
  ]
48
 
49
- # Tokeniser l'entrée
50
- inputs = tokenizer(user_input, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Générer une réponse
53
- outputs = model.generate(inputs.input_ids, max_new_tokens=int(max_new_tokens), temperature=temperature)
 
 
 
54
 
55
- # Décoder la réponse en texte
56
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
57
 
58
  # Fonction pour mettre à jour le message du prompt système en fonction du choix
59
  def update_system_prompt(selected_prompt):
 
46
  {"role": "user", "content": user_input}
47
  ]
48
 
49
+ # Utiliser apply_chat_template
50
+ text = tokenizer.apply_chat_template(
51
+ messages,
52
+ tokenize=False,
53
+ add_generation_prompt=True
54
+ )
55
+
56
+ # Préparer les entrées du modèle
57
+ model_inputs = tokenizer([text], return_tensors="pt").to(device)
58
+
59
+ # Générer la réponse
60
+ generated_ids = model.generate(
61
+ model_inputs.input_ids,
62
+ max_new_tokens=int(max_new_tokens),
63
+ temperature=temperature
64
+ )
65
 
66
+ # Décoder la réponse
67
+ generated_ids = [
68
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
69
+ ]
70
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
71
 
72
+ return response
 
73
 
74
  # Fonction pour mettre à jour le message du prompt système en fonction du choix
75
  def update_system_prompt(selected_prompt):