analist commited on
Commit
12a623d
·
verified ·
1 Parent(s): 8152e76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -3
app.py CHANGED
@@ -26,11 +26,11 @@ bnb_config = BitsAndBytesConfig(
26
 
27
  # Chargement du modèle et du tokenizer
28
  print("Chargement du modèle de base et du tokenizer...")
29
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
30
 
31
  print("Chargement du modèle de base quantifié...")
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
- BASE_MODEL_NAME,
34
  quantization_config=bnb_config,
35
  device_map="auto",
36
  trust_remote_code=True
@@ -39,7 +39,6 @@ base_model = AutoModelForCausalLM.from_pretrained(
39
  print("Application des adaptateurs...")
40
  model = PeftModel.from_pretrained(
41
  base_model,
42
- ADAPTER_MODEL_NAME,
43
  device_map="auto",
44
  )
45
 
@@ -47,6 +46,49 @@ print("Modèle et tokenizer chargés avec succès!")
47
 
48
  # Fonction pour générer une réponse
49
  def generate_response(message, chat_history, system_prompt, temperature=TEMPERATURE, max_tokens=MAX_NEW_TOKENS):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Construction du contexte de chat
52
  chat_context = []
 
26
 
27
  # Chargement du modèle et du tokenizer
28
  print("Chargement du modèle de base et du tokenizer...")
29
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_MODEL_NAME)
30
 
31
  print("Chargement du modèle de base quantifié...")
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
+ ADAPTER_MODEL_NAME,
34
  quantization_config=bnb_config,
35
  device_map="auto",
36
  trust_remote_code=True
 
39
  print("Application des adaptateurs...")
40
  model = PeftModel.from_pretrained(
41
  base_model,
 
42
  device_map="auto",
43
  )
44
 
 
46
 
47
  # Fonction pour générer une réponse
48
  def generate_response(message, chat_history, system_prompt, temperature=TEMPERATURE, max_tokens=MAX_NEW_TOKENS):
49
+ # Construction du contexte de chat
50
+ chat_context = []
51
+
52
+ # Ajout du prompt système s'il existe
53
+ if system_prompt.strip():
54
+ chat_context.append({"role": "system", "content": system_prompt})
55
+ else:
56
+ chat_context.append({"role": "system", "content": DEFAULT_SYSTEM_PROMPT})
57
+
58
+ # Ajout de l'historique des conversations
59
+ for user_msg, assistant_msg in chat_history:
60
+ chat_context.append({"role": "user", "content": user_msg})
61
+ chat_context.append({"role": "assistant", "content": assistant_msg})
62
+
63
+ # Ajout du message actuel
64
+ chat_context.append({"role": "user", "content": message})
65
+
66
+ # Préparation du texte d'entrée avec le template de chat
67
+ input_text = tokenizer.apply_chat_template(
68
+ chat_context,
69
+ tokenize=False,
70
+ add_generation_prompt=True
71
+ )
72
+
73
+ # Tokenisation de l'entrée
74
+ model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
75
+
76
+ # Mise à jour des paramètres de génération
77
+ model.generation_config.temperature = temperature
78
+ model.generation_config.max_new_tokens = max_tokens
79
+
80
+ # Génération de la réponse
81
+ with torch.no_grad():
82
+ generated_ids = model.generate(
83
+ **model_inputs,
84
+ use_cache=True,
85
+ )
86
+
87
+ # Extraction uniquement de la nouvelle partie générée
88
+ new_tokens = generated_ids[0][model_inputs.input_ids.shape[1]:]
89
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
90
+
91
+ return response
92
 
93
  # Construction du contexte de chat
94
  chat_context = []