Spaces:
Sleeping
Sleeping
import fitz | |
from PyPDF2 import PdfReader | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from anthropic import Anthropic | |
from prompts import INFORMATION_EXTRACTION_PROMPT, INFORMATION_EXTRACTION_TAG_FORMAT, verify_INFORMATION_EXTRACTION_PROMPT, extract_INFORMATION_EXTRACTION_PROMPT | |
from prompts import verify_all_tags_present | |
from prompts import COMPARISON_INPUT_FORMAT, COMPARISON_PROMPT, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, extract_COMPARISON_PROMPT | |
import pandas as pd | |
from concurrent.futures import ThreadPoolExecutor | |
import streamlit as st | |
from dotenv import load_dotenv | |
load_dotenv() | |
def make_llm_api_call(messages): | |
print("Making LLM api call") | |
client = Anthropic() | |
message = client.messages.create( | |
model="claude-3-haiku-20240307", | |
max_tokens=4096, | |
temperature=0, | |
messages=messages, | |
) | |
print("LLM response received") | |
return message | |
def loop_verify_format(answer_text, tag_format, messages, verify_func,root_tag): | |
i = 0 | |
while not verify_func(answer_text): | |
print("Wrong format") | |
assistant_message = {"role": "assistant", "content": [{"type":"text", "text":answer_text}]} | |
corrective_message = {"role":"user", "content":[{"type": "text", "text": f"You did not provide your answer in the correct format. Please provide your answer in the following format:\n{tag_format}"}]} | |
messages.append(assistant_message) | |
messages.append(corrective_message) | |
message = make_llm_api_call(messages) | |
message_text = message.content[0].text | |
answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>" | |
if i > 3: | |
raise Exception(f"LLM failed to provide a valid answer in {i-1} attempts") | |
return answer_text | |
def loop_verify_all_tags_present(answer_text, tags, user_message, tag_format, verify_func, root_tag): | |
missing_tags, _ = verify_all_tags_present(answer_text, tags) | |
if missing_tags: | |
print("There are missing tags", missing_tags) | |
assistant_message = {"role":"assistant", "content":[{"type":"text", "text":answer_text}]} | |
corrective_message = [{"role":"user", "content":[{"type":"text", "text":("In your response, the following tags are missing:\n" + "\n".join([f"<tag>{tag}</tag>" for tag in missing_tags]) + "\n\nPlease add information about the above missing tags and give a complete correct response.")}]}] | |
messages = [user_message, assistant_message, corrective_message] | |
message = make_llm_api_call(messages) | |
message_text = message.content[0].text | |
answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>" | |
answer_text = loop_verify_format(answer_text, tag_format, [user_message], verify_func, root_tag) | |
missing_tags, _ = verify_all_tags_present(answer_text, tags) | |
return answer_text | |
def extract_information_from_pdf(pdf_text, tags): | |
tag_text = "\n".join([f"<tag>{tag}</tag>" for tag in tags]) | |
prompt = INFORMATION_EXTRACTION_PROMPT.format(TEXT=pdf_text, TAGS=tag_text) | |
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]} | |
answer_text = "" | |
messages = [user_message] | |
message = make_llm_api_call(messages) | |
message_text = message.content[0].text | |
answer_text = f"<answer>\n{message_text.split('<answer>')[1].split('</answer>')[0].strip()}\n</answer>" | |
answer_text = loop_verify_format(answer_text, INFORMATION_EXTRACTION_TAG_FORMAT, messages, verify_INFORMATION_EXTRACTION_PROMPT, 'answer') | |
answer_text = loop_verify_all_tags_present(answer_text, tags, user_message, INFORMATION_EXTRACTION_PROMPT, verify_INFORMATION_EXTRACTION_PROMPT, 'answer') | |
return extract_INFORMATION_EXTRACTION_PROMPT(answer_text) | |
def extract_text_with_pypdf(pdf_path): | |
reader = PdfReader(pdf_path) | |
text = "" | |
for page in reader.pages: | |
text += f"{page.extract_text()}\n" | |
return text.strip() | |
def get_tag_info_for_pdf(pdf, tags): | |
text = extract_text_with_pypdf(pdf) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0) | |
chunks = text_splitter.split_text(text) | |
tag_data = {tag:"" for tag in tags} | |
print("chunk length",len(chunks)) | |
for chunk in chunks: | |
data = extract_information_from_pdf(chunk, tags) | |
for tag in tags: | |
tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"}) | |
return tag_data | |
def do_comparison_process(pdf1_data, pdf2_data, tags): | |
tag_data_list = [] | |
for tag in tags: | |
tag_info_text = COMPARISON_INPUT_FORMAT.format(tag=tag, pdf1_information=pdf1_data.get(tag), pdf2_information=pdf2_data.get(tag)) | |
tag_data_list.append(tag_info_text) | |
tag_data_text = "\n".join(tag_data_list) | |
prompt = COMPARISON_PROMPT.format(TAG_INFO= tag_data_text) | |
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]} | |
message = make_llm_api_call([user_message]) | |
message_text = message.content[0].text | |
comparison_text = f"<comparison>\n{message_text.split('<comparison>')[1].split('</comparison>')[0].strip()}\n</comparison>" | |
comparison_text = loop_verify_format(comparison_text, COMPARISON_TAG_FORMAT, [user_message], verify_COMPARISON_PROMPT, 'comparison') | |
comparison_text = loop_verify_all_tags_present(comparison_text, tags, user_message, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, 'comparison') | |
return extract_COMPARISON_PROMPT(comparison_text) | |
# def get_pdf_data(pdf1, pdf2, tags): | |
# def get_tag_info_for_pdf(pdf, tags): | |
# text = extract_text_with_pypdf(pdf) | |
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0) | |
# chunks = text_splitter.split_text(text) | |
# tag_data = {tag:"" for tag in tags} | |
# for chunk in chunks: | |
# data = extract_information_from_pdf(chunk, tags) | |
# for tag in tags: | |
# tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"}) | |
# return tag_data | |
# # Create a ThreadPoolExecutor (or ProcessPoolExecutor for CPU-bound tasks) | |
# with ThreadPoolExecutor(max_workers=2) as executor: | |
# # Submit the functions to the executor | |
# pdf1_future = executor.submit(get_tag_info_for_pdf, pdf1, tags) | |
# pdf2_future = executor.submit(get_tag_info_for_pdf, pdf2, tags) | |
# # Collect the results | |
# pdf1_data = pdf1_future.result() | |
# pdf2_data = pdf2_future.result() | |
# return pdf1_data, pdf2_data | |
def process_comparison_data(pdf1, pdf2, tags): | |
with st.spinner("Processing PDF 1"): | |
pdf1_data = get_tag_info_for_pdf(pdf1, tags) | |
with st.spinner("Processing PDF 2"): | |
pdf2_data = get_tag_info_for_pdf(pdf2, tags) | |
with st.spinner("Generating Comparison Data"): | |
comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags) | |
# pdf1_data, pdf2_data = get_pdf_data(pdf1, pdf2, tags) | |
# comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags) | |
table_data = [] | |
for tag in tags: | |
table_data.append((tag, pdf1_data.get(tag), pdf2_data.get(tag), comparison_data.get(tag))) | |
df = pd.DataFrame(table_data, columns=['Tags', 'PDF 1', 'PDF 2', 'Difference']) | |
df.set_index('Tags', inplace=True) | |
return df | |