import anthropic import base64 import pandas as pd import requests import os import numpy as np from openai import OpenAI import io import tiktoken import PyPDF2 import prompts from typing import List, Literal from pydantic import BaseModel import time import gradio as gr ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") def ask_claude( query: str, pdf_path: str = None, use_cache: bool = False, system: str = None, max_tokens: int = 1024, model: str = "claude-3-5-sonnet-20241022" ) -> str: """ Unified function to query Claude API with various options. Args: query: Question/prompt for Claude pdf_path: Optional path to PDF file (local or URL) use_cache: Whether to enable prompt caching system: Optional system prompt max_tokens: Maximum tokens in response (default 1024) model: Claude model to use (default claude-3-5-sonnet) Returns: Claude's response as a string """ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) # Handle PDF if provided content = query betas = [] if pdf_path: # Get PDF content if pdf_path.startswith(('http://', 'https://')): response = requests.get(pdf_path) binary_data = response.content else: with open(pdf_path, "rb") as pdf_file: binary_data = pdf_file.read() pdf_data = base64.standard_b64encode(binary_data).decode("utf-8") content = [ { "type": "document", "source": { "type": "base64", "media_type": "application/pdf", "data": pdf_data } }, { "type": "text", "text": query } ] betas.append("pdfs-2024-09-25") # Add prompt caching if requested if use_cache: betas.append("prompt-caching-2024-07-31") # Prepare API call kwargs kwargs = { "model": model, "max_tokens": max_tokens, "messages": [{"role": "user", "content": content}] } # Add optional parameters if provided if system: kwargs["system"] = system if betas: kwargs["betas"] = betas message = client.beta.messages.create(**kwargs) return message.content[0].text class Point(BaseModel): content: str importance: Literal["critical", "minor"] class Review(BaseModel): contributions: str strengths: List[Point] weaknesses: List[Point] requested_changes: List[Point] impact_concerns: str importance_mapping = {"critical": 2, "minor": 1} client = OpenAI(api_key=OPENAI_API_KEY) model_name = "gpt-4o-2024-08-06" def format_gpt(prompt): chat_completion = client.beta.chat.completions.parse( messages=[ { "role": "user", "content": prompt, } ], model='gpt-4o', response_format=Review, ) return chat_completion.choices[0].message.parsed.model_dump() def parse_final(parsed, max_strengths=3, max_weaknesses=5, max_requested_changes=5): new_parsed = {} new_parsed["contributions"] = parsed["contributions"] new_parsed["impact_concerns"] = parsed["impact_concerns"] new_parsed["strengths"] = "\n".join( [f'- {point["content"]}' for point in parsed["strengths"][:max_strengths]] ) new_parsed["weaknesses"] = "\n".join( [f'- {point["content"]}' for point in parsed["weaknesses"][:max_weaknesses]] ) request_changes_sorted = sorted( parsed["requested_changes"], key=lambda x: importance_mapping[x["importance"]], reverse=True, ) new_parsed["requested_changes"] = "\n".join( [ f"- {point['content']}" for point in request_changes_sorted[:max_requested_changes] ] ) return new_parsed def process(file_content, progress=gr.Progress()): # Create a list to store log messages log_messages = [] def log(msg): print(msg) log_messages.append(msg) return "\n".join(log_messages) if not os.path.exists("cache"): os.makedirs("cache") pdf_path = f"cache/{time.time()}.pdf" with open(pdf_path, "wb") as f: f.write(file_content) progress(0, desc="Starting review process...") log("Starting review process...") all_reviews = [] for i in range(3): progress((i + 1) / 3, desc=f"Generating review {i+1}/3") log(f"Generating review {i+1}/3...") all_reviews.append(ask_claude(prompts.review_prompt, pdf_path=pdf_path)) all_reviews_string = "\n\n".join([f"Review {i+1}:\n{review}" for i, review in enumerate(all_reviews)]) progress(0.4, desc="Combining reviews...") log("Combining reviews...") combined_review = ask_claude(prompts.combine_prompt.format(all_reviews_string=all_reviews_string, review_format=prompts.review_format), pdf_path=pdf_path) progress(0.6, desc="Defending paper...") log("Defending paper...") rebuttal = ask_claude(prompts.defend_prompt.format(combined_review=combined_review), pdf_path=pdf_path) progress(0.8, desc="Revising review...") log("Revising review...") revised_review = ask_claude(prompts.revise_prompt.format(review_format=prompts.review_format, combined_review=combined_review, defended_paper=rebuttal), pdf_path=pdf_path) log("Humanizing review...") humanized_review = ask_claude(prompts.human_style.format(review=revised_review), pdf_path=pdf_path) progress(0.9, desc="Formatting review...") log("Formatting review...") formatted_review = parse_final(format_gpt(prompts.formatting_prompt.format(review=humanized_review))) log("Finished!") contributions, strengths, weaknesses, requested_changes, impact_concerns = ( formatted_review["contributions"], formatted_review["strengths"], formatted_review["weaknesses"], formatted_review["requested_changes"], formatted_review["impact_concerns"], ) contributions = f"# Contributions\n\n{contributions}" strengths = f"# Strengths\n\n{strengths}" weaknesses = f"# Weaknesses\n\n{weaknesses}" requested_changes = f"# Requested Changes\n\n{requested_changes}" impact_concerns = f"# Impact Concerns\n\n{impact_concerns}" return ( contributions, strengths, weaknesses, requested_changes, impact_concerns, "\n".join(log_messages), # Return the log messages ) def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("# TMLR Reviewer") gr.Markdown("This tool helps you generate high-quality reviews for the Transactions on Machine Learning Research (TMLR).") with gr.Row(): # Left column left_column = gr.Column(scale=1) with left_column: upload_component = gr.File(label="Upload PDF", type="binary") submit_btn = gr.Button("Generate Review") # Progress log moved below upload section progress_log = gr.Textbox(label="Progress Log", interactive=False, lines=10) # Right column for review outputs right_column = gr.Column(scale=2) with right_column: output_component_contributions = gr.Markdown(label="Contributions") output_component_strengths = gr.Markdown(label="Strengths") output_component_weaknesses = gr.Markdown(label="Weaknesses") output_component_requested_changes = gr.Markdown(label="Requested Changes") output_component_impact_concerns = gr.Markdown(label="Impact Concerns") submit_btn.click( fn=process, inputs=upload_component, outputs=[ output_component_contributions, output_component_strengths, output_component_weaknesses, output_component_requested_changes, output_component_impact_concerns, progress_log, ] ) demo.queue() return demo if __name__ == "__main__": demo = gradio_interface() demo.launch(share=False)