Update app.py
Browse files
app.py
CHANGED
@@ -5,42 +5,24 @@ import os
|
|
5 |
import openai
|
6 |
import re
|
7 |
import gradio as gr
|
8 |
-
|
9 |
-
from
|
10 |
-
from googleapiclient.http import MediaIoBaseUpload, MediaIoBaseDownload
|
11 |
-
import io
|
12 |
import json
|
13 |
|
14 |
-
def
|
15 |
-
"""Google
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
file_metadata = {'name': filename, 'parents': [folder_id]}
|
24 |
-
media = MediaIoBaseUpload(io.BytesIO(content.encode()), mimetype='text/plain')
|
25 |
-
file = service.files().create(body=file_metadata, media_body=media, fields='id').execute()
|
26 |
-
return file.get('id')
|
27 |
-
|
28 |
-
def find_in_google_drive(service, folder_id, paper_id):
|
29 |
-
"""Google Driveでファイルを検索し、内容を返す。"""
|
30 |
-
query = f"parents='{folder_id}' and name contains '{paper_id}' and trashed=false"
|
31 |
-
response = service.files().list(q=query, spaces='drive', fields='files(id, name)').execute()
|
32 |
-
if not response.get('files'):
|
33 |
-
return None
|
34 |
-
file_id = response.get('files')[0].get('id')
|
35 |
-
request = service.files().get_media(fileId=file_id)
|
36 |
-
fh = io.BytesIO()
|
37 |
-
downloader = MediaIoBaseDownload(fh, request)
|
38 |
-
done = False
|
39 |
-
while done is False:
|
40 |
-
_, done = downloader.next_chunk()
|
41 |
-
fh.seek(0)
|
42 |
-
content = fh.read().decode('utf-8')
|
43 |
-
return content
|
44 |
|
45 |
def download_paper(paper_url):
|
46 |
"""論文PDFをダウンロードして保存。"""
|
@@ -62,7 +44,7 @@ def summarize_text_with_chat(text, max_length=10000):
|
|
62 |
"""OpenAIのChat APIを使ってテキストを要約。"""
|
63 |
openai.api_key = os.getenv('OPENAI_API_KEY')
|
64 |
trimmed_text = text[:max_length]
|
65 |
-
response = openai.
|
66 |
model="gpt-3.5-turbo-0125",
|
67 |
messages=[
|
68 |
{"role": "system", "content": "次の文書を要約してください。必ず'## タイトル', '## 要約', '## 専門用語解説'を記載してください。"},
|
@@ -71,8 +53,9 @@ def summarize_text_with_chat(text, max_length=10000):
|
|
71 |
temperature=0.7,
|
72 |
max_tokens=1000
|
73 |
)
|
74 |
-
summary_text = response.choices[0].message
|
75 |
-
|
|
|
76 |
|
77 |
def fetch_paper_links(url):
|
78 |
"""指定したURLから論文のリンクを抽出し、重複を排除。"""
|
@@ -86,33 +69,37 @@ def fetch_paper_links(url):
|
|
86 |
links.append(href)
|
87 |
return links
|
88 |
|
89 |
-
def
|
90 |
-
"""Google
|
91 |
-
existing_summary = find_in_google_drive(service, folder_id, paper_id)
|
92 |
-
if existing_summary:
|
93 |
-
return existing_summary
|
94 |
paper_url = f"https://arxiv.org/pdf/{paper_id}.pdf"
|
95 |
pdf_path = download_paper(paper_url)
|
96 |
text = extract_text_from_pdf(pdf_path)
|
97 |
summary = summarize_text_with_chat(text)
|
98 |
os.remove(pdf_path)
|
99 |
-
|
100 |
-
save_to_google_drive(service, folder_id, filename, summary)
|
101 |
return summary
|
102 |
|
103 |
def gradio_interface():
|
104 |
-
service = google_drive_authenticate()
|
105 |
-
folder_id = '1yOXimp4kk7eohWKGtVo-gn93M0A404TM'
|
106 |
-
summaries = []
|
107 |
paper_links = fetch_paper_links("https://huggingface.co/papers")
|
108 |
-
paper_ids =
|
|
|
|
|
|
|
109 |
|
110 |
for paper_id in paper_ids:
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
summaries_markdown = "\n---\n".join(summaries)
|
115 |
-
return summaries_markdown
|
116 |
|
117 |
iface = gr.Interface(
|
118 |
fn=gradio_interface,
|
|
|
5 |
import openai
|
6 |
import re
|
7 |
import gradio as gr
|
8 |
+
import gspread
|
9 |
+
from oauth2client.service_account import ServiceAccountCredentials
|
|
|
|
|
10 |
import json
|
11 |
|
12 |
+
def connect_gspread(spread_sheet_key):
|
13 |
+
"""Google スプレッドシートに接続。"""
|
14 |
+
credentials_json = os.getenv('GOOGLE_CREDENTIALS')
|
15 |
+
credentials_dict = json.loads(credentials_json)
|
16 |
+
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
|
17 |
+
|
18 |
+
credentials = ServiceAccountCredentials.from_json_keyfile_dict(credentials_dict, scope)
|
19 |
+
gc = gspread.authorize(credentials)
|
20 |
+
SPREADSHEET_KEY = spread_sheet_key
|
21 |
+
worksheet = gc.open_by_key(SPREADSHEET_KEY).sheet1
|
22 |
+
return worksheet
|
23 |
|
24 |
+
spread_sheet_key = "1nSh6D_Gqdbhi1CB3wvD4OJUU6bji8-LE6HET7NTEjrM"
|
25 |
+
worksheet = connect_gspread(spread_sheet_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def download_paper(paper_url):
|
28 |
"""論文PDFをダウンロードして保存。"""
|
|
|
44 |
"""OpenAIのChat APIを使ってテキストを要約。"""
|
45 |
openai.api_key = os.getenv('OPENAI_API_KEY')
|
46 |
trimmed_text = text[:max_length]
|
47 |
+
response = openai.chat.completions.create(
|
48 |
model="gpt-3.5-turbo-0125",
|
49 |
messages=[
|
50 |
{"role": "system", "content": "次の文書を要約してください。必ず'## タイトル', '## 要約', '## 専門用語解説'を記載してください。"},
|
|
|
53 |
temperature=0.7,
|
54 |
max_tokens=1000
|
55 |
)
|
56 |
+
summary_text = response.choices[0].message.content
|
57 |
+
total_token = response.usage.total_tokens
|
58 |
+
return summary_text, total_token
|
59 |
|
60 |
def fetch_paper_links(url):
|
61 |
"""指定したURLから論文のリンクを抽出し、重複を排除。"""
|
|
|
69 |
links.append(href)
|
70 |
return links
|
71 |
|
72 |
+
def summarize_paper_and_save_to_sheet(paper_id):
|
73 |
+
"""論文を要約し、結果をGoogle スプレッドシートに保存。"""
|
|
|
|
|
|
|
74 |
paper_url = f"https://arxiv.org/pdf/{paper_id}.pdf"
|
75 |
pdf_path = download_paper(paper_url)
|
76 |
text = extract_text_from_pdf(pdf_path)
|
77 |
summary = summarize_text_with_chat(text)
|
78 |
os.remove(pdf_path)
|
79 |
+
worksheet.append_row([paper_id, paper_url, summary])
|
|
|
80 |
return summary
|
81 |
|
82 |
def gradio_interface():
|
|
|
|
|
|
|
83 |
paper_links = fetch_paper_links("https://huggingface.co/papers")
|
84 |
+
paper_ids = set(link.split('/')[-1] for link in paper_links)
|
85 |
+
|
86 |
+
total_tokens_used = 0
|
87 |
+
summaries = []
|
88 |
|
89 |
for paper_id in paper_ids:
|
90 |
+
summary_info = ""
|
91 |
+
try:
|
92 |
+
summary, tokens_used = summarize_paper_and_save_to_sheet(paper_id)
|
93 |
+
total_tokens_used += tokens_used
|
94 |
+
paper_id_url = f"https://arxiv.org/pdf/{paper_id}.pdf"
|
95 |
+
summary_info += f'論文: {paper_id_url}\n{summary}\n'
|
96 |
+
except Exception as e:
|
97 |
+
summary_info += f"Error processing paper ID {paper_id}: {e}\n"
|
98 |
+
|
99 |
+
summaries.append(summary_info)
|
100 |
|
101 |
+
summaries_markdown = "\n---\n".join(summaries) # 要約を水平線で区切る
|
102 |
+
return summaries_markdown + f"\n全ての要約で使用されたトータルトークン数: {total_tokens_used}"
|
103 |
|
104 |
iface = gr.Interface(
|
105 |
fn=gradio_interface,
|