Spaces:
Sleeping
Sleeping
Kevin Wu
commited on
Commit
·
21a5880
0
Parent(s):
Intial commit
Browse files- README.md +14 -0
- requirements.txt +6 -0
- run_claude.py +261 -0
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TMLR Paper Reviewer
|
3 |
+
emoji: 📝
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.19.2
|
8 |
+
app_file: run_claude.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# TMLR Paper Reviewer
|
13 |
+
|
14 |
+
This tool helps generate high-quality reviews for papers submitted to the Transactions on Machine Learning Research (TMLR).
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
anthropic
|
3 |
+
openai
|
4 |
+
PyPDF2
|
5 |
+
tiktoken
|
6 |
+
pydantic
|
run_claude.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import anthropic
|
2 |
+
import base64
|
3 |
+
import pandas as pd
|
4 |
+
import requests
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from openai import OpenAI
|
8 |
+
import io
|
9 |
+
import tiktoken
|
10 |
+
import PyPDF2
|
11 |
+
|
12 |
+
import prompts
|
13 |
+
|
14 |
+
from typing import List, Literal
|
15 |
+
from pydantic import BaseModel
|
16 |
+
|
17 |
+
import time
|
18 |
+
|
19 |
+
import gradio as gr
|
20 |
+
|
21 |
+
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
|
22 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
23 |
+
|
24 |
+
|
25 |
+
def ask_claude(
|
26 |
+
query: str,
|
27 |
+
pdf_path: str = None,
|
28 |
+
use_cache: bool = False,
|
29 |
+
system: str = None,
|
30 |
+
max_tokens: int = 1024,
|
31 |
+
model: str = "claude-3-5-sonnet-20241022"
|
32 |
+
) -> str:
|
33 |
+
"""
|
34 |
+
Unified function to query Claude API with various options.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
query: Question/prompt for Claude
|
38 |
+
pdf_path: Optional path to PDF file (local or URL)
|
39 |
+
use_cache: Whether to enable prompt caching
|
40 |
+
system: Optional system prompt
|
41 |
+
max_tokens: Maximum tokens in response (default 1024)
|
42 |
+
model: Claude model to use (default claude-3-5-sonnet)
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
Claude's response as a string
|
46 |
+
"""
|
47 |
+
client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
|
48 |
+
|
49 |
+
# Handle PDF if provided
|
50 |
+
content = query
|
51 |
+
betas = []
|
52 |
+
if pdf_path:
|
53 |
+
# Get PDF content
|
54 |
+
if pdf_path.startswith(('http://', 'https://')):
|
55 |
+
response = requests.get(pdf_path)
|
56 |
+
binary_data = response.content
|
57 |
+
else:
|
58 |
+
with open(pdf_path, "rb") as pdf_file:
|
59 |
+
binary_data = pdf_file.read()
|
60 |
+
|
61 |
+
pdf_data = base64.standard_b64encode(binary_data).decode("utf-8")
|
62 |
+
content = [
|
63 |
+
{
|
64 |
+
"type": "document",
|
65 |
+
"source": {
|
66 |
+
"type": "base64",
|
67 |
+
"media_type": "application/pdf",
|
68 |
+
"data": pdf_data
|
69 |
+
}
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"type": "text",
|
73 |
+
"text": query
|
74 |
+
}
|
75 |
+
]
|
76 |
+
betas.append("pdfs-2024-09-25")
|
77 |
+
|
78 |
+
# Add prompt caching if requested
|
79 |
+
if use_cache:
|
80 |
+
betas.append("prompt-caching-2024-07-31")
|
81 |
+
|
82 |
+
# Prepare API call kwargs
|
83 |
+
kwargs = {
|
84 |
+
"model": model,
|
85 |
+
"max_tokens": max_tokens,
|
86 |
+
"messages": [{"role": "user", "content": content}]
|
87 |
+
}
|
88 |
+
|
89 |
+
# Add optional parameters if provided
|
90 |
+
if system:
|
91 |
+
kwargs["system"] = system
|
92 |
+
if betas:
|
93 |
+
kwargs["betas"] = betas
|
94 |
+
|
95 |
+
message = client.beta.messages.create(**kwargs)
|
96 |
+
return message.content[0].text
|
97 |
+
|
98 |
+
class Point(BaseModel):
|
99 |
+
content: str
|
100 |
+
importance: Literal["critical", "minor"]
|
101 |
+
|
102 |
+
|
103 |
+
class Review(BaseModel):
|
104 |
+
contributions: str
|
105 |
+
strengths: List[Point]
|
106 |
+
weaknesses: List[Point]
|
107 |
+
requested_changes: List[Point]
|
108 |
+
impact_concerns: str
|
109 |
+
|
110 |
+
|
111 |
+
importance_mapping = {"critical": 2, "minor": 1}
|
112 |
+
|
113 |
+
client = OpenAI(api_key=OPENAI_API_KEY)
|
114 |
+
|
115 |
+
model_name = "gpt-4o-2024-08-06"
|
116 |
+
|
117 |
+
def format_gpt(prompt):
|
118 |
+
chat_completion = client.beta.chat.completions.parse(
|
119 |
+
messages=[
|
120 |
+
{
|
121 |
+
"role": "user",
|
122 |
+
"content": prompt,
|
123 |
+
}
|
124 |
+
],
|
125 |
+
model='gpt-4o',
|
126 |
+
response_format=Review,
|
127 |
+
)
|
128 |
+
return chat_completion.choices[0].message.parsed.model_dump()
|
129 |
+
|
130 |
+
def parse_final(parsed, max_strengths=3, max_weaknesses=5, max_requested_changes=5):
|
131 |
+
new_parsed = {}
|
132 |
+
new_parsed["contributions"] = parsed["contributions"]
|
133 |
+
new_parsed["impact_concerns"] = parsed["impact_concerns"]
|
134 |
+
new_parsed["strengths"] = "\n".join(
|
135 |
+
[f'- {point["content"]}' for point in parsed["strengths"][:max_strengths]]
|
136 |
+
)
|
137 |
+
new_parsed["weaknesses"] = "\n".join(
|
138 |
+
[f'- {point["content"]}' for point in parsed["weaknesses"][:max_weaknesses]]
|
139 |
+
)
|
140 |
+
request_changes_sorted = sorted(
|
141 |
+
parsed["requested_changes"],
|
142 |
+
key=lambda x: importance_mapping[x["importance"]],
|
143 |
+
reverse=True,
|
144 |
+
)
|
145 |
+
new_parsed["requested_changes"] = "\n".join(
|
146 |
+
[
|
147 |
+
f"- {point['content']}"
|
148 |
+
for point in request_changes_sorted[:max_requested_changes]
|
149 |
+
]
|
150 |
+
)
|
151 |
+
return new_parsed
|
152 |
+
|
153 |
+
def process(file_content, progress=gr.Progress()):
|
154 |
+
# Create a list to store log messages
|
155 |
+
log_messages = []
|
156 |
+
|
157 |
+
def log(msg):
|
158 |
+
print(msg)
|
159 |
+
log_messages.append(msg)
|
160 |
+
return "\n".join(log_messages)
|
161 |
+
|
162 |
+
if not os.path.exists("cache"):
|
163 |
+
os.makedirs("cache")
|
164 |
+
pdf_path = f"cache/{time.time()}.pdf"
|
165 |
+
with open(pdf_path, "wb") as f:
|
166 |
+
f.write(file_content)
|
167 |
+
|
168 |
+
progress(0, desc="Starting review process...")
|
169 |
+
log("Starting review process...")
|
170 |
+
|
171 |
+
all_reviews = []
|
172 |
+
for i in range(3):
|
173 |
+
progress((i + 1) / 3, desc=f"Generating review {i+1}/3")
|
174 |
+
log(f"Generating review {i+1}/3...")
|
175 |
+
all_reviews.append(ask_claude(prompts.review_prompt, pdf_path=pdf_path))
|
176 |
+
all_reviews_string = "\n\n".join([f"Review {i+1}:\n{review}" for i, review in enumerate(all_reviews)])
|
177 |
+
|
178 |
+
progress(0.4, desc="Combining reviews...")
|
179 |
+
log("Combining reviews...")
|
180 |
+
combined_review = ask_claude(prompts.combine_prompt.format(all_reviews_string=all_reviews_string,
|
181 |
+
review_format=prompts.review_format), pdf_path=pdf_path)
|
182 |
+
|
183 |
+
progress(0.6, desc="Defending paper...")
|
184 |
+
log("Defending paper...")
|
185 |
+
rebuttal = ask_claude(prompts.defend_prompt.format(combined_review=combined_review), pdf_path=pdf_path)
|
186 |
+
|
187 |
+
progress(0.8, desc="Revising review...")
|
188 |
+
log("Revising review...")
|
189 |
+
revised_review = ask_claude(prompts.revise_prompt.format(review_format=prompts.review_format, combined_review=combined_review, defended_paper=rebuttal), pdf_path=pdf_path)
|
190 |
+
log("Humanizing review...")
|
191 |
+
humanized_review = ask_claude(prompts.human_style.format(review=revised_review), pdf_path=pdf_path)
|
192 |
+
|
193 |
+
progress(0.9, desc="Formatting review...")
|
194 |
+
log("Formatting review...")
|
195 |
+
formatted_review = parse_final(format_gpt(prompts.formatting_prompt.format(review=humanized_review)))
|
196 |
+
|
197 |
+
log("Finished!")
|
198 |
+
|
199 |
+
contributions, strengths, weaknesses, requested_changes, impact_concerns = (
|
200 |
+
formatted_review["contributions"],
|
201 |
+
formatted_review["strengths"],
|
202 |
+
formatted_review["weaknesses"],
|
203 |
+
formatted_review["requested_changes"],
|
204 |
+
formatted_review["impact_concerns"],
|
205 |
+
)
|
206 |
+
contributions = f"# Contributions\n\n{contributions}"
|
207 |
+
strengths = f"# Strengths\n\n{strengths}"
|
208 |
+
weaknesses = f"# Weaknesses\n\n{weaknesses}"
|
209 |
+
requested_changes = f"# Requested Changes\n\n{requested_changes}"
|
210 |
+
impact_concerns = f"# Impact Concerns\n\n{impact_concerns}"
|
211 |
+
return (
|
212 |
+
contributions,
|
213 |
+
strengths,
|
214 |
+
weaknesses,
|
215 |
+
requested_changes,
|
216 |
+
impact_concerns,
|
217 |
+
"\n".join(log_messages), # Return the log messages
|
218 |
+
)
|
219 |
+
|
220 |
+
def gradio_interface():
|
221 |
+
with gr.Blocks() as demo:
|
222 |
+
gr.Markdown("# TMLR Reviewer")
|
223 |
+
gr.Markdown("This tool helps you generate high-quality reviews for the Transactions on Machine Learning Research (TMLR).")
|
224 |
+
|
225 |
+
with gr.Row():
|
226 |
+
upload_component = gr.File(label="Upload PDF", type="binary")
|
227 |
+
|
228 |
+
with gr.Row():
|
229 |
+
submit_btn = gr.Button("Generate Review")
|
230 |
+
|
231 |
+
# Add progress log display
|
232 |
+
with gr.Row():
|
233 |
+
progress_log = gr.Textbox(label="Progress Log", interactive=False, lines=10)
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
output_component_contributions = gr.Markdown(label="Contributions")
|
237 |
+
output_component_strengths = gr.Markdown(label="Strengths")
|
238 |
+
output_component_weaknesses = gr.Markdown(label="Weaknesses")
|
239 |
+
output_component_requested_changes = gr.Markdown(label="Requested Changes")
|
240 |
+
output_component_impact_concerns = gr.Markdown(label="Impact Concerns")
|
241 |
+
|
242 |
+
submit_btn.click(
|
243 |
+
fn=process,
|
244 |
+
inputs=upload_component,
|
245 |
+
outputs=[
|
246 |
+
output_component_contributions,
|
247 |
+
output_component_strengths,
|
248 |
+
output_component_weaknesses,
|
249 |
+
output_component_requested_changes,
|
250 |
+
output_component_impact_concerns,
|
251 |
+
progress_log, # Add progress_log to outputs
|
252 |
+
]
|
253 |
+
)
|
254 |
+
|
255 |
+
demo.queue()
|
256 |
+
return demo
|
257 |
+
|
258 |
+
if __name__ == "__main__":
|
259 |
+
demo = gradio_interface()
|
260 |
+
demo.launch(share=False)
|
261 |
+
|