import csv import json import os import pandas as pd from langchain_community.embeddings import HuggingFaceEmbeddings from asg_retriever import legal_pdf from asg_loader import DocumentLoading from asg_retriever import Retriever, query_embeddings_new_new from asg_generator import generate_sentence_patterns, generate from category_and_tsne import clustering from langchain_text_splitters import RecursiveCharacterTextSplitter import time import torch import re import transformers from dotenv import load_dotenv from asg_clustername import generate_cluster_name_new from asg_outline import OutlineGenerator, generateSurvey_qwen_new import os from markdown_pdf import MarkdownPdf, Section # Assuming you are using markdown_pdf from typing import Any from .path_utils import get_path, setup_hf_cache # 设置 Hugging Face 缓存目录 cache_dir = setup_hf_cache() def clean_str(input_str): input_str = str(input_str).strip().lower() if input_str == "none" or input_str == "nan" or len(input_str) == 0: return "" input_str = input_str.replace('\\n',' ').replace('\n',' ').replace('\r',' ').replace('——',' ').replace('——',' ').replace('__',' ').replace('__',' ').replace('........','.').replace('....','.').replace('....','.').replace('..','.').replace('..','.').replace('..','.').replace('. . . . . . . . ','. ').replace('. . . . ','. ').replace('. . . . ','. ').replace('. . ','. ').replace('. . ','. ') input_str = re.sub(r'\\u[0-9a-z]{4}', ' ', input_str).replace(' ',' ').replace(' ',' ') return input_str def remove_invalid_citations(text, valid_collection_names): """ 只保留 [xxx\] 中的 xxx 属于 valid_collection_names 的引用, 其余的引用标记一律删除。 """ pattern = r"\[(.*?)\\\]" # 匹配形如 [xxx\] 的内容 all_matches = re.findall(pattern, text) new_text = text for match in all_matches: cleaned_match = match.rstrip('\\') # 去除末尾的 \ if cleaned_match not in valid_collection_names: new_text = new_text.replace(f"[{match}\\]", "") return new_text def normalize_citations_with_mapping(paper_text): # 使用正则表达式匹配所有引用标记(形如 [citation1]) citations = re.findall(r'\[.*?\]', paper_text) # 去重并保持顺序 unique_citations = list(dict.fromkeys(citations)) # 生成引用映射表,把原始引用标记映射为数字引用 citation_mapping = {citation: f'[{i + 1}]' for i, citation in enumerate(unique_citations)} # 在文本中替换老引用为新引用 normalized_text = paper_text for old_citation, new_citation in citation_mapping.items(): normalized_text = normalized_text.replace(old_citation, new_citation) # 生成从数字到原始引用标记的反向映射 # 用 rstrip('\\') 去掉末尾的反斜杠 reverse_mapping = { i + 1: unique_citations[i].strip('[]').rstrip('\\') for i in range(len(unique_citations)) } return normalized_text, reverse_mapping def generate_references_section(citation_mapping, collection_pdf_mapping): references = ["# References"] # 生成引用部分 for num in sorted(citation_mapping.keys()): collection_name = citation_mapping[num] pdf_name = collection_pdf_mapping.get(collection_name, "Unknown PDF") if pdf_name.endswith(".pdf"): pdf_name = pdf_name[:-4] # 在每一行末尾添加两个空格以确保换行 references.append(f"[{num}] {pdf_name} ") return "\n".join(references) def fix_citation_punctuation_md(text): """ 把类似于 'some text. \[1]' 或 'some text. \[2]' 调整为 'some text \[1].' 仅针对已经变成 \[1], \[2] 之类数字引用的 Markdown 情况有效。 如果还没有变成 \[数字],则需先经过 normalize_citations_with_mapping。 """ # 正则表达式匹配点号后带有空格或无空格,紧接 \[数字] 的情况 pattern = r'\.\s*(\\\[\d+\])' replacement = r' \1.' fixed_text = re.sub(pattern, replacement, text) return fixed_text def finalize_survey_paper(paper_text, Global_collection_names, Global_file_names): # 1) 删除所有不想要的旧引用(包括 [数字]、[Sewon, 2021] 等) paper_text = remove_invalid_citations(paper_text, Global_collection_names) # 2) 规范化引用 => [1][2]... normalized_text, citation_mapping = normalize_citations_with_mapping(paper_text) # 3) 修复标点,比如 .[1] => [1]. normalized_text = fix_citation_punctuation_md(normalized_text) # 4) 构造 {collection_name: pdf_file_name} 字典 collection_pdf_mapping = dict(zip(Global_collection_names, Global_file_names)) # 5) 生成 References references_section = generate_references_section(citation_mapping, collection_pdf_mapping) # 6) 合并正文和 References final_paper = normalized_text.strip() + "\n\n" + references_section return final_paper class ASG_system: def __init__(self, root_path: str, survey_id:str, pdf_path: str, survey_title: str, cluster_standard: str) -> None: load_dotenv() self.pdf_path = pdf_path self.txt_path = root_path + "/txt" self.tsv_path = root_path + "/tsv" self.md_path = root_path + "/md" self.info_path = root_path + "/info" self.result_path = root_path + "/result" self.survey_id = survey_id self.survey_title = survey_title self.cluster_standard = cluster_standard self.collection_names = [] self.file_names = [] self.citation_data = [] self.description_list = [] self.ref_list = [] self.cluster_names = [] self.collection_names_clustered = [] self.df_selected = '' model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" try: self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", cache_folder=cache_dir) except Exception as e: print(f"Error initializing embedder: {e}") self.embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") self.pipeline = transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, token = os.getenv('HF_API_KEY'), device_map="auto", ) self.pipeline.model.load_adapter(peft_model_id = "technicolor/llama3.1_8b_outline_generation", adapter_name="outline") self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_abstract_generation", adapter_name="abstract") self.pipeline.model.load_adapter(peft_model_id ="technicolor/llama3.1_8b_conclusion_generation", adapter_name="conclusion") os.makedirs(self.txt_path, exist_ok=True) os.makedirs(f'{self.txt_path}/{self.survey_id}', exist_ok=True) os.makedirs(self.tsv_path, exist_ok=True) os.makedirs(self.md_path, exist_ok=True) os.makedirs(f'{self.md_path}/{self.survey_id}', exist_ok=True) os.makedirs(self.info_path, exist_ok=True) os.makedirs(f'{self.info_path}/{self.survey_id}', exist_ok=True) os.makedirs(self.result_path, exist_ok=True) os.makedirs(f'{self.result_path}/{self.survey_id}', exist_ok=True) def parsing_pdfs(self, mode="intro") -> None: pdf_files = os.listdir(self.pdf_path) loader = DocumentLoading() for pdf_file in pdf_files: pdf_file = os.path.join(self.pdf_path, pdf_file) split_start_time = time.time() base_name = os.path.splitext(os.path.basename(pdf_file))[0] target_dir = os.path.join(self.md_path, self.survey_id, base_name, "auto") md_dir = os.path.join(self.md_path, self.survey_id) loader.convert_pdf_to_md(pdf_file, md_dir) md_file_path = os.path.join(target_dir, f"{base_name}.md") print(md_file_path) print("*"*24) if not os.path.exists(md_file_path): raise FileNotFoundError(f"Markdown file {md_file_path} does not exist. Conversion might have failed.") if mode == "intro": doc = loader.process_md_file(md_file_path, self.survey_id, self.txt_path) elif mode == "full": doc = loader.process_md_file_full(md_file_path, self.survey_id,self.txt_path) text_splitter = RecursiveCharacterTextSplitter( chunk_size=400, chunk_overlap=30, length_function=len, is_separator_regex=False, ) splitters = text_splitter.create_documents([doc]) documents_list = [document.page_content for document in splitters] for i in range(len(documents_list)): documents_list[i] = documents_list[i].replace('\n', ' ') print(f"Splitting took {time.time() - split_start_time} seconds.") embed_start_time = time.time() doc_results = self.embedder.embed_documents(documents_list) if isinstance(doc_results, torch.Tensor): embeddings_list = doc_results.tolist() else: embeddings_list = doc_results print(f"Embedding took {time.time() - embed_start_time} seconds.") # Prepare metadata metadata_list = [{"doc_name": os.path.basename(pdf_file)} for i in range(len(documents_list))] title = os.path.splitext(os.path.basename(pdf_file))[0] title_new = title.strip() invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*','_'] for char in invalid_chars: title_new = title_new.replace(char, ' ') print("============================") print(title_new) # New logic to create collection_name # filename = os.path.basename(file_path) collection_name = legal_pdf(title_new) retriever = Retriever() retriever.list_collections_chroma() retriever.create_collection_chroma(collection_name) retriever.add_documents_chroma( collection_name=collection_name, embeddings_list=embeddings_list, documents_list=documents_list, metadata_list=metadata_list ) self.collection_names.append(collection_name) self.file_names.append(title_new) print(self.collection_names) print(self.file_names) json_files = os.listdir(os.path.join(self.txt_path, self.survey_id)) ref_paper_num = len(json_files) print(f'The length of the json files is {ref_paper_num}') json_data_pd = pd.DataFrame() for _ in json_files: file_path = os.path.join(self.txt_path, self.survey_id, _) with open(file_path, 'r', encoding="utf-8") as file: data = json.load(file) # Extract necessary information title = data.get("title", "") abstract = data.get("abstract", "") authors = data.get("authors", "") introduction = data.get("introduction", "") new_data = { "reference paper title": title, "reference paper citation information (can be collected from Google scholar/DBLP)": authors, "reference paper abstract (Please copy the text AND paste here)": abstract, "reference paper introduction (Please copy the text AND paste here)": introduction, "reference paper doi link (optional)": "", "reference paper category label (optional)": "" } # 将新数据转换为 DataFrame new_data_df = pd.DataFrame([new_data]) # 使用 pd.concat 而不是 append json_data_pd = pd.concat([json_data_pd, new_data_df], ignore_index=True) # Save the DataFrame to a variable for further use input_pd = json_data_pd if ref_paper_num>0: ## change col name input_pd['ref_title'] = [filename for filename in self.file_names] input_pd["ref_context"] = [""]*ref_paper_num input_pd["ref_entry"] = input_pd["reference paper citation information (can be collected from Google scholar/DBLP)"] input_pd["abstract"] = input_pd["reference paper abstract (Please copy the text AND paste here)"].apply(lambda x: clean_str(x) if len(str(x))>0 else 'Invalid abstract') input_pd["intro"] = input_pd["reference paper introduction (Please copy the text AND paste here)"].apply(lambda x: clean_str(x) if len(str(x))>0 else 'Invalid introduction') # optional columns input_pd["label"] = input_pd["reference paper category label (optional)"].apply(lambda x: str(x) if len(str(x))>0 else '') #input_pd["label"] = input_pd["reference paper category id (optional)"].apply(lambda x: str(x) if len(str(x))>0 else '') ## output tsv # output_tsv_filename = self.tsv_path + self.survey_id + '.tsv' output_tsv_filename = os.path.join(self.tsv_path, self.survey_id + '.tsv') #output_df = input_pd[["ref_title","ref_context","ref_entry","abstract","intro","description"]] output_df = input_pd[["ref_title","ref_context","ref_entry","abstract","intro", 'label']] # print(output_df) #pdb.set_trace() output_df.to_csv(output_tsv_filename, sep='\t') def description_generation(self) -> None: query= self.cluster_standard query_list = generate_sentence_patterns(query) for name in self.collection_names: context, citation_data = query_embeddings_new_new(name, query_list) self.citation_data.extend(citation_data) description = generate(context, query, name) self.description_list.append(description) citation_path = f'{self.info_path}/{self.survey_id}/citation_data.json' os.makedirs(os.path.dirname(citation_path), exist_ok=True) with open(citation_path, 'w', encoding="utf-8") as outfile: json.dump(self.citation_data, outfile, indent=4, ensure_ascii=False) file_path = f'{self.tsv_path}/{self.survey_id}.tsv' with open(file_path, 'r', newline='', encoding='utf-8') as infile: reader = csv.reader(infile, delimiter='\t') rows = list(reader) if rows: headers = rows[0] headers.append('retrieval_result') updated_rows = [headers] for row, description in zip(rows[1:], self.description_list): row.append(description) updated_rows.append(row) with open(file_path, 'w', newline='', encoding='utf-8') as outfile: writer = csv.writer(outfile, delimiter='\t') writer.writerows(updated_rows) print('Updated file has been saved to', file_path) else: print('Input file is empty.') def agglomerative_clustering(self) -> None: df = pd.read_csv(f'{self.tsv_path}/{self.survey_id}.tsv', sep='\t', index_col=0, encoding='utf-8') df_selected = df df_selected, _ = clustering(df_selected, 3, self.survey_id, self.info_path, self.tsv_path) self.df_selected = df_selected df_tmp = df_selected.reset_index() df_tmp['index'] = df_tmp.index ref_titles = list(df_tmp.groupby(df_tmp['label'])['ref_title'].apply(list)) # ref_indexs = list(df_tmp.groupby(df_tmp['label'])['index'].apply(list)) category_label_summarized = generate_cluster_name_new(f"{self.tsv_path}/{self.survey_id}.tsv", self.survey_title) self.cluster_names = category_label_summarized cluster_info = {category_label_summarized[i]:ref_titles[i] for i in range(len(category_label_summarized))} for key, value in cluster_info.items(): temp = [legal_pdf(i) for i in value] cluster_info[key] = temp self.collection_names_clustered.append(temp) cluster_info_path = f'{self.info_path}/{self.survey_id}/cluster_info.json' with open(cluster_info_path, 'w', encoding="utf-8") as outfile: json.dump(cluster_info, outfile, indent=4, ensure_ascii=False) def outline_generation(self) -> None: print(self.df_selected) print(self.cluster_names) outline_generator = OutlineGenerator(self.pipeline, self.df_selected, self.cluster_names) outline_generator.get_cluster_info() messages, outline = outline_generator.generate_outline_qwen(self.survey_title) outline_json = {'messages':messages, 'outline': outline} output_path = f'{self.info_path}/{self.survey_id}/outline.json' os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w', encoding="utf-8") as outfile: json.dump(outline_json, outfile, indent=4, ensure_ascii=False) def section_generation(self) -> None: generateSurvey_qwen_new(self.survey_id, self.survey_title, self.collection_names_clustered, self.pipeline, self.citation_data, './txt','./info') def citation_generation(self) -> None: """ Generate citation Markdown and PDF files from JSON and store them in the specified result path. """ json_filepath = os.path.join(self.info_path, self.survey_id, "generated_result.json") markdown_dir = f'{self.result_path}/{self.survey_id}' markdown_filename = f'survey_{self.survey_id}.md' markdown_filepath = os.path.join(markdown_dir, markdown_filename) pdf_filename = f'survey_{self.survey_id}.pdf' pdf_filepath = os.path.join(markdown_dir, pdf_filename) markdown_content = self.get_markdown_content(json_filepath) if not markdown_content: raise ValueError("Markdown content is empty. Cannot generate citation files.") try: with open(markdown_filepath, 'w', encoding='utf-8', encoding="utf-8") as markdown_file: markdown_file.write(markdown_content) print(f"Markdown content saved to: {markdown_filepath}") except Exception as e: raise RuntimeError(f"Failed to save Markdown file: {e}") try: pdf = MarkdownPdf() pdf.meta["title"] = "Citation Results" pdf.add_section(Section(markdown_content, toc=False)) pdf.save(pdf_filepath) print(f"PDF content saved to: {pdf_filepath}") except Exception as e: raise RuntimeError(f"Failed to generate PDF file: {e}") print(f"Files generated successfully: \nMarkdown: {markdown_filepath}\nPDF: {pdf_filepath}") def get_markdown_content(self, json_filepath: str) -> str: """ Read a JSON file and generate Markdown content based on its data. :param json_filepath: Path to the JSON file containing survey data. :return: A string containing the generated Markdown content. """ try: with open(json_filepath, 'r', encoding='utf-8', encoding="utf-8") as json_file: survey_data = json.load(json_file) except Exception as e: raise RuntimeError(f"Failed to read JSON file: {e}") topic = survey_data.get('survey_title', 'Default Topic') content = survey_data.get('content', 'No content available.') survey_title_markdown = f"# A Survey of {topic}\n\n" survey_content_markdown = content + "\n\n" markdown_content = survey_title_markdown + survey_content_markdown markdown_content = finalize_survey_paper(markdown_content, self.collection_names, self.file_names) return markdown_content if __name__ == "__main__": root_path = "." pdf_path = "./pdfs/test" survey_title = "Automating Literature Review Generation with LLM" cluster_standard = "method" asg_system = ASG_system(root_path, 'test', pdf_path, survey_title, cluster_standard) asg_system.parsing_pdfs() asg_system.description_generation() asg_system.agglomerative_clustering() asg_system.outline_generation() asg_system.section_generation() asg_system.citation_generation()