hotchpotch commited on
Commit
8278a9f
·
1 Parent(s): 4e64bcf
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -138,6 +138,7 @@ def to_contexts(passages):
138
 
139
 
140
  def qa(
 
141
  question: str,
142
  passages: list,
143
  model_name: str,
@@ -145,7 +146,7 @@ def qa(
145
  qa_prompt: str,
146
  max_tokens=2000,
147
  ):
148
- client = OpenAI()
149
  contexts = to_contexts(passages)
150
  prompt = qa_prompt.format(contexts=contexts, question=question)
151
  response = client.chat.completions.create(
@@ -164,16 +165,25 @@ def qa(
164
 
165
 
166
  def generate_answer(
167
- buf, question, passages, model_name, temperature, qa_prompt, max_tokens
 
 
 
 
 
 
 
168
  ):
169
  buf.write("⏳回答の生成中...")
170
  texts = ""
171
  for char in qa(
 
172
  question=question,
173
  passages=passages,
174
  model_name=model_name,
175
  temperature=temperature,
176
  qa_prompt=qa_prompt,
 
177
  ):
178
  texts += char
179
  buf.write(texts)
@@ -287,12 +297,14 @@ def app():
287
 
288
  openai_api_key = st.session_state.openai_api_key
289
  if openai_api_key:
 
290
  answer_header.subheader("Answer: ")
291
  openai_model_name = st.session_state.openai_model_name
292
  temperature = st.session_state.temperature
293
  qa_prompt = st.session_state.qa_prompt
294
  max_tokens = st.session_state.max_tokens
295
  generate_answer(
 
296
  buf=answer_text_buffer,
297
  question=question,
298
  passages=passages,
 
138
 
139
 
140
  def qa(
141
+ openai_api_key: str,
142
  question: str,
143
  passages: list,
144
  model_name: str,
 
146
  qa_prompt: str,
147
  max_tokens=2000,
148
  ):
149
+ client = OpenAI(api_key=openai_api_key)
150
  contexts = to_contexts(passages)
151
  prompt = qa_prompt.format(contexts=contexts, question=question)
152
  response = client.chat.completions.create(
 
165
 
166
 
167
  def generate_answer(
168
+ openai_api_key,
169
+ buf,
170
+ question,
171
+ passages,
172
+ model_name,
173
+ temperature,
174
+ qa_prompt,
175
+ max_tokens,
176
  ):
177
  buf.write("⏳回答の生成中...")
178
  texts = ""
179
  for char in qa(
180
+ openai_api_key=openai_api_key,
181
  question=question,
182
  passages=passages,
183
  model_name=model_name,
184
  temperature=temperature,
185
  qa_prompt=qa_prompt,
186
+ max_tokens=max_tokens,
187
  ):
188
  texts += char
189
  buf.write(texts)
 
297
 
298
  openai_api_key = st.session_state.openai_api_key
299
  if openai_api_key:
300
+ openai_api_key = openai_api_key.strip()
301
  answer_header.subheader("Answer: ")
302
  openai_model_name = st.session_state.openai_model_name
303
  temperature = st.session_state.temperature
304
  qa_prompt = st.session_state.qa_prompt
305
  max_tokens = st.session_state.max_tokens
306
  generate_answer(
307
+ openai_api_key=openai_api_key,
308
  buf=answer_text_buffer,
309
  question=question,
310
  passages=passages,