RamAnanth1 commited on
Commit
30e5c47
·
1 Parent(s): 00edfbb

Add API as argument

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -7,8 +7,6 @@ from transformers import pipeline
7
  import torch
8
 
9
  session_token = os.environ.get('SessionToken')
10
- # logger.info(f"session_token_: {session_token}")
11
- api = ChatGPT(session_token)
12
 
13
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
 
@@ -23,28 +21,37 @@ all_special_ids = whisper_model.tokenizer.all_special_ids
23
  transcribe_token_id = all_special_ids[-5]
24
  translate_token_id = all_special_ids[-6]
25
 
 
 
 
 
 
 
 
 
 
26
  def translate_or_transcribe(audio, task):
27
  whisper_model.model.config.forced_decoder_ids = [[2, transcribe_token_id if task=="Transcribe in Spoken Language" else translate_token_id]]
28
  text = whisper_model(audio)["text"]
29
  return text
30
 
31
- def get_response_from_chatbot(text):
 
 
32
  try:
33
- if reset_conversation:
34
- api.refresh_auth()
35
- api.reset_conversation()
36
- resp = api.send_message(text)
37
- response = resp['message']
38
- # logger.info(f"response_: {response}")
39
  except:
40
- response = "Sorry, chatGPT queue is full. Please try again in some time"
41
  return response
42
 
43
- def chat(message, chat_history):
44
  out_chat = []
45
  if chat_history != '':
46
  out_chat = json.loads(chat_history)
47
- response = get_response_from_chatbot(message)
48
  out_chat.append((message, response))
49
  chat_history = json.dumps(out_chat)
50
  logger.info(f"out_chat_: {len(out_chat)}")
@@ -159,8 +166,7 @@ with gr.Blocks(title='Talk to chatGPT') as demo:
159
 
160
  )
161
  translate_btn = gr.Button("Check Whisper first ? 👍")
162
-
163
- reset_conversation = gr.Checkbox(label="Reset conversation?", value=False)
164
  whisper_task = gr.Radio(["Translate to English", "Transcribe in Spoken Language"], value="Translate to English", show_label=False)
165
  with gr.Row(elem_id="prompt_row"):
166
  prompt_input = gr.Textbox(lines=2, label="Input text",show_label=True)
@@ -177,10 +183,11 @@ with gr.Blocks(title='Talk to chatGPT') as demo:
177
  inputs=[prompt_input_audio,whisper_task],
178
  outputs=prompt_input
179
  )
180
-
 
181
  submit_btn.click(fn=chat,
182
- inputs=[prompt_input, chat_history],
183
- outputs=[chatbot, chat_history],
184
  )
185
  gr.HTML('''
186
  <p>Note: Please be aware that audio records from iOS devices will not be decoded as expected by Gradio. For the best experience, record your voice from a computer instead of your smartphone ;)</p>
 
7
  import torch
8
 
9
  session_token = os.environ.get('SessionToken')
 
 
10
 
11
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
 
 
21
  transcribe_token_id = all_special_ids[-5]
22
  translate_token_id = all_special_ids[-6]
23
 
24
+ def get_api():
25
+ api = None
26
+ # try:
27
+ # api = ChatGPT(session_token)
28
+ # # api.refresh_auth()
29
+ # except:
30
+ # api = None
31
+ return api
32
+
33
  def translate_or_transcribe(audio, task):
34
  whisper_model.model.config.forced_decoder_ids = [[2, transcribe_token_id if task=="Transcribe in Spoken Language" else translate_token_id]]
35
  text = whisper_model(audio)["text"]
36
  return text
37
 
38
+ def get_response_from_chatbot(api,text):
39
+ if api is None:
40
+ return "Sorry, the chatGPT API has some issues. Please try again later"
41
  try:
42
+ resp = api.send_message(text)
43
+ api.refresh_auth()
44
+ # api.reset_conversation()
45
+ response = resp['message']
 
 
46
  except:
47
+ response = "Sorry, the chatGPT queue is full. Please try again later"
48
  return response
49
 
50
+ def chat(api,message, chat_history):
51
  out_chat = []
52
  if chat_history != '':
53
  out_chat = json.loads(chat_history)
54
+ response = get_response_from_chatbot(api,message)
55
  out_chat.append((message, response))
56
  chat_history = json.dumps(out_chat)
57
  logger.info(f"out_chat_: {len(out_chat)}")
 
166
 
167
  )
168
  translate_btn = gr.Button("Check Whisper first ? 👍")
169
+
 
170
  whisper_task = gr.Radio(["Translate to English", "Transcribe in Spoken Language"], value="Translate to English", show_label=False)
171
  with gr.Row(elem_id="prompt_row"):
172
  prompt_input = gr.Textbox(lines=2, label="Input text",show_label=True)
 
183
  inputs=[prompt_input_audio,whisper_task],
184
  outputs=prompt_input
185
  )
186
+
187
+ api = gr.State(value=get_api())
188
  submit_btn.click(fn=chat,
189
+ inputs=[api,prompt_input, chat_history],
190
+ outputs=[api,chatbot, chat_history],
191
  )
192
  gr.HTML('''
193
  <p>Note: Please be aware that audio records from iOS devices will not be decoded as expected by Gradio. For the best experience, record your voice from a computer instead of your smartphone ;)</p>