shigeru saito commited on
Commit
e87d7b4
·
1 Parent(s): f56a350

エラー処理追加, 実行例追加, csvの一括処理を追加

Browse files
Files changed (2) hide show
  1. app.py +76 -4
  2. sample.csv +4 -0
app.py CHANGED
@@ -2,11 +2,14 @@ import os
2
  import dotenv
3
  import sys
4
  import gradio as gr
 
5
 
6
  from langchain import LLMChain
7
  from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
8
  from langchain.chat_models import ChatOpenAI
9
  from langchain.utilities import GoogleSearchAPIWrapper
 
 
10
 
11
  dotenv.load_dotenv()
12
 
@@ -40,22 +43,91 @@ def search_and_generate(question, prefix = "次の質問にできる限り答え
40
  )
41
 
42
  # エージェントの準備
43
- llm = ChatOpenAI(model_name="gpt-4")
44
  llm_chain = LLMChain(llm=llm, prompt=prompt)
45
  agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
46
  agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
47
 
48
- result = agent_executor.run(question)
 
 
 
49
 
50
  return result
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def main():
53
- if len(sys.argv) > 1:
 
 
 
 
 
 
 
 
 
 
54
  question = sys.argv[1]
55
  result = search_and_generate(question)
56
  print(result)
57
  else:
58
- gr.Interface(fn=search_and_generate, inputs="text", outputs="text").launch()
 
 
 
 
 
 
 
 
59
 
60
  if __name__ == "__main__":
61
  main()
 
2
  import dotenv
3
  import sys
4
  import gradio as gr
5
+ import pandas as pd
6
 
7
  from langchain import LLMChain
8
  from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
9
  from langchain.chat_models import ChatOpenAI
10
  from langchain.utilities import GoogleSearchAPIWrapper
11
+ from langchain.schema.output_parser import OutputParserException
12
+ from googleapiclient.errors import HttpError
13
 
14
  dotenv.load_dotenv()
15
 
 
43
  )
44
 
45
  # エージェントの準備
46
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo")
47
  llm_chain = LLMChain(llm=llm, prompt=prompt)
48
  agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
49
  agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
50
 
51
+ try:
52
+ result = agent_executor.run(question)
53
+ except OutputParserException as e:
54
+ result = '例外が発生しました: OutputParserException ' + str(e.args[0])
55
 
56
  return result
57
 
58
+ def search_and_generate_csv(csv_file):
59
+
60
+ output_csv_file = csv_file.replace(".csv", "_output.csv")
61
+ questions_df = pd.read_csv(csv_file)
62
+
63
+ # 結果を格納するDataFrameを作成
64
+ results_df = pd.DataFrame(columns=["question", "answer"])
65
+
66
+ for i, row in questions_df.iterrows():
67
+
68
+ question = ""
69
+ if i == 0 and "question" not in row:
70
+ # 1行目はヘッダーなのでquestionがある場合は1行目を無視する
71
+ continue
72
+ elif "question" not in row:
73
+ # questionがない場合は1列目をquestionとして扱う
74
+ question = row[0]
75
+ else:
76
+ # questionがある場合はquestionの列を参照する
77
+ question = row["question"]
78
+
79
+ # 質問に対する回答を取得
80
+ answer = search_and_generate(question)
81
+
82
+ # 結果をDataFrameに追加
83
+ results_df.loc[i] = [question, answer]
84
+
85
+ # 結果をCSVファイルに保存
86
+ result_csv = results_df.to_csv(output_csv_file, index=False)
87
+
88
+ return result_csv
89
+
90
+ def process_input(str_question, file_csv):
91
+ # file_csvが空でない場合はCSVファイルとして扱う
92
+ try:
93
+ if file_csv:
94
+ return search_and_generate_csv(file_csv.name)
95
+ elif str_question:
96
+ return search_and_generate(str_question)
97
+ except HttpError as e:
98
+ print("An HTTP error %d occurred:\n%s" % (e.resp.status, e.content))
99
+ # reasonがrateLimitExceededの場合
100
+ if "rateLimitExceeded" in e.content.decode():
101
+ result = "検索APIリクエストの上限に達しました。"
102
+ else:
103
+ result = "検索API実行時に通信エラーが発生しました。"
104
+ return result
105
+
106
  def main():
107
+ # パラメータに --csv がある場合、かつ、その次の引数がファイル名の場合
108
+ if "--csv" in sys.argv and len(sys.argv) > sys.argv.index("--csv") + 1:
109
+ csv_file = sys.argv[sys.argv.index("--csv") + 1]
110
+
111
+ # 引数のファイル名からCSVファイルを読み込む
112
+ result_csv = search_and_generate_csv(csv_file)
113
+
114
+ # csvファイルの内容をprintで表示する
115
+ print(result_csv)
116
+
117
+ elif len(sys.argv) > 1:
118
  question = sys.argv[1]
119
  result = search_and_generate(question)
120
  print(result)
121
  else:
122
+ gr.Interface(fn=process_input,
123
+ inputs=[gr.Text(), gr.File()],
124
+ outputs="text",
125
+ examples=[
126
+ ["", "sample.csv"],
127
+ ["日下部民藝館に伝わる伝承は?", None],
128
+ ["醍醐寺に伝わる伝承は?", None],
129
+ ],
130
+ ).launch()
131
 
132
  if __name__ == "__main__":
133
  main()
sample.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ question, answer
2
+ 葛飾区に伝わる伝承は?50文字以内で答えてください,
3
+ 品川区に伝わる伝承は?50文字以内で答えてください,
4
+ 箱根に伝わる伝承は?50文字以内で答えてください,