Spaces:
Sleeping
Sleeping
import anthropic | |
import base64 | |
import pandas as pd | |
import requests | |
import os | |
import numpy as np | |
from openai import OpenAI | |
import io | |
import tiktoken | |
import PyPDF2 | |
import prompts | |
from typing import List, Literal | |
from pydantic import BaseModel | |
import time | |
import gradio as gr | |
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY") | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
def ask_claude( | |
query: str, | |
pdf_path: str = None, | |
use_cache: bool = False, | |
system: str = None, | |
max_tokens: int = 1024, | |
model: str = "claude-3-5-sonnet-20241022" | |
) -> str: | |
""" | |
Unified function to query Claude API with various options. | |
Args: | |
query: Question/prompt for Claude | |
pdf_path: Optional path to PDF file (local or URL) | |
use_cache: Whether to enable prompt caching | |
system: Optional system prompt | |
max_tokens: Maximum tokens in response (default 1024) | |
model: Claude model to use (default claude-3-5-sonnet) | |
Returns: | |
Claude's response as a string | |
""" | |
client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) | |
# Handle PDF if provided | |
content = query | |
betas = [] | |
if pdf_path: | |
# Get PDF content | |
if pdf_path.startswith(('http://', 'https://')): | |
response = requests.get(pdf_path) | |
binary_data = response.content | |
else: | |
with open(pdf_path, "rb") as pdf_file: | |
binary_data = pdf_file.read() | |
pdf_data = base64.standard_b64encode(binary_data).decode("utf-8") | |
content = [ | |
{ | |
"type": "document", | |
"source": { | |
"type": "base64", | |
"media_type": "application/pdf", | |
"data": pdf_data | |
} | |
}, | |
{ | |
"type": "text", | |
"text": query | |
} | |
] | |
betas.append("pdfs-2024-09-25") | |
# Add prompt caching if requested | |
if use_cache: | |
betas.append("prompt-caching-2024-07-31") | |
# Prepare API call kwargs | |
kwargs = { | |
"model": model, | |
"max_tokens": max_tokens, | |
"messages": [{"role": "user", "content": content}] | |
} | |
# Add optional parameters if provided | |
if system: | |
kwargs["system"] = system | |
if betas: | |
kwargs["betas"] = betas | |
message = client.beta.messages.create(**kwargs) | |
return message.content[0].text | |
class Point(BaseModel): | |
content: str | |
importance: Literal["critical", "minor"] | |
class Review(BaseModel): | |
contributions: str | |
strengths: List[Point] | |
weaknesses: List[Point] | |
requested_changes: List[Point] | |
impact_concerns: str | |
importance_mapping = {"critical": 2, "minor": 1} | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
model_name = "gpt-4o-2024-08-06" | |
def format_gpt(prompt): | |
chat_completion = client.beta.chat.completions.parse( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
], | |
model='gpt-4o', | |
response_format=Review, | |
) | |
return chat_completion.choices[0].message.parsed.model_dump() | |
def parse_final(parsed, max_strengths=3, max_weaknesses=5, max_requested_changes=5): | |
new_parsed = {} | |
new_parsed["contributions"] = parsed["contributions"] | |
new_parsed["impact_concerns"] = parsed["impact_concerns"] | |
new_parsed["strengths"] = "\n".join( | |
[f'- {point["content"]}' for point in parsed["strengths"][:max_strengths]] | |
) | |
new_parsed["weaknesses"] = "\n".join( | |
[f'- {point["content"]}' for point in parsed["weaknesses"][:max_weaknesses]] | |
) | |
request_changes_sorted = sorted( | |
parsed["requested_changes"], | |
key=lambda x: importance_mapping[x["importance"]], | |
reverse=True, | |
) | |
new_parsed["requested_changes"] = "\n".join( | |
[ | |
f"- {point['content']}" | |
for point in request_changes_sorted[:max_requested_changes] | |
] | |
) | |
return new_parsed | |
def process(file_content, progress=gr.Progress()): | |
# Create a list to store log messages | |
log_messages = [] | |
def log(msg): | |
print(msg) | |
log_messages.append(msg) | |
return "\n".join(log_messages) | |
if not os.path.exists("cache"): | |
os.makedirs("cache") | |
pdf_path = f"cache/{time.time()}.pdf" | |
with open(pdf_path, "wb") as f: | |
f.write(file_content) | |
progress(0, desc="Starting review process...") | |
log("Starting review process...") | |
all_reviews = [] | |
for i in range(3): | |
progress((i + 1) / 3, desc=f"Generating review {i+1}/3") | |
log(f"Generating review {i+1}/3...") | |
all_reviews.append(ask_claude(prompts.review_prompt, pdf_path=pdf_path)) | |
all_reviews_string = "\n\n".join([f"Review {i+1}:\n{review}" for i, review in enumerate(all_reviews)]) | |
progress(0.4, desc="Combining reviews...") | |
log("Combining reviews...") | |
combined_review = ask_claude(prompts.combine_prompt.format(all_reviews_string=all_reviews_string, | |
review_format=prompts.review_format), pdf_path=pdf_path) | |
progress(0.6, desc="Defending paper...") | |
log("Defending paper...") | |
rebuttal = ask_claude(prompts.defend_prompt.format(combined_review=combined_review), pdf_path=pdf_path) | |
progress(0.8, desc="Revising review...") | |
log("Revising review...") | |
revised_review = ask_claude(prompts.revise_prompt.format(review_format=prompts.review_format, combined_review=combined_review, defended_paper=rebuttal), pdf_path=pdf_path) | |
log("Humanizing review...") | |
humanized_review = ask_claude(prompts.human_style.format(review=revised_review), pdf_path=pdf_path) | |
progress(0.9, desc="Formatting review...") | |
log("Formatting review...") | |
formatted_review = parse_final(format_gpt(prompts.formatting_prompt.format(review=humanized_review))) | |
log("Finished!") | |
contributions, strengths, weaknesses, requested_changes, impact_concerns = ( | |
formatted_review["contributions"], | |
formatted_review["strengths"], | |
formatted_review["weaknesses"], | |
formatted_review["requested_changes"], | |
formatted_review["impact_concerns"], | |
) | |
contributions = f"# Contributions\n\n{contributions}" | |
strengths = f"# Strengths\n\n{strengths}" | |
weaknesses = f"# Weaknesses\n\n{weaknesses}" | |
requested_changes = f"# Requested Changes\n\n{requested_changes}" | |
impact_concerns = f"# Impact Concerns\n\n{impact_concerns}" | |
return ( | |
contributions, | |
strengths, | |
weaknesses, | |
requested_changes, | |
impact_concerns, | |
"\n".join(log_messages), # Return the log messages | |
) | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# TMLR Reviewer") | |
gr.Markdown("This tool helps you generate high-quality reviews for the Transactions on Machine Learning Research (TMLR).") | |
with gr.Row(): | |
# Left column | |
left_column = gr.Column(scale=1) | |
with left_column: | |
upload_component = gr.File(label="Upload PDF", type="binary") | |
submit_btn = gr.Button("Generate Review") | |
# Progress log moved below upload section | |
progress_log = gr.Textbox(label="Progress Log", interactive=False, lines=10) | |
# Right column for review outputs | |
right_column = gr.Column(scale=2) | |
with right_column: | |
output_component_contributions = gr.Markdown(label="Contributions") | |
output_component_strengths = gr.Markdown(label="Strengths") | |
output_component_weaknesses = gr.Markdown(label="Weaknesses") | |
output_component_requested_changes = gr.Markdown(label="Requested Changes") | |
output_component_impact_concerns = gr.Markdown(label="Impact Concerns") | |
submit_btn.click( | |
fn=process, | |
inputs=upload_component, | |
outputs=[ | |
output_component_contributions, | |
output_component_strengths, | |
output_component_weaknesses, | |
output_component_requested_changes, | |
output_component_impact_concerns, | |
progress_log, | |
] | |
) | |
demo.queue() | |
return demo | |
if __name__ == "__main__": | |
demo = gradio_interface() | |
demo.launch(share=False) | |