from typing import List, Optional import gradio as gr from langchain_core.vectorstores import VectorStore from config import ( LLM_MODEL_REPOS, EMBED_MODEL_REPOS, SUBTITLES_LANGUAGES, GENERATE_KWARGS, ) from utils import ( load_llm_model, load_embed_model, load_documents_and_create_db, user_message_to_chatbot, update_user_message_with_context, get_llm_response, get_gguf_model_names, add_new_model_repo, clear_llm_folder, clear_embed_folder, get_memory_usage, ) # ============ INTERFACE COMPONENT INITIALIZATION FUNCS ============ def get_rag_settings(rag_mode: bool, render: bool = True): k = gr.Radio( choices=[1, 2, 3, 4, 5, 'all'], value=2, label='Number of relevant documents for search', visible=rag_mode, render=render, ) score_threshold = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.05, label='relevance_scores_threshold', visible=rag_mode, render=render, ) return k, score_threshold def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component: num_lines = len(text.split('\n')) max_lines = 10 num_lines = max_lines if num_lines > max_lines else num_lines return gr.Textbox( text, visible=rag_mode, interactive=False, label='User Message With Context', lines=num_lines, ) def get_system_prompt_component(interactive: bool) -> gr.Textbox: value = '' if interactive else 'System prompt is not supported by this model' return gr.Textbox(value=value, label='System prompt', interactive=interactive) def get_generate_args(do_sample: bool) -> List[gr.component]: generate_args = [ gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample), gr.Slider(minimum=0.1, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample), gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample), gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample), ] return generate_args def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox: value = visible = db is not None return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible) # ================ LOADING AND INITIALIZING MODELS ======================== start_llm_model, start_support_system_role, load_log = load_llm_model(LLM_MODEL_REPOS[0], 'gemma-2-2b-it-Q8_0.gguf') start_embed_model, load_log = load_embed_model(EMBED_MODEL_REPOS[0]) # ================== APPLICATION WEB INTERFACE ============================ theme = gr.themes.Base(primary_hue='green', secondary_hue='yellow', neutral_hue='zinc').set( loader_color='rgb(0, 255, 0)', slider_color='rgb(0, 200, 0)', body_text_color_dark='rgb(0, 200, 0)', button_secondary_background_fill_dark='green', ) css = '''.gradio-container {width: 60% !important}''' with gr.Blocks(theme=theme, css=css) as interface: # ==================== GRADIO STATES =============================== documents = gr.State([]) db = gr.State(None) user_message_with_context = gr.State('') support_system_role = gr.State(start_support_system_role) llm_model_repos = gr.State(LLM_MODEL_REPOS) embed_model_repos = gr.State(EMBED_MODEL_REPOS) llm_model = gr.State(start_llm_model) embed_model = gr.State(start_embed_model) # ==================== BOT PAGE ================================= with gr.Tab(label='Chatbot'): with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( show_copy_button=True, bubble_full_width=False, height=480, ) user_message = gr.Textbox(label='User') with gr.Row(): user_message_btn = gr.Button('Send') stop_btn = gr.Button('Stop') clear_btn = gr.Button('Clear') # ------------- GENERATION PARAMETERS ------------------- with gr.Column(scale=1, min_width=80): with gr.Group(): gr.Markdown('History size') history_len = gr.Slider( minimum=0, maximum=5, value=0, step=1, info='Number of previous messages taken into account in history', label='history_len', show_label=False, ) with gr.Group(): gr.Markdown('Generation parameters') do_sample = gr.Checkbox( value=False, label='do_sample', info='Activate random sampling', ) generate_args = get_generate_args(do_sample.value) do_sample.change( fn=get_generate_args, inputs=do_sample, outputs=generate_args, show_progress=False, ) rag_mode = get_rag_mode_component(db=db.value) k, score_threshold = get_rag_settings(rag_mode=rag_mode.value, render=False) rag_mode.change( fn=get_rag_settings, inputs=[rag_mode], outputs=[k, score_threshold], ) with gr.Row(): k.render() score_threshold.render() # ---------------- SYSTEM PROMPT AND USER MESSAGE ----------- with gr.Accordion('Prompt', open=True): system_prompt = get_system_prompt_component(interactive=support_system_role.value) user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value) # ---------------- SEND, CLEAR AND STOP BUTTONS ------------ generate_event = gr.on( triggers=[user_message.submit, user_message_btn.click], fn=user_message_to_chatbot, inputs=[user_message, chatbot], outputs=[user_message, chatbot], queue=False, ).then( fn=update_user_message_with_context, inputs=[chatbot, rag_mode, db, k, score_threshold], outputs=[user_message_with_context], ).then( fn=get_user_message_with_context, inputs=[user_message_with_context, rag_mode], outputs=[user_message_with_context], ).then( fn=get_llm_response, inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt, support_system_role, history_len, do_sample, *generate_args], outputs=[chatbot], ) stop_btn.click( fn=None, inputs=None, outputs=None, cancels=generate_event, queue=False, ) clear_btn.click( fn=lambda: (None, ''), inputs=None, outputs=[chatbot, user_message_with_context], queue=False, ) # ================= FILE DOWNLOAD PAGE ========================= with gr.Tab(label='Load documents'): with gr.Row(variant='compact'): upload_files = gr.File(file_count='multiple', label='Loading text files') web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube') with gr.Row(variant='compact'): chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size') chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap') subtitles_lang = gr.Radio( SUBTITLES_LANGUAGES, value=SUBTITLES_LANGUAGES[0], label='YouTube subtitle language', ) load_documents_btn = gr.Button(value='Upload documents and initialize database') load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False) load_documents_btn.click( fn=load_documents_and_create_db, inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model], outputs=[documents, db, load_docs_log], ).success( fn=get_rag_mode_component, inputs=[db], outputs=[rag_mode], ) gr.HTML("""

GitHub Repository

""") # ================= VIEW PAGE FOR ALL DOCUMENTS ================= with gr.Tab(label='View documents'): view_documents_btn = gr.Button(value='Show downloaded text chunks') view_documents_textbox = gr.Textbox( lines=1, placeholder='To view chunks, load documents in the Load documents tab', label='Uploaded chunks', ) sep = '=' * 20 view_documents_btn.click( lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]), inputs=[documents], outputs=[view_documents_textbox], ) # ============== GGUF MODELS DOWNLOAD PAGE ===================== with gr.Tab('Load LLM model'): new_llm_model_repo = gr.Textbox( value='', label='Add repository', placeholder='Link to repository of HF models in GGUF format', ) new_llm_model_repo_btn = gr.Button('Add repository') curr_llm_model_repo = gr.Dropdown( choices=LLM_MODEL_REPOS, value=None, label='HF Model Repository', ) curr_llm_model_path = gr.Dropdown( choices=[], value=None, label='GGUF model file', ) load_llm_model_btn = gr.Button('Loading and initializing model') load_llm_model_log = gr.Textbox( value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup', label='Model loading status', lines=6, ) with gr.Group(): gr.Markdown('Free up disk space by deleting all models except the currently selected one') clear_llm_folder_btn = gr.Button('Clear folder') new_llm_model_repo_btn.click( fn=add_new_model_repo, inputs=[new_llm_model_repo, llm_model_repos], outputs=[curr_llm_model_repo, load_llm_model_log], ).success( fn=lambda: '', inputs=None, outputs=[new_llm_model_repo], ) curr_llm_model_repo.change( fn=get_gguf_model_names, inputs=[curr_llm_model_repo], outputs=[curr_llm_model_path], ) load_llm_model_btn.click( fn=load_llm_model, inputs=[curr_llm_model_repo, curr_llm_model_path], outputs=[llm_model, support_system_role, load_llm_model_log], queue=True, ).success( fn=lambda log: log + get_memory_usage(), inputs=[load_llm_model_log], outputs=[load_llm_model_log], ).then( fn=get_system_prompt_component, inputs=[support_system_role], outputs=[system_prompt], ) clear_llm_folder_btn.click( fn=clear_llm_folder, inputs=[curr_llm_model_path], outputs=None, ).success( fn=lambda model_path: f'Models other than {model_path} removed', inputs=[curr_llm_model_path], outputs=None, ) # ============== EMBEDDING MODELS DOWNLOAD PAGE ============= with gr.Tab('Load embed model'): new_embed_model_repo = gr.Textbox( value='', label='Add repository', placeholder='Link to HF model repository', ) new_embed_model_repo_btn = gr.Button('Add repository') curr_embed_model_repo = gr.Dropdown( choices=EMBED_MODEL_REPOS, value=None, label='HF model repository', ) load_embed_model_btn = gr.Button('Loading and initializing model') load_embed_model_log = gr.Textbox( value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup', label='Model loading status', lines=7, ) with gr.Group(): gr.Markdown('Free up disk space by deleting all models except the currently selected one') clear_embed_folder_btn = gr.Button('Clear folder') new_embed_model_repo_btn.click( fn=add_new_model_repo, inputs=[new_embed_model_repo, embed_model_repos], outputs=[curr_embed_model_repo, load_embed_model_log], ).success( fn=lambda: '', inputs=None, outputs=new_embed_model_repo, ) load_embed_model_btn.click( fn=load_embed_model, inputs=[curr_embed_model_repo], outputs=[embed_model, load_embed_model_log], ).success( fn=lambda log: log + get_memory_usage(), inputs=[load_embed_model_log], outputs=[load_embed_model_log], ) clear_embed_folder_btn.click( fn=clear_embed_folder, inputs=[curr_embed_model_repo], outputs=None, ).success( fn=lambda model_repo: f'Models other than {model_repo} removed', inputs=[curr_embed_model_repo], outputs=None, ) interface.launch(server_name='0.0.0.0', server_port=7860) # debug=True