Kevin Wu commited on
Commit
21a5880
·
0 Parent(s):

Intial commit

Browse files
Files changed (3) hide show
  1. README.md +14 -0
  2. requirements.txt +6 -0
  3. run_claude.py +261 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TMLR Paper Reviewer
3
+ emoji: 📝
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: run_claude.py
9
+ pinned: false
10
+ ---
11
+
12
+ # TMLR Paper Reviewer
13
+
14
+ This tool helps generate high-quality reviews for papers submitted to the Transactions on Machine Learning Research (TMLR).
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ anthropic
3
+ openai
4
+ PyPDF2
5
+ tiktoken
6
+ pydantic
run_claude.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+ import base64
3
+ import pandas as pd
4
+ import requests
5
+ import os
6
+ import numpy as np
7
+ from openai import OpenAI
8
+ import io
9
+ import tiktoken
10
+ import PyPDF2
11
+
12
+ import prompts
13
+
14
+ from typing import List, Literal
15
+ from pydantic import BaseModel
16
+
17
+ import time
18
+
19
+ import gradio as gr
20
+
21
+ ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY")
22
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
23
+
24
+
25
+ def ask_claude(
26
+ query: str,
27
+ pdf_path: str = None,
28
+ use_cache: bool = False,
29
+ system: str = None,
30
+ max_tokens: int = 1024,
31
+ model: str = "claude-3-5-sonnet-20241022"
32
+ ) -> str:
33
+ """
34
+ Unified function to query Claude API with various options.
35
+
36
+ Args:
37
+ query: Question/prompt for Claude
38
+ pdf_path: Optional path to PDF file (local or URL)
39
+ use_cache: Whether to enable prompt caching
40
+ system: Optional system prompt
41
+ max_tokens: Maximum tokens in response (default 1024)
42
+ model: Claude model to use (default claude-3-5-sonnet)
43
+
44
+ Returns:
45
+ Claude's response as a string
46
+ """
47
+ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
48
+
49
+ # Handle PDF if provided
50
+ content = query
51
+ betas = []
52
+ if pdf_path:
53
+ # Get PDF content
54
+ if pdf_path.startswith(('http://', 'https://')):
55
+ response = requests.get(pdf_path)
56
+ binary_data = response.content
57
+ else:
58
+ with open(pdf_path, "rb") as pdf_file:
59
+ binary_data = pdf_file.read()
60
+
61
+ pdf_data = base64.standard_b64encode(binary_data).decode("utf-8")
62
+ content = [
63
+ {
64
+ "type": "document",
65
+ "source": {
66
+ "type": "base64",
67
+ "media_type": "application/pdf",
68
+ "data": pdf_data
69
+ }
70
+ },
71
+ {
72
+ "type": "text",
73
+ "text": query
74
+ }
75
+ ]
76
+ betas.append("pdfs-2024-09-25")
77
+
78
+ # Add prompt caching if requested
79
+ if use_cache:
80
+ betas.append("prompt-caching-2024-07-31")
81
+
82
+ # Prepare API call kwargs
83
+ kwargs = {
84
+ "model": model,
85
+ "max_tokens": max_tokens,
86
+ "messages": [{"role": "user", "content": content}]
87
+ }
88
+
89
+ # Add optional parameters if provided
90
+ if system:
91
+ kwargs["system"] = system
92
+ if betas:
93
+ kwargs["betas"] = betas
94
+
95
+ message = client.beta.messages.create(**kwargs)
96
+ return message.content[0].text
97
+
98
+ class Point(BaseModel):
99
+ content: str
100
+ importance: Literal["critical", "minor"]
101
+
102
+
103
+ class Review(BaseModel):
104
+ contributions: str
105
+ strengths: List[Point]
106
+ weaknesses: List[Point]
107
+ requested_changes: List[Point]
108
+ impact_concerns: str
109
+
110
+
111
+ importance_mapping = {"critical": 2, "minor": 1}
112
+
113
+ client = OpenAI(api_key=OPENAI_API_KEY)
114
+
115
+ model_name = "gpt-4o-2024-08-06"
116
+
117
+ def format_gpt(prompt):
118
+ chat_completion = client.beta.chat.completions.parse(
119
+ messages=[
120
+ {
121
+ "role": "user",
122
+ "content": prompt,
123
+ }
124
+ ],
125
+ model='gpt-4o',
126
+ response_format=Review,
127
+ )
128
+ return chat_completion.choices[0].message.parsed.model_dump()
129
+
130
+ def parse_final(parsed, max_strengths=3, max_weaknesses=5, max_requested_changes=5):
131
+ new_parsed = {}
132
+ new_parsed["contributions"] = parsed["contributions"]
133
+ new_parsed["impact_concerns"] = parsed["impact_concerns"]
134
+ new_parsed["strengths"] = "\n".join(
135
+ [f'- {point["content"]}' for point in parsed["strengths"][:max_strengths]]
136
+ )
137
+ new_parsed["weaknesses"] = "\n".join(
138
+ [f'- {point["content"]}' for point in parsed["weaknesses"][:max_weaknesses]]
139
+ )
140
+ request_changes_sorted = sorted(
141
+ parsed["requested_changes"],
142
+ key=lambda x: importance_mapping[x["importance"]],
143
+ reverse=True,
144
+ )
145
+ new_parsed["requested_changes"] = "\n".join(
146
+ [
147
+ f"- {point['content']}"
148
+ for point in request_changes_sorted[:max_requested_changes]
149
+ ]
150
+ )
151
+ return new_parsed
152
+
153
+ def process(file_content, progress=gr.Progress()):
154
+ # Create a list to store log messages
155
+ log_messages = []
156
+
157
+ def log(msg):
158
+ print(msg)
159
+ log_messages.append(msg)
160
+ return "\n".join(log_messages)
161
+
162
+ if not os.path.exists("cache"):
163
+ os.makedirs("cache")
164
+ pdf_path = f"cache/{time.time()}.pdf"
165
+ with open(pdf_path, "wb") as f:
166
+ f.write(file_content)
167
+
168
+ progress(0, desc="Starting review process...")
169
+ log("Starting review process...")
170
+
171
+ all_reviews = []
172
+ for i in range(3):
173
+ progress((i + 1) / 3, desc=f"Generating review {i+1}/3")
174
+ log(f"Generating review {i+1}/3...")
175
+ all_reviews.append(ask_claude(prompts.review_prompt, pdf_path=pdf_path))
176
+ all_reviews_string = "\n\n".join([f"Review {i+1}:\n{review}" for i, review in enumerate(all_reviews)])
177
+
178
+ progress(0.4, desc="Combining reviews...")
179
+ log("Combining reviews...")
180
+ combined_review = ask_claude(prompts.combine_prompt.format(all_reviews_string=all_reviews_string,
181
+ review_format=prompts.review_format), pdf_path=pdf_path)
182
+
183
+ progress(0.6, desc="Defending paper...")
184
+ log("Defending paper...")
185
+ rebuttal = ask_claude(prompts.defend_prompt.format(combined_review=combined_review), pdf_path=pdf_path)
186
+
187
+ progress(0.8, desc="Revising review...")
188
+ log("Revising review...")
189
+ revised_review = ask_claude(prompts.revise_prompt.format(review_format=prompts.review_format, combined_review=combined_review, defended_paper=rebuttal), pdf_path=pdf_path)
190
+ log("Humanizing review...")
191
+ humanized_review = ask_claude(prompts.human_style.format(review=revised_review), pdf_path=pdf_path)
192
+
193
+ progress(0.9, desc="Formatting review...")
194
+ log("Formatting review...")
195
+ formatted_review = parse_final(format_gpt(prompts.formatting_prompt.format(review=humanized_review)))
196
+
197
+ log("Finished!")
198
+
199
+ contributions, strengths, weaknesses, requested_changes, impact_concerns = (
200
+ formatted_review["contributions"],
201
+ formatted_review["strengths"],
202
+ formatted_review["weaknesses"],
203
+ formatted_review["requested_changes"],
204
+ formatted_review["impact_concerns"],
205
+ )
206
+ contributions = f"# Contributions\n\n{contributions}"
207
+ strengths = f"# Strengths\n\n{strengths}"
208
+ weaknesses = f"# Weaknesses\n\n{weaknesses}"
209
+ requested_changes = f"# Requested Changes\n\n{requested_changes}"
210
+ impact_concerns = f"# Impact Concerns\n\n{impact_concerns}"
211
+ return (
212
+ contributions,
213
+ strengths,
214
+ weaknesses,
215
+ requested_changes,
216
+ impact_concerns,
217
+ "\n".join(log_messages), # Return the log messages
218
+ )
219
+
220
+ def gradio_interface():
221
+ with gr.Blocks() as demo:
222
+ gr.Markdown("# TMLR Reviewer")
223
+ gr.Markdown("This tool helps you generate high-quality reviews for the Transactions on Machine Learning Research (TMLR).")
224
+
225
+ with gr.Row():
226
+ upload_component = gr.File(label="Upload PDF", type="binary")
227
+
228
+ with gr.Row():
229
+ submit_btn = gr.Button("Generate Review")
230
+
231
+ # Add progress log display
232
+ with gr.Row():
233
+ progress_log = gr.Textbox(label="Progress Log", interactive=False, lines=10)
234
+
235
+ with gr.Row():
236
+ output_component_contributions = gr.Markdown(label="Contributions")
237
+ output_component_strengths = gr.Markdown(label="Strengths")
238
+ output_component_weaknesses = gr.Markdown(label="Weaknesses")
239
+ output_component_requested_changes = gr.Markdown(label="Requested Changes")
240
+ output_component_impact_concerns = gr.Markdown(label="Impact Concerns")
241
+
242
+ submit_btn.click(
243
+ fn=process,
244
+ inputs=upload_component,
245
+ outputs=[
246
+ output_component_contributions,
247
+ output_component_strengths,
248
+ output_component_weaknesses,
249
+ output_component_requested_changes,
250
+ output_component_impact_concerns,
251
+ progress_log, # Add progress_log to outputs
252
+ ]
253
+ )
254
+
255
+ demo.queue()
256
+ return demo
257
+
258
+ if __name__ == "__main__":
259
+ demo = gradio_interface()
260
+ demo.launch(share=False)
261
+