File size: 5,329 Bytes
dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f dfc699a a70773f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# ✅ Enhanced GenAI Assistant with:
# - PDF/TXT support
# - Ask Anything + Challenge Me modes
# - Auto Summary (<=150 words)
# - Memory handling
# - Reference highlighting
# - Stunning UI (Gradio upgraded)
# --- FILE: app.py ---
import os
import gradio as gr
from dotenv import load_dotenv
import indexing
import retrieval
import utils
list_llm = [
"mistralai/Mistral-7B-Instruct-v0.3",
"microsoft/Phi-3.5-mini-instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-3B-Instruct",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"google/gemma-2-2b-it",
"Qwen/Qwen2.5-3B-Instruct",
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def retrieve_api():
load_dotenv()
global huggingfacehub_api_token
huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
list_file_path = [x.name for x in list_file_obj if x is not None]
collection_name = indexing.create_collection_name(list_file_path[0])
doc_splits, full_text = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
summary = utils.generate_summary(full_text)
vector_db = indexing.create_db(doc_splits, collection_name)
return vector_db, collection_name, summary, "Complete!"
def initialize_llm(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = list_llm[llm_option]
qa_chain = retrieval.initialize_llmchain(
llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
)
return qa_chain, "Complete!"
def conversation(qa_chain, message, history):
qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(qa_chain, message, history)
highlights = utils.extract_highlight_snippets(response_sources)
return qa_chain, gr.update(value=""), new_history, *highlights
def challenge_me(qa_chain):
questions = utils.generate_challenge_questions(qa_chain)
return questions
def evaluate_answers(qa_chain, questions, user_answers):
feedback = utils.evaluate_responses(qa_chain, questions, user_answers)
return feedback
def gradio_ui():
with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none}") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown("""<h1 style='text-align:center;'>📚 GenAI Document Assistant</h1>
<h3 style='text-align:center;color:gray;'>Smart, interactive reading of research papers, legal docs, and more.</h3>""")
with gr.Tab("1️⃣ Upload Document"):
document = gr.File(label="Upload PDF or TXT", file_types=[".pdf", ".txt"], file_count="multiple")
slider_chunk_size = gr.Slider(100, 1000, value=600, step=20, label="Chunk Size")
slider_chunk_overlap = gr.Slider(10, 200, value=40, step=10, label="Chunk Overlap")
db_progress = gr.Textbox(label="Processing Status")
summary_box = gr.Textbox(label="Auto Summary (≤ 150 words)", lines=5)
db_btn = gr.Button("📥 Process Document")
with gr.Tab("2️⃣ QA Chain Initialization"):
llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index")
slider_temperature = gr.Slider(0.01, 1.0, value=0.7, step=0.1, label="Temperature")
slider_maxtokens = gr.Slider(224, 4096, value=1024, step=32, label="Max Tokens")
slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-K")
llm_progress = gr.Textbox(label="LLM Status")
qachain_btn = gr.Button("⚙️ Initialize QA Chain")
with gr.Tab("3️⃣ Ask Anything"):
chatbot = gr.Chatbot(height=300)
msg = gr.Textbox(placeholder="Ask a question from the document...")
submit_btn = gr.Button("💬 Ask")
clear_btn = gr.ClearButton([msg, chatbot])
ref1 = gr.Textbox(label="Reference 1")
ref2 = gr.Textbox(label="Reference 2")
ref3 = gr.Textbox(label="Reference 3")
with gr.Tab("4️⃣ Challenge Me"):
challenge_btn = gr.Button("🎯 Generate Questions")
q1 = gr.Textbox(label="Question 1")
a1 = gr.Textbox(label="Your Answer 1")
q2 = gr.Textbox(label="Question 2")
a2 = gr.Textbox(label="Your Answer 2")
q3 = gr.Textbox(label="Question 3")
a3 = gr.Textbox(label="Your Answer 3")
eval_btn = gr.Button("✅ Submit Answers")
feedback = gr.Textbox(label="Feedback", lines=5)
db_btn.click(initialize_database, [document, slider_chunk_size, slider_chunk_overlap], [vector_db, collection_name, summary_box, db_progress])
qachain_btn.click(initialize_llm, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, ref1, ref2, ref3])
challenge_btn.click(challenge_me, [qa_chain], [q1, q2, q3])
eval_btn.click(evaluate_answers, [qa_chain, [q1, q2, q3], [a1, a2, a3]], [feedback])
demo.launch(debug=True)
if __name__ == "__main__":
retrieve_api()
gradio_ui()
|