Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -362,22 +362,22 @@ def generate_response(user_input, model_id):
|
|
362 |
print(f"Generated prompt: {prompt}") # <-- Log the prompt here
|
363 |
|
364 |
# Add the retrieved knowledge to the prompt
|
365 |
-
|
366 |
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
|
371 |
-
|
372 |
|
373 |
# Reformulated prompt based on intent classification
|
374 |
-
|
375 |
|
376 |
-
|
377 |
-
|
378 |
|
379 |
#prompt = user_input
|
380 |
-
|
381 |
|
382 |
generation_config = GenerationConfig(
|
383 |
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
|
@@ -394,40 +394,40 @@ def generate_response(user_input, model_id):
|
|
394 |
)
|
395 |
|
396 |
# Generate response
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
# Extract AI's response only (omit the prompt)
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
|
406 |
# Encode the prompt and candidates
|
407 |
-
|
408 |
-
|
409 |
|
410 |
# Compute similarity scores between prompt and each candidate
|
411 |
-
|
412 |
|
413 |
# Find the candidate with the highest similarity score
|
414 |
|
415 |
-
|
416 |
-
|
417 |
|
418 |
# Assuming best_response is already defined and contains the generated response
|
419 |
|
420 |
-
|
421 |
# Split the response into sentences
|
422 |
-
|
423 |
# Take the first three sentences and join them back together
|
424 |
-
|
425 |
|
426 |
# Append the user's message to the chat history
|
427 |
-
|
428 |
-
|
429 |
|
430 |
-
return
|
431 |
|
432 |
except Exception as e:
|
433 |
print("Error in generate_response:")
|
|
|
362 |
print(f"Generated prompt: {prompt}") # <-- Log the prompt here
|
363 |
|
364 |
# Add the retrieved knowledge to the prompt
|
365 |
+
func_caller.append({"role": "system", "content": prompt})
|
366 |
|
367 |
+
for msg in chat_history:
|
368 |
+
func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
|
369 |
+
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
|
370 |
|
371 |
+
highest_label_result = classify_intent(user_input)
|
372 |
|
373 |
# Reformulated prompt based on intent classification
|
374 |
+
reformulated_prompt = reformulate_prompt(user_input, highest_label_result)
|
375 |
|
376 |
+
func_caller.append({"role": "user", "content": f'{reformulated_prompt}'})
|
377 |
+
formatted_prompt = "\n".join([f"{m['role']}: {m['content']}" for m in func_caller])
|
378 |
|
379 |
#prompt = user_input
|
380 |
+
device = accelerator.device # Automatically uses GPU or CPU based on accelerator setup
|
381 |
|
382 |
generation_config = GenerationConfig(
|
383 |
do_sample=(highest_label == "dialog continuation" or highest_label == "recommendation request"), # True if dialog continuation, else False
|
|
|
394 |
)
|
395 |
|
396 |
# Generate response
|
397 |
+
gpt_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
|
398 |
+
gpt_output = model.generate(gpt_inputs["input_ids"], max_new_tokens=50, generation_config=generation_config)
|
399 |
+
final_response = tokenizer.decode(gpt_output[0], skip_special_tokens=True)
|
400 |
# Extract AI's response only (omit the prompt)
|
401 |
+
ai_response2 = final_response.replace(reformulated_prompt, "").strip()
|
402 |
+
ai_response = re.sub(re.escape(formatted_prompt), "", final_response, flags=re.IGNORECASE).strip()
|
403 |
+
ai_response = re.split(r'(?<=\w[.!?]) +', ai_response)
|
404 |
+
ai_response = [s.strip() for s in re.split(r'(?<=\w[.!?]) +', ai_response) if s]
|
405 |
|
406 |
# Encode the prompt and candidates
|
407 |
+
prompt_embedding = bertmodel.encode(formatted_prompt, convert_to_tensor=True)
|
408 |
+
candidate_embeddings = bertmodel.encode(ai_response, convert_to_tensor=True)
|
409 |
|
410 |
# Compute similarity scores between prompt and each candidate
|
411 |
+
similarities = util.pytorch_cos_sim(prompt_embedding, candidate_embeddings)[0]
|
412 |
|
413 |
# Find the candidate with the highest similarity score
|
414 |
|
415 |
+
best_index = similarities.argmax()
|
416 |
+
best_response = ai_response[best_index]
|
417 |
|
418 |
# Assuming best_response is already defined and contains the generated response
|
419 |
|
420 |
+
if highest_label == "dialog continuation":
|
421 |
# Split the response into sentences
|
422 |
+
sentences = best_response.split('. ')
|
423 |
# Take the first three sentences and join them back together
|
424 |
+
best_response = '. '.join(sentences[:3]) if len(sentences) > 3 else best_response
|
425 |
|
426 |
# Append the user's message to the chat history
|
427 |
+
chat_history.append({'role': 'user', 'content': user_input})
|
428 |
+
chat_history.append({'role': 'assistant', 'content': best_response})
|
429 |
|
430 |
+
return best_response
|
431 |
|
432 |
except Exception as e:
|
433 |
print("Error in generate_response:")
|