|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
import indexing |
|
import retrieval |
|
import utils |
|
|
|
list_llm = [ |
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
"microsoft/Phi-3.5-mini-instruct", |
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
"meta-llama/Llama-3.2-3B-Instruct", |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
"google/gemma-2-2b-it", |
|
"Qwen/Qwen2.5-3B-Instruct", |
|
] |
|
list_llm_simple = [os.path.basename(llm) for llm in list_llm] |
|
|
|
|
|
def retrieve_api(): |
|
load_dotenv() |
|
global huggingfacehub_api_token |
|
huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY") |
|
|
|
|
|
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()): |
|
list_file_path = [x.name for x in list_file_obj if x is not None] |
|
collection_name = indexing.create_collection_name(list_file_path[0]) |
|
doc_splits, full_text = indexing.load_doc(list_file_path, chunk_size, chunk_overlap) |
|
summary = utils.generate_summary(full_text) |
|
vector_db = indexing.create_db(doc_splits, collection_name) |
|
return vector_db, collection_name, summary, "Complete!" |
|
|
|
|
|
def initialize_llm(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): |
|
llm_name = list_llm[llm_option] |
|
qa_chain = retrieval.initialize_llmchain( |
|
llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress |
|
) |
|
return qa_chain, "Complete!" |
|
|
|
|
|
def conversation(qa_chain, message, history): |
|
qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(qa_chain, message, history) |
|
highlights = utils.extract_highlight_snippets(response_sources) |
|
return qa_chain, gr.update(value=""), new_history, *highlights |
|
|
|
|
|
def challenge_me(qa_chain): |
|
questions = utils.generate_challenge_questions(qa_chain) |
|
return questions |
|
|
|
|
|
def evaluate_answers(qa_chain, questions, user_answers): |
|
feedback = utils.evaluate_responses(qa_chain, questions, user_answers) |
|
return feedback |
|
|
|
|
|
def gradio_ui(): |
|
with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none}") as demo: |
|
vector_db = gr.State() |
|
qa_chain = gr.State() |
|
collection_name = gr.State() |
|
|
|
gr.Markdown("""<h1 style='text-align:center;'>📚 GenAI Document Assistant</h1> |
|
<h3 style='text-align:center;color:gray;'>Smart, interactive reading of research papers, legal docs, and more.</h3>""") |
|
|
|
with gr.Tab("1️⃣ Upload Document"): |
|
document = gr.File(label="Upload PDF or TXT", file_types=[".pdf", ".txt"], file_count="multiple") |
|
slider_chunk_size = gr.Slider(100, 1000, value=600, step=20, label="Chunk Size") |
|
slider_chunk_overlap = gr.Slider(10, 200, value=40, step=10, label="Chunk Overlap") |
|
db_progress = gr.Textbox(label="Processing Status") |
|
summary_box = gr.Textbox(label="Auto Summary (≤ 150 words)", lines=5) |
|
db_btn = gr.Button("📥 Process Document") |
|
|
|
with gr.Tab("2️⃣ QA Chain Initialization"): |
|
llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index") |
|
slider_temperature = gr.Slider(0.01, 1.0, value=0.7, step=0.1, label="Temperature") |
|
slider_maxtokens = gr.Slider(224, 4096, value=1024, step=32, label="Max Tokens") |
|
slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-K") |
|
llm_progress = gr.Textbox(label="LLM Status") |
|
qachain_btn = gr.Button("⚙️ Initialize QA Chain") |
|
|
|
with gr.Tab("3️⃣ Ask Anything"): |
|
chatbot = gr.Chatbot(height=300) |
|
msg = gr.Textbox(placeholder="Ask a question from the document...") |
|
submit_btn = gr.Button("💬 Ask") |
|
clear_btn = gr.ClearButton([msg, chatbot]) |
|
ref1 = gr.Textbox(label="Reference 1") |
|
ref2 = gr.Textbox(label="Reference 2") |
|
ref3 = gr.Textbox(label="Reference 3") |
|
|
|
with gr.Tab("4️⃣ Challenge Me"): |
|
challenge_btn = gr.Button("🎯 Generate Questions") |
|
q1 = gr.Textbox(label="Question 1") |
|
a1 = gr.Textbox(label="Your Answer 1") |
|
q2 = gr.Textbox(label="Question 2") |
|
a2 = gr.Textbox(label="Your Answer 2") |
|
q3 = gr.Textbox(label="Question 3") |
|
a3 = gr.Textbox(label="Your Answer 3") |
|
eval_btn = gr.Button("✅ Submit Answers") |
|
feedback = gr.Textbox(label="Feedback", lines=5) |
|
|
|
db_btn.click(initialize_database, [document, slider_chunk_size, slider_chunk_overlap], [vector_db, collection_name, summary_box, db_progress]) |
|
qachain_btn.click(initialize_llm, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress]) |
|
submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, ref1, ref2, ref3]) |
|
challenge_btn.click(challenge_me, [qa_chain], [q1, q2, q3]) |
|
eval_btn.click(evaluate_answers, [qa_chain, [q1, q2, q3], [a1, a2, a3]], [feedback]) |
|
|
|
demo.launch(debug=True) |
|
|
|
if __name__ == "__main__": |
|
retrieve_api() |
|
gradio_ui() |
|
|