Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
import re # Import the regular expressions module | |
from openai import OpenAI | |
import ast | |
def generate_cluster_name_qwen_sep(tsv_path, survey_title): | |
data = pd.read_csv(tsv_path, sep='\t') | |
# Define the system prompt once, outside the loop | |
system_prompt = f'''You are a research assistant working on a survey paper. The survey paper is about "{survey_title}". \ | |
''' | |
result = [] # Initialize the result list | |
for i in range(3): # Assuming labels are 0, 1, 2 | |
sentence_list = [] # Reset sentence_list for each label | |
for j in range(len(data)): | |
if data['label'][j] == i: | |
sentence_list.append(data['retrieval_result'][j]) | |
# Convert the sentence list to a string representation | |
user_prompt = f''' | |
Given a list of descriptions of sentences about an aspect of the survey, you need to use one phrase (within 8 words) to summarize it and treat it as a section title of your survey paper. \ | |
Your response should be a list with only one element and without any other information, for example, ["Post-training of LLMs"] \ | |
Your response must contain one keyword of the survey title, unspecified or irrelevant results are not allowed. \ | |
The description list is:{sentence_list}''' | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
openai_api_base = os.getenv("OPENAI_API_BASE") | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
chat_response = client.chat.completions.create( | |
model=os.environ.get("MODEL"), | |
max_tokens=768, | |
temperature=0.5, | |
stop="<|im_end|>", | |
stream=True, | |
messages=messages | |
) | |
# Stream the response to a single text string | |
text = "" | |
for chunk in chat_response: | |
if chunk.choices[0].delta.content: | |
text += chunk.choices[0].delta.content | |
# Use regex to extract the first content within [] | |
match = re.search(r'\[(.*?)\]', text) | |
if match: | |
cluster_name = match.group(1).strip() # Extract and clean the cluster name | |
# 去除集群名称两侧的引号(如果存在) | |
cluster_name = cluster_name.strip('"').strip("'") | |
result.append(cluster_name) | |
else: | |
result.append("No Cluster Name Found") # Handle cases where pattern isn't found | |
# print("The generated cluster names are:") | |
# print(result) | |
return result # This will be a list with three elements | |
# Example usage: | |
# result = generate_cluster_name_qwen_sep('path_to_your_file.tsv', 'Your Survey Title') | |
# print(result) # Output might look like ["Cluster One", "Cluster Two", "Cluster Three"] | |
def refine_cluster_name(cluster_names, survey_title): | |
cluster_names = str(cluster_names) # Convert to string to handle list input | |
# Define the system prompt to set the context | |
system_prompt = f'''You are a research assistant tasked with optimizing and refining a set of section titles for a survey paper. The survey paper is about "{survey_title}". | |
''' | |
# Construct the user prompt, including all cluster names | |
user_prompt = f''' | |
Here is a set of section titles generated for the survey topic "{survey_title}": | |
{cluster_names} | |
Please ensure that all cluster names are coherent and consistent with each other, and that each name is clear, concise, and accurately reflects the corresponding section. | |
Notice to remove the overlapping information between the cluster names. | |
Each cluster name should be within 8 words and include a keyword from the survey title. | |
Response with a list of section titles in the following format without any other irrelevant information, | |
For example, ["Refined Title 1", "Refined Title 2", "Refined Title 3"] | |
''' | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
# Initialize OpenAI client | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
openai_api_base = os.getenv("OPENAI_API_BASE") | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
try: | |
chat_response = client.chat.completions.create( | |
model=os.environ.get("MODEL"), | |
max_tokens=256, | |
temperature=0.5, | |
stop="<|im_end|>", | |
stream=True, | |
messages=messages | |
) | |
# Stream the response and concatenate into a complete text | |
text = "" | |
for chunk in chat_response: | |
if chunk.choices[0].delta.content: | |
text += chunk.choices[0].delta.content | |
# print("The raw response text is:") | |
# print(text) | |
# Use regex to extract content within square brackets | |
match = re.search(r'\[(.*?)\]', text) | |
if match: | |
refined_cluster_names = match.group(1).strip() # Extract and clean the cluster name | |
else: | |
refined_cluster_names = [ | |
survey_title + ": Definition", | |
survey_title + ": Methods", | |
survey_title + ": Evaluation" | |
] # Handle cases where pattern isn't found | |
except Exception as e: | |
print(f"An error occurred while refining cluster names: {e}") | |
refined_cluster_names = ["Refinement Error"] * len(cluster_names) | |
refined_cluster_names = ast.literal_eval(refined_cluster_names) # Convert string to list | |
# print("The refined cluster names are:") | |
# print(refined_cluster_names) | |
return refined_cluster_names # Returns a list with the refined cluster names、 | |
def generate_cluster_name_new(tsv_path, survey_title, cluster_num = 3): | |
data = pd.read_csv(tsv_path, sep='\t') | |
desp=[] | |
for i in range(cluster_num): # Assuming labels are 0, 1, 2 | |
sentence_list = [] # Initialize the sentence list | |
for j in range(len(data)): | |
if data['label'][j] == i: | |
sentence_list.append(data['retrieval_result'][j]) | |
desp.append(sentence_list) | |
system_prompt = f''' | |
You are a research assistant working on a survey paper. The survey paper is about "{survey_title}". ''' | |
cluster_info = "\n".join([f'Cluster {i+1}: "{desp[i]}"' for i in range(cluster_num)]) | |
user_prompt = f''' | |
Your task is to generate {cluster_num} distinctive cluster names (e.g., "Pre-training of LLMs") of the given clusters of reference papers, each reference paper is described by a sentence. | |
The clusters of reference papers are: | |
{cluster_info} | |
Your output should be a single list of {cluster_num} cluster names, e.g., ["Pre-training of LLMs", "Fine-tuning of LLMs", "Evaluation of LLMs"] | |
Do not output any other text or information. | |
''' | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
openai_api_base = os.getenv("OPENAI_API_BASE") | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
chat_response = client.chat.completions.create( | |
model=os.environ.get("MODEL"), | |
max_tokens=768, | |
temperature=0.5, | |
stop="<|im_end|>", | |
stream=True, | |
messages=messages | |
) | |
# Stream the response to a single text string | |
text = "" | |
for chunk in chat_response: | |
if chunk.choices[0].delta.content: | |
text += chunk.choices[0].delta.content | |
# print("The raw response text is:") | |
# print(text) | |
# Use regex to extract content within square brackets | |
match = re.search(r'\[(.*?)\]', text) | |
if match: | |
refined_cluster_names = match.group(1).strip() # Extract and clean the cluster name | |
else: | |
predefined_sections = [ | |
"Definition", "Methods", "Evaluation", "Applications", | |
"Challenges", "Future Directions", "Comparisons", "Case Studies" | |
] | |
# 根据 cluster_num 选择前 cluster_num 个预定义类别 | |
refined_cluster_names = [ | |
f"{survey_title}: {predefined_sections[i]}" for i in range(cluster_num) | |
] | |
refined_cluster_names = ast.literal_eval(refined_cluster_names) # Convert string to list | |
# print("The refined cluster names are:") | |
# print(refined_cluster_names) | |
return refined_cluster_names # Returns a list with the refined cluster names、 | |
if __name__ == "__main__": | |
refined_result = refine_cluster_name(["Pre-training of LLMs", "Fine-tuning of LLMs", "Evaluation of LLMs"], 'Survey of LLMs') | |
# print(refined_result) | |