Spaces:
Sleeping
Sleeping
shigeru saito
commited on
Commit
·
e87d7b4
1
Parent(s):
f56a350
エラー処理追加, 実行例追加, csvの一括処理を追加
Browse files- app.py +76 -4
- sample.csv +4 -0
app.py
CHANGED
@@ -2,11 +2,14 @@ import os
|
|
2 |
import dotenv
|
3 |
import sys
|
4 |
import gradio as gr
|
|
|
5 |
|
6 |
from langchain import LLMChain
|
7 |
from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
from langchain.utilities import GoogleSearchAPIWrapper
|
|
|
|
|
10 |
|
11 |
dotenv.load_dotenv()
|
12 |
|
@@ -40,22 +43,91 @@ def search_and_generate(question, prefix = "次の質問にできる限り答え
|
|
40 |
)
|
41 |
|
42 |
# エージェントの準備
|
43 |
-
llm = ChatOpenAI(model_name="gpt-
|
44 |
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
45 |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
|
46 |
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
47 |
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
return result
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def main():
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
question = sys.argv[1]
|
55 |
result = search_and_generate(question)
|
56 |
print(result)
|
57 |
else:
|
58 |
-
gr.Interface(fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
if __name__ == "__main__":
|
61 |
main()
|
|
|
2 |
import dotenv
|
3 |
import sys
|
4 |
import gradio as gr
|
5 |
+
import pandas as pd
|
6 |
|
7 |
from langchain import LLMChain
|
8 |
from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
from langchain.utilities import GoogleSearchAPIWrapper
|
11 |
+
from langchain.schema.output_parser import OutputParserException
|
12 |
+
from googleapiclient.errors import HttpError
|
13 |
|
14 |
dotenv.load_dotenv()
|
15 |
|
|
|
43 |
)
|
44 |
|
45 |
# エージェントの準備
|
46 |
+
llm = ChatOpenAI(model_name="gpt-3.5-turbo")
|
47 |
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
48 |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
|
49 |
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
50 |
|
51 |
+
try:
|
52 |
+
result = agent_executor.run(question)
|
53 |
+
except OutputParserException as e:
|
54 |
+
result = '例外が発生しました: OutputParserException ' + str(e.args[0])
|
55 |
|
56 |
return result
|
57 |
|
58 |
+
def search_and_generate_csv(csv_file):
|
59 |
+
|
60 |
+
output_csv_file = csv_file.replace(".csv", "_output.csv")
|
61 |
+
questions_df = pd.read_csv(csv_file)
|
62 |
+
|
63 |
+
# 結果を格納するDataFrameを作成
|
64 |
+
results_df = pd.DataFrame(columns=["question", "answer"])
|
65 |
+
|
66 |
+
for i, row in questions_df.iterrows():
|
67 |
+
|
68 |
+
question = ""
|
69 |
+
if i == 0 and "question" not in row:
|
70 |
+
# 1行目はヘッダーなのでquestionがある場合は1行目を無視する
|
71 |
+
continue
|
72 |
+
elif "question" not in row:
|
73 |
+
# questionがない場合は1列目をquestionとして扱う
|
74 |
+
question = row[0]
|
75 |
+
else:
|
76 |
+
# questionがある場合はquestionの列を参照する
|
77 |
+
question = row["question"]
|
78 |
+
|
79 |
+
# 質問に対する回答を取得
|
80 |
+
answer = search_and_generate(question)
|
81 |
+
|
82 |
+
# 結果をDataFrameに追加
|
83 |
+
results_df.loc[i] = [question, answer]
|
84 |
+
|
85 |
+
# 結果をCSVファイルに保存
|
86 |
+
result_csv = results_df.to_csv(output_csv_file, index=False)
|
87 |
+
|
88 |
+
return result_csv
|
89 |
+
|
90 |
+
def process_input(str_question, file_csv):
|
91 |
+
# file_csvが空でない場合はCSVファイルとして扱う
|
92 |
+
try:
|
93 |
+
if file_csv:
|
94 |
+
return search_and_generate_csv(file_csv.name)
|
95 |
+
elif str_question:
|
96 |
+
return search_and_generate(str_question)
|
97 |
+
except HttpError as e:
|
98 |
+
print("An HTTP error %d occurred:\n%s" % (e.resp.status, e.content))
|
99 |
+
# reasonがrateLimitExceededの場合
|
100 |
+
if "rateLimitExceeded" in e.content.decode():
|
101 |
+
result = "検索APIリクエストの上限に達しました。"
|
102 |
+
else:
|
103 |
+
result = "検索API実行時に通信エラーが発生しました。"
|
104 |
+
return result
|
105 |
+
|
106 |
def main():
|
107 |
+
# パラメータに --csv がある場合、かつ、その次の引数がファイル名の場合
|
108 |
+
if "--csv" in sys.argv and len(sys.argv) > sys.argv.index("--csv") + 1:
|
109 |
+
csv_file = sys.argv[sys.argv.index("--csv") + 1]
|
110 |
+
|
111 |
+
# 引数のファイル名からCSVファイルを読み込む
|
112 |
+
result_csv = search_and_generate_csv(csv_file)
|
113 |
+
|
114 |
+
# csvファイルの内容をprintで表示する
|
115 |
+
print(result_csv)
|
116 |
+
|
117 |
+
elif len(sys.argv) > 1:
|
118 |
question = sys.argv[1]
|
119 |
result = search_and_generate(question)
|
120 |
print(result)
|
121 |
else:
|
122 |
+
gr.Interface(fn=process_input,
|
123 |
+
inputs=[gr.Text(), gr.File()],
|
124 |
+
outputs="text",
|
125 |
+
examples=[
|
126 |
+
["", "sample.csv"],
|
127 |
+
["日下部民藝館に伝わる伝承は?", None],
|
128 |
+
["醍醐寺に伝わる伝承は?", None],
|
129 |
+
],
|
130 |
+
).launch()
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
main()
|
sample.csv
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question, answer
|
2 |
+
葛飾区に伝わる伝承は?50文字以内で答えてください,
|
3 |
+
品川区に伝わる伝承は?50文字以内で答えてください,
|
4 |
+
箱根に伝わる伝承は?50文字以内で答えてください,
|