Spaces:
Running
Running
Use LangChain to get streaming response from the LLM; update progress bar to display the current status
Browse files
app.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import datetime
|
2 |
import logging
|
3 |
import pathlib
|
@@ -7,9 +10,7 @@ from typing import List
|
|
7 |
|
8 |
import json5
|
9 |
import streamlit as st
|
10 |
-
from langchain_community.chat_message_histories import
|
11 |
-
StreamlitChatMessageHistory
|
12 |
-
)
|
13 |
from langchain_core.messages import HumanMessage
|
14 |
from langchain_core.prompts import ChatPromptTemplate
|
15 |
|
@@ -47,17 +48,9 @@ def _get_prompt_template(is_refinement: bool) -> str:
|
|
47 |
return template
|
48 |
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# Get Mistral tokenizer for counting tokens.
|
54 |
-
#
|
55 |
-
# :return: The tokenizer.
|
56 |
-
# """
|
57 |
-
#
|
58 |
-
# return AutoTokenizer.from_pretrained(
|
59 |
-
# pretrained_model_name_or_path=GlobalConfig.HF_LLM_MODEL_NAME
|
60 |
-
# )
|
61 |
|
62 |
|
63 |
APP_TEXT = _load_strings()
|
@@ -66,9 +59,10 @@ APP_TEXT = _load_strings()
|
|
66 |
CHAT_MESSAGES = 'chat_messages'
|
67 |
DOWNLOAD_FILE_KEY = 'download_file_name'
|
68 |
IS_IT_REFINEMENT = 'is_it_refinement'
|
|
|
|
|
69 |
|
70 |
logger = logging.getLogger(__name__)
|
71 |
-
progress_bar = st.progress(0, text='Setting up SlideDeck AI...')
|
72 |
|
73 |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
|
74 |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
|
@@ -110,7 +104,6 @@ def build_ui():
|
|
110 |
with st.expander('Usage Policies and Limitations'):
|
111 |
display_page_footer_content()
|
112 |
|
113 |
-
progress_bar.progress(50, text='Setting up chat interface...')
|
114 |
set_up_chat_ui()
|
115 |
|
116 |
|
@@ -131,8 +124,6 @@ def set_up_chat_ui():
|
|
131 |
st.chat_message('ai').write(
|
132 |
random.choice(APP_TEXT['ai_greetings'])
|
133 |
)
|
134 |
-
progress_bar.progress(100, text='Done!')
|
135 |
-
progress_bar.empty()
|
136 |
|
137 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
138 |
|
@@ -188,66 +179,51 @@ def set_up_chat_ui():
|
|
188 |
}
|
189 |
)
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
expanded=False
|
194 |
-
) as status:
|
195 |
-
response: dict = llm_helper.hf_api_query({
|
196 |
-
'inputs': formatted_template,
|
197 |
-
'parameters': {
|
198 |
-
'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
|
199 |
-
'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
|
200 |
-
'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
|
201 |
-
'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
|
202 |
-
'num_return_sequences': 1,
|
203 |
-
'return_full_text': False,
|
204 |
-
# "repetition_penalty": 0.0001
|
205 |
-
},
|
206 |
-
'options': {
|
207 |
-
'wait_for_model': True,
|
208 |
-
'use_cache': True
|
209 |
-
}
|
210 |
-
})
|
211 |
|
212 |
-
|
213 |
-
|
214 |
|
215 |
-
|
|
|
|
|
216 |
|
217 |
-
|
218 |
-
|
219 |
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
expanded=False
|
236 |
-
)
|
237 |
-
generate_slide_deck(response_cleaned)
|
238 |
-
status.update(label='Done!', state='complete', expanded=True)
|
239 |
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
244 |
|
245 |
|
246 |
-
def generate_slide_deck(json_str: str):
|
247 |
"""
|
248 |
-
Create a slide deck.
|
|
|
249 |
|
250 |
:param json_str: The content in *valid* JSON format.
|
|
|
251 |
"""
|
252 |
|
253 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
@@ -269,17 +245,6 @@ def generate_slide_deck(json_str: str):
|
|
269 |
output_file_path=path
|
270 |
)
|
271 |
except ValueError:
|
272 |
-
# st.error(
|
273 |
-
# f"{APP_TEXT['json_parsing_error']}"
|
274 |
-
# f"\n\nAdditional error info: {ve}"
|
275 |
-
# f"\n\nHere are some sample instructions that you could try to possibly fix this error;"
|
276 |
-
# f" if these don't work, try rephrasing or refreshing:"
|
277 |
-
# f"\n\n"
|
278 |
-
# "- Regenerate content and fix the JSON error."
|
279 |
-
# "\n- Regenerate content and fix the JSON error. Quotes inside quotes should be escaped."
|
280 |
-
# )
|
281 |
-
# logger.error('%s', APP_TEXT['json_parsing_error'])
|
282 |
-
# logger.error('Additional error info: %s', str(ve))
|
283 |
st.error(
|
284 |
'Encountered error while parsing JSON...will fix it and retry'
|
285 |
)
|
@@ -295,8 +260,8 @@ def generate_slide_deck(json_str: str):
|
|
295 |
except Exception as ex:
|
296 |
st.error(APP_TEXT['content_generation_error'])
|
297 |
logger.error('Caught a generic exception: %s', str(ex))
|
298 |
-
|
299 |
-
|
300 |
|
301 |
|
302 |
def _is_it_refinement() -> bool:
|
|
|
1 |
+
"""
|
2 |
+
Streamlit app containing the UI and the application logic.
|
3 |
+
"""
|
4 |
import datetime
|
5 |
import logging
|
6 |
import pathlib
|
|
|
10 |
|
11 |
import json5
|
12 |
import streamlit as st
|
13 |
+
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
|
|
|
|
|
14 |
from langchain_core.messages import HumanMessage
|
15 |
from langchain_core.prompts import ChatPromptTemplate
|
16 |
|
|
|
48 |
return template
|
49 |
|
50 |
|
51 |
+
@st.cache_resource
|
52 |
+
def _get_llm():
|
53 |
+
return llm_helper.get_hf_endpoint()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
|
56 |
APP_TEXT = _load_strings()
|
|
|
59 |
CHAT_MESSAGES = 'chat_messages'
|
60 |
DOWNLOAD_FILE_KEY = 'download_file_name'
|
61 |
IS_IT_REFINEMENT = 'is_it_refinement'
|
62 |
+
APPROX_TARGET_LENGTH = GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH / 2
|
63 |
+
|
64 |
|
65 |
logger = logging.getLogger(__name__)
|
|
|
66 |
|
67 |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
|
68 |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
|
|
|
104 |
with st.expander('Usage Policies and Limitations'):
|
105 |
display_page_footer_content()
|
106 |
|
|
|
107 |
set_up_chat_ui()
|
108 |
|
109 |
|
|
|
124 |
st.chat_message('ai').write(
|
125 |
random.choice(APP_TEXT['ai_greetings'])
|
126 |
)
|
|
|
|
|
127 |
|
128 |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
|
129 |
|
|
|
179 |
}
|
180 |
)
|
181 |
|
182 |
+
progress_bar = st.progress(0, 'Preparing to call LLM...')
|
183 |
+
response = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
for chunk in _get_llm().stream(formatted_template):
|
186 |
+
response += chunk
|
187 |
|
188 |
+
# Update the progress bar
|
189 |
+
progress_percentage = min(len(response) / APPROX_TARGET_LENGTH, 0.95)
|
190 |
+
progress_bar.progress(progress_percentage, text='Streaming content...')
|
191 |
|
192 |
+
history.add_user_message(prompt)
|
193 |
+
history.add_ai_message(response)
|
194 |
|
195 |
+
# The content has been generated as JSON
|
196 |
+
# There maybe trailing ``` at the end of the response -- remove them
|
197 |
+
# To be careful: ``` may be part of the content as well when code is generated
|
198 |
+
response_cleaned = text_helper.get_clean_json(response)
|
199 |
|
200 |
+
logger.info(
|
201 |
+
'Cleaned JSON response:: original length: %d | cleaned length: %d',
|
202 |
+
len(response), len(response_cleaned)
|
203 |
+
)
|
204 |
+
logger.debug('Cleaned JSON: %s', response_cleaned)
|
205 |
|
206 |
+
# Now create the PPT file
|
207 |
+
progress_bar.progress(0.95, text='Searching photos and generating the slide deck...')
|
208 |
+
path = generate_slide_deck(response_cleaned)
|
209 |
+
progress_bar.progress(1.0, text='Done!')
|
|
|
|
|
|
|
|
|
210 |
|
211 |
+
st.chat_message('ai').code(response, language='json')
|
212 |
+
_display_download_button(path)
|
213 |
+
|
214 |
+
logger.info(
|
215 |
+
'#messages in history / 2: %d',
|
216 |
+
len(st.session_state[CHAT_MESSAGES]) / 2
|
217 |
+
)
|
218 |
|
219 |
|
220 |
+
def generate_slide_deck(json_str: str) -> pathlib.Path:
|
221 |
"""
|
222 |
+
Create a slide deck and return the file path. In case there is any error creating the slide
|
223 |
+
deck, the path may be to an empty file.
|
224 |
|
225 |
:param json_str: The content in *valid* JSON format.
|
226 |
+
:return: The file of the .pptx file.
|
227 |
"""
|
228 |
|
229 |
if DOWNLOAD_FILE_KEY in st.session_state:
|
|
|
245 |
output_file_path=path
|
246 |
)
|
247 |
except ValueError:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
st.error(
|
249 |
'Encountered error while parsing JSON...will fix it and retry'
|
250 |
)
|
|
|
260 |
except Exception as ex:
|
261 |
st.error(APP_TEXT['content_generation_error'])
|
262 |
logger.error('Caught a generic exception: %s', str(ex))
|
263 |
+
|
264 |
+
return path
|
265 |
|
266 |
|
267 |
def _is_it_refinement() -> bool:
|