ECIR_DEMO_FinPersona / components /chat_conversation.py
TTsamurai
update
31c7f82
from components.induce_personality import construct_big_five_words
from components.constant import (
ACCESS,
QUERY_REWRITING,
RAG,
PERSONALITY,
PERSONALITY_LIST,
REWRITE_PASSAGES,
NUM_PASSAGES,
DEVICE,
RESPONSE_GENERATOR,
TEMPLATE_PAYLOAD,
)
from components.prompt import SYSTEM_INSTRUCTION, RAG_INSTRUCTION, PERSONALITY_INSTRUCTION
import requests
import together
def generate_response_debugging(history):
# outputs_text = "This is a test response"
outputs_text = " ".join([item["content"] for item in history])
history = history + [{"role": "assistant", "content": outputs_text}]
return outputs_text, history
# REWRITER = "castorini/t5-base-canard"
def generate_response_together_api(history, max_tokens, client, model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"):
together_request = {
"model": model,
"messages": history,
"stream": False,
"logprobs": False,
"stop": ["<eos>", "<unk>", "<sep>", "<pad>", "<cls>", "<mask>"],
"max_tokens": max_tokens,
}
response = client.chat.completions.create(**together_request)
outputs_text = response.choices[0].message.content
history = history + [{"role": "assistant", "content": outputs_text}]
return outputs_text, history
def make_local_api_call(payload, api_url):
try:
# Send the POST request to the API
response = requests.post(api_url, json=payload)
# Check if the request was successful
if response.status_code == 200:
result = response.json()
# Print the generated text
return result.get("text", [""])[0]
# if "logits" in result:
# print(f"Logits: {result['logits']}")
else:
# If there was an error, print the status code and message
print(f"Error: {response.status_code}")
print(response.text)
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
def generate_response_local_api(history, terminator, max_tokens, api_url):
payload = TEMPLATE_PAYLOAD.copy()
payload.update(
{
"prompt": history,
"max_tokens": max_tokens,
"stop_token_ids": terminator,
}
)
# Call the API to generate the response
outputs_text = make_local_api_call(payload, api_url)
if outputs_text:
# Update history with the assistant's response
history = history + [{"role": "assistant", "content": outputs_text}]
return outputs_text, history
else:
print("Failed to generate a response.")
return "Generation failed", history # Return the original history in case of failure
def conversation_window(history, N=100):
if len(history) > N:
return history[2:]
return history
def format_message_history(message, history):
if not history:
str_history = f"\n<user>: {message}\n<assistant>"
else:
# Query written
str_history = (
"".join(["".join(["\n<user>:" + item[0], "\n<assistant>:" + item[1]]) for item in history])
+ f"\n<user>: {message}\n<assistant>"
)
return str_history
def format_user_message(message, history):
return history + [{"role": "user", "content": message}]
def format_context(message, history):
return [{"role": "system", "content": message}] + history
def prepare_tokenizer(tokenizer):
special_tokens = ["<eos>", "<unk>", "<sep>", "<pad>", "<cls>", "<mask>"]
for token in special_tokens:
if tokenizer.convert_tokens_to_ids(token) is None:
tokenizer.add_tokens([token])
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<eos>")
terminators = [
tokenizer.eos_token_id,
# self.pipeline.tokenizer.convert_tokens_to_ids(""),
]
return tokenizer, terminators
def gradio_to_huggingface_message(gradio_message):
huggingface_message = []
for user, bot in gradio_message:
huggingface_message.append({"role": "user", "content": user})
huggingface_message.append({"role": "assistant", "content": bot})
return huggingface_message
def huggingface_to_gradio_message(huggingface_message):
gradio_message = []
store = []
for utter in huggingface_message:
if utter["role"] in ["user", "assistant"]:
if utter["role"] == "assistant":
store.append(utter["content"])
gradio_message.append(store)
store = []
else:
store.append(utter["content"])
return gradio_message
def get_personality_instruction(personality):
return PERSONALITY_INSTRUCTION.format(personality)
def get_system_instruction(rag=RAG, personality_list=None):
if rag and personality_list:
return (
SYSTEM_INSTRUCTION
+ RAG_INSTRUCTION
+ get_personality_instruction(construct_big_five_words(personality_list))
)
elif personality_list:
return SYSTEM_INSTRUCTION + get_personality_instruction(construct_big_five_words(personality_list))
elif rag:
return SYSTEM_INSTRUCTION + RAG_INSTRUCTION
else:
return SYSTEM_INSTRUCTION
def format_rag_context(rag_context):
"""
rag_context [{"passage_id": clue_web, "passage_text": "abc"}, ...]
"""
passage_context = "Context: \n"
for passage_rank, info in enumerate(rag_context):
passage_context += f"Passage ID: {info['passage_id']}, Text: {info['passage_text']}\n\n"
return passage_context