ofermend commited on
Commit
2a5b875
·
verified ·
1 Parent(s): d0177c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -29
app.py CHANGED
@@ -1,7 +1,6 @@
1
  from omegaconf import OmegaConf
2
  from query import VectaraQuery
3
  import os
4
- import json
5
 
6
  import streamlit as st
7
  from PIL import Image
@@ -20,33 +19,16 @@ def launch_bot():
20
  response = vq.submit_query_streaming(question)
21
  return response
22
 
23
- def generate_and_display_response(question):
24
- if cfg.streaming:
25
- stream = generate_streaming_response(question)
26
- response = st.write_stream(stream)
27
- else:
28
- with st.spinner("Thinking..."):
29
- response = generate_response(question)
30
- st.write(response)
31
- message = {"role": "assistant", "content": response}
32
- st.session_state.messages.append(message)
33
-
34
- def submit_question(question):
35
- st.session_state.messages.append({"role": "user", "content": question})
36
- with st.chat_message("user"):
37
- st.write(question)
38
- generate_and_display_response(question)
39
-
40
  if 'cfg' not in st.session_state:
 
41
  cfg = OmegaConf.create({
42
  'customer_id': str(os.environ['customer_id']),
43
- 'corpus_ids': [str(cid) for cid in json.loads(os.environ['corpus_ids'])],
44
  'api_key': str(os.environ['api_key']),
45
  'title': os.environ['title'],
46
  'description': os.environ['description'],
47
  'source_data_desc': os.environ['source_data_desc'],
48
  'streaming': isTrue(os.environ.get('streaming', False)),
49
- 'questions': list(eval(os.environ['questions'])),
50
  'prompt_name': os.environ.get('prompt_name', None)
51
  })
52
  st.session_state.cfg = cfg
@@ -77,9 +59,6 @@ def launch_bot():
77
 
78
  if "messages" not in st.session_state.keys():
79
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
80
- for question in cfg.questions:
81
- st.button(question, on_click=lambda q=question: submit_question(q))
82
-
83
 
84
  # Display chat messages
85
  for message in st.session_state.messages:
@@ -88,11 +67,23 @@ def launch_bot():
88
 
89
  # User-provided prompt
90
  if prompt := st.chat_input():
91
- print(f"chat msg = {prompt}")
92
- submit_question(prompt)
93
-
94
-
95
-
 
 
 
 
 
 
 
 
 
 
 
 
96
  if __name__ == "__main__":
97
  launch_bot()
98
-
 
1
  from omegaconf import OmegaConf
2
  from query import VectaraQuery
3
  import os
 
4
 
5
  import streamlit as st
6
  from PIL import Image
 
19
  response = vq.submit_query_streaming(question)
20
  return response
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if 'cfg' not in st.session_state:
23
+ corpus_ids = str(os.environ['corpus_ids']).split(',')
24
  cfg = OmegaConf.create({
25
  'customer_id': str(os.environ['customer_id']),
26
+ 'corpus_ids': corpus_ids,
27
  'api_key': str(os.environ['api_key']),
28
  'title': os.environ['title'],
29
  'description': os.environ['description'],
30
  'source_data_desc': os.environ['source_data_desc'],
31
  'streaming': isTrue(os.environ.get('streaming', False)),
 
32
  'prompt_name': os.environ.get('prompt_name', None)
33
  })
34
  st.session_state.cfg = cfg
 
59
 
60
  if "messages" not in st.session_state.keys():
61
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
 
 
 
62
 
63
  # Display chat messages
64
  for message in st.session_state.messages:
 
67
 
68
  # User-provided prompt
69
  if prompt := st.chat_input():
70
+ st.session_state.messages.append({"role": "user", "content": prompt})
71
+ with st.chat_message("user"):
72
+ st.write(prompt)
73
+
74
+ # Generate a new response if last message is not from assistant
75
+ if st.session_state.messages[-1]["role"] != "assistant":
76
+ with st.chat_message("assistant"):
77
+ if cfg.streaming:
78
+ stream = generate_streaming_response(prompt)
79
+ response = st.write_stream(stream)
80
+ else:
81
+ with st.spinner("Thinking..."):
82
+ response = generate_response(prompt)
83
+ st.write(response)
84
+ message = {"role": "assistant", "content": response}
85
+ st.session_state.messages.append(message)
86
+
87
  if __name__ == "__main__":
88
  launch_bot()
89
+