""" Streamlit app containing the UI and the application logic. """ import datetime import logging import pathlib import random import tempfile from typing import List, Union import huggingface_hub import json5 import requests import streamlit as st from langchain_community.chat_message_histories import StreamlitChatMessageHistory from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate from global_config import GlobalConfig from helpers import llm_helper, pptx_helper, text_helper @st.cache_data def _load_strings() -> dict: """ Load various strings to be displayed in the app. :return: The dictionary of strings. """ with open(GlobalConfig.APP_STRINGS_FILE, 'r', encoding='utf-8') as in_file: return json5.loads(in_file.read()) @st.cache_data def _get_prompt_template(is_refinement: bool) -> str: """ Return a prompt template. :param is_refinement: Whether this is the initial or refinement prompt. :return: The prompt template as f-string. """ if is_refinement: with open(GlobalConfig.REFINEMENT_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file: template = in_file.read() else: with open(GlobalConfig.INITIAL_PROMPT_TEMPLATE, 'r', encoding='utf-8') as in_file: template = in_file.read() return template def are_all_inputs_valid( user_prompt: str, selected_provider: str, selected_model: str, user_key: str, ) -> bool: """ Validate user input and LLM selection. :param user_prompt: The prompt. :param selected_provider: The LLM provider. :param selected_model: Name of the model. :param user_key: User-provided API key. :return: `True` if all inputs "look" OK; `False` otherwise. """ if not text_helper.is_valid_prompt(user_prompt): handle_error( 'Not enough information provided!' ' Please be a little more descriptive and type a few words' ' with a few characters :)', False ) return False if not selected_provider or not selected_model: handle_error('No valid LLM provider and/or model name found!', False) return False if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key): handle_error( 'The LLM settings do not look correct. Make sure that an API key/access token' ' is provided if the selected LLM requires it. An API key should be 6-64 characters' ' long, only containing alphanumeric characters, hyphens, and underscores.', False ) return False return True def handle_error(error_msg: str, should_log: bool): """ Display an error message in the app. :param error_msg: The error message to be displayed. :param should_log: If `True`, log the message. """ if should_log: logger.error(error_msg) st.error(error_msg) def reset_api_key(): """ Clear API key input when a different LLM is selected from the dropdown list. """ st.session_state.api_key_input = '' APP_TEXT = _load_strings() # Session variables CHAT_MESSAGES = 'chat_messages' DOWNLOAD_FILE_KEY = 'download_file_name' IS_IT_REFINEMENT = 'is_it_refinement' logger = logging.getLogger(__name__) texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys()) captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts] with st.sidebar: # The PPT templates pptx_template = st.sidebar.radio( '1: Select a presentation template:', texts, captions=captions, horizontal=True ) # The LLMs llm_provider_to_use = st.sidebar.selectbox( label='2: Select an LLM to use:', options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()], index=GlobalConfig.DEFAULT_MODEL_INDEX, help=GlobalConfig.LLM_PROVIDER_HELP, on_change=reset_api_key ).split(' ')[0] # The API key/access token api_key_token = st.text_input( label=( '3: Paste your API key/access token:\n\n' '*Mandatory* for Cohere and Gemini LLMs.' ' *Optional* for HF Mistral LLMs but still encouraged.\n\n' ), type='password', key='api_key_input' ) def build_ui(): """ Display the input elements for content generation. """ st.title(APP_TEXT['app_name']) st.subheader(APP_TEXT['caption']) st.markdown( '![Visitors](https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fbarunsaha%2Fslide-deck-ai&countColor=%23263759)' # noqa: E501 ) with st.expander('Usage Policies and Limitations'): st.text(APP_TEXT['tos'] + '\n\n' + APP_TEXT['tos2']) set_up_chat_ui() def set_up_chat_ui(): """ Prepare the chat interface and related functionality. """ with st.expander('Usage Instructions'): st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS) st.info(APP_TEXT['like_feedback']) st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings'])) history = StreamlitChatMessageHistory(key=CHAT_MESSAGES) prompt_template = ChatPromptTemplate.from_template( _get_prompt_template( is_refinement=_is_it_refinement() ) ) # Since Streamlit app reloads at every interaction, display the chat history # from the save session state for msg in history.messages: st.chat_message(msg.type).code(msg.content, language='json') if prompt := st.chat_input( placeholder=APP_TEXT['chat_placeholder'], max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH ): provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use) if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token): return logger.info( 'User input: %s | #characters: %d | LLM: %s', prompt, len(prompt), llm_name ) st.chat_message('user').write(prompt) if _is_it_refinement(): user_messages = _get_user_messages() user_messages.append(prompt) list_of_msgs = [ f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages) ] formatted_template = prompt_template.format( **{ 'instructions': '\n'.join(list_of_msgs), 'previous_content': _get_last_response(), } ) else: formatted_template = prompt_template.format(**{'question': prompt}) progress_bar = st.progress(0, 'Preparing to call LLM...') response = '' try: llm = llm_helper.get_langchain_llm( provider=provider, model=llm_name, max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'], api_key=api_key_token.strip(), ) if not llm: handle_error( 'Failed to create an LLM instance! Make sure that you have selected the' ' correct model from the dropdown list and have provided correct API key' ' or access token.', False ) return for _ in llm.stream(formatted_template): response += _ # Update the progress bar with an approx progress percentage progress_bar.progress( min( len(response) / GlobalConfig.VALID_MODELS[ llm_provider_to_use ]['max_new_tokens'], 0.95 ), text='Streaming content...this might take a while...' ) except requests.exceptions.ConnectionError: handle_error( 'A connection error occurred while streaming content from the LLM endpoint.' ' Unfortunately, the slide deck cannot be generated. Please try again later.' ' Alternatively, try selecting a different LLM from the dropdown list.', True ) return except huggingface_hub.errors.ValidationError as ve: handle_error( f'An error occurred while trying to generate the content: {ve}' '\nPlease try again with a significantly shorter input text.', True ) return except Exception as ex: handle_error( f'An unexpected error occurred while generating the content: {ex}' '\nPlease try again later, possibly with different inputs.' ' Alternatively, try selecting a different LLM from the dropdown list.' ' If you are using Cohere or Gemini models, make sure that you have provided' ' a correct API key.', True ) return history.add_user_message(prompt) history.add_ai_message(response) # The content has been generated as JSON # There maybe trailing ``` at the end of the response -- remove them # To be careful: ``` may be part of the content as well when code is generated response = text_helper.get_clean_json(response) logger.info( 'Cleaned JSON length: %d', len(response) ) # Now create the PPT file progress_bar.progress( GlobalConfig.LLM_PROGRESS_MAX, text='Finding photos online and generating the slide deck...' ) progress_bar.progress(1.0, text='Done!') st.chat_message('ai').code(response, language='json') if path := generate_slide_deck(response): _display_download_button(path) logger.info( '#messages in history / 2: %d', len(st.session_state[CHAT_MESSAGES]) / 2 ) def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]: """ Create a slide deck and return the file path. In case there is any error creating the slide deck, the path may be to an empty file. :param json_str: The content in *valid* JSON format. :return: The path to the .pptx file or `None` in case of error. """ try: parsed_data = json5.loads(json_str) except ValueError: handle_error( 'Encountered error while parsing JSON...will fix it and retry', True ) try: parsed_data = json5.loads(text_helper.fix_malformed_json(json_str)) except ValueError: handle_error( 'Encountered an error again while fixing JSON...' 'the slide deck cannot be created, unfortunately ☹' '\nPlease try again later.', True ) return None except RecursionError: handle_error( 'Encountered a recursion error while parsing JSON...' 'the slide deck cannot be created, unfortunately ☹' '\nPlease try again later.', True ) return None except Exception: handle_error( 'Encountered an error while parsing JSON...' 'the slide deck cannot be created, unfortunately ☹' '\nPlease try again later.', True ) return None if DOWNLOAD_FILE_KEY in st.session_state: path = pathlib.Path(st.session_state[DOWNLOAD_FILE_KEY]) else: temp = tempfile.NamedTemporaryFile(delete=False, suffix='.pptx') path = pathlib.Path(temp.name) st.session_state[DOWNLOAD_FILE_KEY] = str(path) if temp: temp.close() try: logger.debug('Creating PPTX file: %s...', st.session_state[DOWNLOAD_FILE_KEY]) pptx_helper.generate_powerpoint_presentation( parsed_data, slides_template=pptx_template, output_file_path=path ) except Exception as ex: st.error(APP_TEXT['content_generation_error']) logger.error('Caught a generic exception: %s', str(ex)) return path def _is_it_refinement() -> bool: """ Whether it is the initial prompt or a refinement. :return: True if it is the initial prompt; False otherwise. """ if IS_IT_REFINEMENT in st.session_state: return True if len(st.session_state[CHAT_MESSAGES]) >= 2: # Prepare for the next call st.session_state[IS_IT_REFINEMENT] = True return True return False def _get_user_messages() -> List[str]: """ Get a list of user messages submitted until now from the session state. :return: The list of user messages. """ return [ msg.content for msg in st.session_state[CHAT_MESSAGES] if isinstance(msg, HumanMessage) ] def _get_last_response() -> str: """ Get the last response generated by AI. :return: The response text. """ return st.session_state[CHAT_MESSAGES][-1].content def _display_messages_history(view_messages: st.expander): """ Display the history of messages. :param view_messages: The list of AI and Human messages. """ with view_messages: view_messages.json(st.session_state[CHAT_MESSAGES]) def _display_download_button(file_path: pathlib.Path): """ Display a download button to download a slide deck. :param file_path: The path of the .pptx file. """ with open(file_path, 'rb') as download_file: st.download_button( 'Download PPTX file ⬇️', data=download_file, file_name='Presentation.pptx', key=datetime.datetime.now() ) def main(): """ Trigger application run. """ build_ui() if __name__ == '__main__': main()