InteractiveSurvey / src /demo /asg_clustername.py
technicolor's picture
Add Django InteractiveSurvey project
a97d040
raw
history blame
8.93 kB
import os
import pandas as pd
import re # Import the regular expressions module
from openai import OpenAI
import ast
def generate_cluster_name_qwen_sep(tsv_path, survey_title):
data = pd.read_csv(tsv_path, sep='\t')
# Define the system prompt once, outside the loop
system_prompt = f'''You are a research assistant working on a survey paper. The survey paper is about "{survey_title}". \
'''
result = [] # Initialize the result list
for i in range(3): # Assuming labels are 0, 1, 2
sentence_list = [] # Reset sentence_list for each label
for j in range(len(data)):
if data['label'][j] == i:
sentence_list.append(data['retrieval_result'][j])
# Convert the sentence list to a string representation
user_prompt = f'''
Given a list of descriptions of sentences about an aspect of the survey, you need to use one phrase (within 8 words) to summarize it and treat it as a section title of your survey paper. \
Your response should be a list with only one element and without any other information, for example, ["Post-training of LLMs"] \
Your response must contain one keyword of the survey title, unspecified or irrelevant results are not allowed. \
The description list is:{sentence_list}'''
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
openai_api_key = os.getenv("OPENAI_API_KEY")
openai_api_base = os.getenv("OPENAI_API_BASE")
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
chat_response = client.chat.completions.create(
model=os.environ.get("MODEL"),
max_tokens=768,
temperature=0.5,
stop="<|im_end|>",
stream=True,
messages=messages
)
# Stream the response to a single text string
text = ""
for chunk in chat_response:
if chunk.choices[0].delta.content:
text += chunk.choices[0].delta.content
# Use regex to extract the first content within []
match = re.search(r'\[(.*?)\]', text)
if match:
cluster_name = match.group(1).strip() # Extract and clean the cluster name
# 去除集群名称两侧的引号(如果存在)
cluster_name = cluster_name.strip('"').strip("'")
result.append(cluster_name)
else:
result.append("No Cluster Name Found") # Handle cases where pattern isn't found
# print("The generated cluster names are:")
# print(result)
return result # This will be a list with three elements
# Example usage:
# result = generate_cluster_name_qwen_sep('path_to_your_file.tsv', 'Your Survey Title')
# print(result) # Output might look like ["Cluster One", "Cluster Two", "Cluster Three"]
def refine_cluster_name(cluster_names, survey_title):
cluster_names = str(cluster_names) # Convert to string to handle list input
# Define the system prompt to set the context
system_prompt = f'''You are a research assistant tasked with optimizing and refining a set of section titles for a survey paper. The survey paper is about "{survey_title}".
'''
# Construct the user prompt, including all cluster names
user_prompt = f'''
Here is a set of section titles generated for the survey topic "{survey_title}":
{cluster_names}
Please ensure that all cluster names are coherent and consistent with each other, and that each name is clear, concise, and accurately reflects the corresponding section.
Notice to remove the overlapping information between the cluster names.
Each cluster name should be within 8 words and include a keyword from the survey title.
Response with a list of section titles in the following format without any other irrelevant information,
For example, ["Refined Title 1", "Refined Title 2", "Refined Title 3"]
'''
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# Initialize OpenAI client
openai_api_key = os.getenv("OPENAI_API_KEY")
openai_api_base = os.getenv("OPENAI_API_BASE")
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
try:
chat_response = client.chat.completions.create(
model=os.environ.get("MODEL"),
max_tokens=256,
temperature=0.5,
stop="<|im_end|>",
stream=True,
messages=messages
)
# Stream the response and concatenate into a complete text
text = ""
for chunk in chat_response:
if chunk.choices[0].delta.content:
text += chunk.choices[0].delta.content
# print("The raw response text is:")
# print(text)
# Use regex to extract content within square brackets
match = re.search(r'\[(.*?)\]', text)
if match:
refined_cluster_names = match.group(1).strip() # Extract and clean the cluster name
else:
refined_cluster_names = [
survey_title + ": Definition",
survey_title + ": Methods",
survey_title + ": Evaluation"
] # Handle cases where pattern isn't found
except Exception as e:
print(f"An error occurred while refining cluster names: {e}")
refined_cluster_names = ["Refinement Error"] * len(cluster_names)
refined_cluster_names = ast.literal_eval(refined_cluster_names) # Convert string to list
# print("The refined cluster names are:")
# print(refined_cluster_names)
return refined_cluster_names # Returns a list with the refined cluster names、
def generate_cluster_name_new(tsv_path, survey_title, cluster_num = 3):
data = pd.read_csv(tsv_path, sep='\t')
desp=[]
for i in range(cluster_num): # Assuming labels are 0, 1, 2
sentence_list = [] # Initialize the sentence list
for j in range(len(data)):
if data['label'][j] == i:
sentence_list.append(data['retrieval_result'][j])
desp.append(sentence_list)
system_prompt = f'''
You are a research assistant working on a survey paper. The survey paper is about "{survey_title}". '''
cluster_info = "\n".join([f'Cluster {i+1}: "{desp[i]}"' for i in range(cluster_num)])
user_prompt = f'''
Your task is to generate {cluster_num} distinctive cluster names (e.g., "Pre-training of LLMs") of the given clusters of reference papers, each reference paper is described by a sentence.
The clusters of reference papers are:
{cluster_info}
Your output should be a single list of {cluster_num} cluster names, e.g., ["Pre-training of LLMs", "Fine-tuning of LLMs", "Evaluation of LLMs"]
Do not output any other text or information.
'''
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
openai_api_key = os.getenv("OPENAI_API_KEY")
openai_api_base = os.getenv("OPENAI_API_BASE")
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
chat_response = client.chat.completions.create(
model=os.environ.get("MODEL"),
max_tokens=768,
temperature=0.5,
stop="<|im_end|>",
stream=True,
messages=messages
)
# Stream the response to a single text string
text = ""
for chunk in chat_response:
if chunk.choices[0].delta.content:
text += chunk.choices[0].delta.content
# print("The raw response text is:")
# print(text)
# Use regex to extract content within square brackets
match = re.search(r'\[(.*?)\]', text)
if match:
refined_cluster_names = match.group(1).strip() # Extract and clean the cluster name
else:
predefined_sections = [
"Definition", "Methods", "Evaluation", "Applications",
"Challenges", "Future Directions", "Comparisons", "Case Studies"
]
# 根据 cluster_num 选择前 cluster_num 个预定义类别
refined_cluster_names = [
f"{survey_title}: {predefined_sections[i]}" for i in range(cluster_num)
]
refined_cluster_names = ast.literal_eval(refined_cluster_names) # Convert string to list
# print("The refined cluster names are:")
# print(refined_cluster_names)
return refined_cluster_names # Returns a list with the refined cluster names、
if __name__ == "__main__":
refined_result = refine_cluster_name(["Pre-training of LLMs", "Fine-tuning of LLMs", "Evaluation of LLMs"], 'Survey of LLMs')
# print(refined_result)