Spaces:
Running
Running
""" | |
Streamlit app containing the UI and the application logic. | |
""" | |
import datetime | |
import logging | |
import pathlib | |
import random | |
import tempfile | |
from typing import List, Union | |
import huggingface_hub | |
import json5 | |
import requests | |
import streamlit as st | |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory | |
from langchain_core.messages import HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from global_config import GlobalConfig | |
from helpers import llm_helper, pptx_helper, text_helper | |
def _load_strings() -> dict: | |
""" | |
Load various strings to be displayed in the app. | |
:return: The dictionary of strings. | |
""" | |
with open(GlobalConfig.APP_STRINGS_FILE, 'r', encoding='utf-8') as in_file: | |
return json5.loads(in_file.read()) | |
def _get_prompt_template(is_refinement: bool) -> str: | |
""" | |
Return a prompt template. | |
:param is_refinement: Whether this is the initial or refinement prompt. | |
:return: The prompt template as f-string. | |
""" | |
if is_refinement: | |
with open(GlobalConfig.REFINEMENT_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file: | |
template = in_file.read() | |
else: | |
with open(GlobalConfig.INITIAL_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file: | |
template = in_file.read() | |
return template | |
def are_all_inputs_valid( | |
user_prompt: str, | |
selected_provider: str, | |
selected_model: str, | |
user_key: str, | |
) -> bool: | |
""" | |
Validate user input and LLM selection. | |
:param user_prompt: The prompt. | |
:param selected_provider: The LLM provider. | |
:param selected_model: Name of the model. | |
:param user_key: User-provided API key. | |
:return: `True` if all inputs "look" OK; `False` otherwise. | |
""" | |
if not text_helper.is_valid_prompt(user_prompt): | |
handle_error( | |
'Not enough information provided!' | |
' Please be a little more descriptive and type a few words' | |
' with a few characters :)', | |
False | |
) | |
return False | |
if not selected_provider or not selected_model: | |
handle_error('No valid LLM provider and/or model name found!', False) | |
return False | |
if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key): | |
handle_error( | |
'The LLM settings do not look correct. Make sure that an API key/access token' | |
' is provided if the selected LLM requires it. An API key should be 6-64 characters' | |
' long, only containing alphanumeric characters, hyphens, and underscores.', | |
False | |
) | |
return False | |
return True | |
def handle_error(error_msg: str, should_log: bool): | |
""" | |
Display an error message in the app. | |
:param error_msg: The error message to be displayed. | |
:param should_log: If `True`, log the message. | |
""" | |
if should_log: | |
logger.error(error_msg) | |
st.error(error_msg) | |
APP_TEXT = _load_strings() | |
# Session variables | |
CHAT_MESSAGES = 'chat_messages' | |
DOWNLOAD_FILE_KEY = 'download_file_name' | |
IS_IT_REFINEMENT = 'is_it_refinement' | |
logger = logging.getLogger(__name__) | |
texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys()) | |
captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts] | |
with st.sidebar: | |
# The PPT templates | |
pptx_template = st.sidebar.radio( | |
'1: Select a presentation template:', | |
texts, | |
captions=captions, | |
horizontal=True | |
) | |
# The LLMs | |
llm_provider_to_use = st.sidebar.selectbox( | |
label='2: Select an LLM to use:', | |
options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()], | |
index=GlobalConfig.DEFAULT_MODEL_INDEX, | |
help=GlobalConfig.LLM_PROVIDER_HELP, | |
).split(' ')[0] | |
# The API key/access token | |
api_key_token = st.text_input( | |
label=( | |
'3: Paste your API key/access token:\n\n' | |
'*Mandatory* for Cohere and Gemini LLMs.' | |
' *Optional* for HF Mistral LLMs but still encouraged.\n\n' | |
), | |
type='password', | |
) | |
def build_ui(): | |
""" | |
Display the input elements for content generation. | |
""" | |
st.title(APP_TEXT['app_name']) | |
st.subheader(APP_TEXT['caption']) | |
st.markdown( | |
'![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501 | |
) | |
with st.expander('Usage Policies and Limitations'): | |
st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2']) | |
set_up_chat_ui() | |
def set_up_chat_ui(): | |
""" | |
Prepare the chat interface and related functionality. | |
""" | |
with st.expander('Usage Instructions'): | |
st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS) | |
st.info(APP_TEXT['like_feedback']) | |
st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings'])) | |
history = StreamlitChatMessageHistory(key=CHAT_MESSAGES) | |
prompt_template = ChatPromptTemplate.from_template( | |
_get_prompt_template( | |
is_refinement=_is_it_refinement() | |
) | |
) | |
# Since Streamlit app reloads at every interaction, display the chat history | |
# from the save session state | |
for msg in history.messages: | |
st.chat_message(msg.type).code(msg.content, language='json') | |
if prompt := st.chat_input( | |
placeholder=APP_TEXT['chat_placeholder'], | |
max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH | |
): | |
provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use) | |
if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token): | |
return | |
logger.info( | |
'User input: %s | #characters: %d | LLM: %s', | |
prompt, len(prompt), llm_name | |
) | |
st.chat_message('user').write(prompt) | |
if _is_it_refinement(): | |
user_messages = _get_user_messages() | |
user_messages.append(prompt) | |
list_of_msgs = [ | |
f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages) | |
] | |
formatted_template = prompt_template.format( | |
**{ | |
'instructions': '\n'.join(list_of_msgs), | |
'previous_content': _get_last_response(), | |
} | |
) | |
else: | |
formatted_template = prompt_template.format(**{'question': prompt}) | |
progress_bar = st.progress(0, 'Preparing to call LLM...') | |
response = '' | |
try: | |
llm = llm_helper.get_langchain_llm( | |
provider=provider, | |
model=llm_name, | |
max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'], | |
api_key=api_key_token.strip(), | |
) | |
if not llm: | |
handle_error( | |
'Failed to create an LLM instance! Make sure that you have selected the' | |
' correct model from the dropdown list and have provided correct API key' | |
' or access token.', | |
False | |
) | |
return | |
for _ in llm.stream(formatted_template): | |
response += _ | |
# Update the progress bar with an approx progress percentage | |
progress_bar.progress( | |
min( | |
len(response) / GlobalConfig.VALID_MODELS[ | |
llm_provider_to_use | |
]['max_new_tokens'], | |
0.95 | |
), | |
text='Streaming content...this might take a while...' | |
) | |
except requests.exceptions.ConnectionError: | |
handle_error( | |
'A connection error occurred while streaming content from the LLM endpoint.' | |
' Unfortunately, the slide deck cannot be generated. Please try again later.' | |
' Alternatively, try selecting a different LLM from the dropdown list.', | |
True | |
) | |
return | |
except huggingface_hub.errors.ValidationError as ve: | |
handle_error( | |
f'An error occurred while trying to generate the content: {ve}' | |
'\nPlease try again with a significantly shorter input text.', | |
True | |
) | |
return | |
except Exception as ex: | |
handle_error( | |
f'An unexpected error occurred while generating the content: {ex}' | |
'\nPlease try again later, possibly with different inputs.' | |
' Alternatively, try selecting a different LLM from the dropdown list.' | |
' If you are using Cohere or Gemini models, make sure that you have provided' | |
' a correct API key.', | |
True | |
) | |
return | |
history.add_user_message(prompt) | |
history.add_ai_message(response) | |
# The content has been generated as JSON | |
# There maybe trailing ``` at the end of the response -- remove them | |
# To be careful: ``` may be part of the content as well when code is generated | |
response = text_helper.get_clean_json(response) | |
logger.info( | |
'Cleaned JSON length: %d', len(response) | |
) | |
# Now create the PPT file | |
progress_bar.progress( | |
GlobalConfig.LLM_PROGRESS_MAX, | |
text='Finding photos online and generating the slide deck...' | |
) | |
progress_bar.progress(1.0, text='Done!') | |
st.chat_message('ai').code(response, language='json') | |
if path := generate_slide_deck(response): | |
_display_download_button(path) | |
logger.info( | |
'#messages in history / 2: %d', | |
len(st.session_state[CHAT_MESSAGES]) / 2 | |
) | |
def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]: | |
""" | |
Create a slide deck and return the file path. In case there is any error creating the slide | |
deck, the path may be to an empty file. | |
:param json_str: The content in *valid* JSON format. | |
:return: The path to the .pptx file or `None` in case of error. | |
""" | |
try: | |
parsed_data = json5.loads(json_str) | |
except ValueError: | |
handle_error( | |
'Encountered error while parsing JSON...will fix it and retry', | |
True | |
) | |
try: | |
parsed_data = json5.loads(text_helper.fix_malformed_json(json_str)) | |
except ValueError: | |
handle_error( | |
'Encountered an error again while fixing JSON...' | |
'the slide deck cannot be created, unfortunately ☹' | |
'\nPlease try again later.', | |
True | |
) | |
return None | |
except RecursionError: | |
handle_error( | |
'Encountered a recursion error while parsing JSON...' | |
'the slide deck cannot be created, unfortunately ☹' | |
'\nPlease try again later.', | |
True | |
) | |
return None | |
except Exception: | |
handle_error( | |
'Encountered an error while parsing JSON...' | |
'the slide deck cannot be created, unfortunately ☹' | |
'\nPlease try again later.', | |
True | |
) | |
return None | |
if DOWNLOAD_FILE_KEY in st.session_state: | |
path = pathlib.Path(st.session_state[DOWNLOAD_FILE_KEY]) | |
else: | |
temp = tempfile.NamedTemporaryFile(delete=False, suffix='.pptx') | |
path = pathlib.Path(temp.name) | |
st.session_state[DOWNLOAD_FILE_KEY] = str(path) | |
if temp: | |
temp.close() | |
try: | |
logger.debug('Creating PPTX file: %s...', st.session_state[DOWNLOAD_FILE_KEY]) | |
pptx_helper.generate_powerpoint_presentation( | |
parsed_data, | |
slides_template=pptx_template, | |
output_file_path=path | |
) | |
except Exception as ex: | |
st.error(APP_TEXT['content_generation_error']) | |
logger.error('Caught a generic exception: %s', str(ex)) | |
return path | |
def _is_it_refinement() -> bool: | |
""" | |
Whether it is the initial prompt or a refinement. | |
:return: True if it is the initial prompt; False otherwise. | |
""" | |
if IS_IT_REFINEMENT in st.session_state: | |
return True | |
if len(st.session_state[CHAT_MESSAGES]) >= 2: | |
# Prepare for the next call | |
st.session_state[IS_IT_REFINEMENT] = True | |
return True | |
return False | |
def _get_user_messages() -> List[str]: | |
""" | |
Get a list of user messages submitted until now from the session state. | |
:return: The list of user messages. | |
""" | |
return [ | |
msg.content for msg in st.session_state[CHAT_MESSAGES] if isinstance(msg, HumanMessage) | |
] | |
def _get_last_response() -> str: | |
""" | |
Get the last response generated by AI. | |
:return: The response text. | |
""" | |
return st.session_state[CHAT_MESSAGES][-1].content | |
def _display_messages_history(view_messages: st.expander): | |
""" | |
Display the history of messages. | |
:param view_messages: The list of AI and Human messages. | |
""" | |
with view_messages: | |
view_messages.json(st.session_state[CHAT_MESSAGES]) | |
def _display_download_button(file_path: pathlib.Path): | |
""" | |
Display a download button to download a slide deck. | |
:param file_path: The path of the .pptx file. | |
""" | |
with open(file_path, 'rb') as download_file: | |
st.download_button( | |
'Download PPTX file ⬇️', | |
data=download_file, | |
file_name='Presentation.pptx', | |
key=datetime.datetime.now() | |
) | |
def main(): | |
""" | |
Trigger application run. | |
""" | |
build_ui() | |
if __name__ == '__main__': | |
main() | |