File size: 5,329 Bytes
dfc699a
 
 
 
 
 
 
 
 
a70773f
 
 
 
 
 
dfc699a
a70773f
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc699a
a70773f
 
 
 
dfc699a
a70773f
 
dfc699a
 
a70773f
dfc699a
a70773f
 
dfc699a
a70773f
 
 
 
 
 
 
 
dfc699a
 
 
a70773f
 
dfc699a
 
 
a70773f
 
dfc699a
 
 
a70773f
 
 
dfc699a
a70773f
 
 
 
dfc699a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a70773f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# ✅ Enhanced GenAI Assistant with:
# - PDF/TXT support
# - Ask Anything + Challenge Me modes
# - Auto Summary (<=150 words)
# - Memory handling
# - Reference highlighting
# - Stunning UI (Gradio upgraded)

# --- FILE: app.py ---

import os
import gradio as gr
from dotenv import load_dotenv
import indexing
import retrieval
import utils

list_llm = [
    "mistralai/Mistral-7B-Instruct-v0.3",
    "microsoft/Phi-3.5-mini-instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.2-3B-Instruct",
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "google/gemma-2-2b-it",
    "Qwen/Qwen2.5-3B-Instruct",
]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]


def retrieve_api():
    load_dotenv()
    global huggingfacehub_api_token
    huggingfacehub_api_token = os.environ.get("HUGGINGFACE_API_KEY")


def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    collection_name = indexing.create_collection_name(list_file_path[0])
    doc_splits, full_text = indexing.load_doc(list_file_path, chunk_size, chunk_overlap)
    summary = utils.generate_summary(full_text)
    vector_db = indexing.create_db(doc_splits, collection_name)
    return vector_db, collection_name, summary, "Complete!"


def initialize_llm(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm_name = list_llm[llm_option]
    qa_chain = retrieval.initialize_llmchain(
        llm_name, huggingfacehub_api_token, llm_temperature, max_tokens, top_k, vector_db, progress
    )
    return qa_chain, "Complete!"


def conversation(qa_chain, message, history):
    qa_chain, new_history, response_sources = retrieval.invoke_qa_chain(qa_chain, message, history)
    highlights = utils.extract_highlight_snippets(response_sources)
    return qa_chain, gr.update(value=""), new_history, *highlights


def challenge_me(qa_chain):
    questions = utils.generate_challenge_questions(qa_chain)
    return questions


def evaluate_answers(qa_chain, questions, user_answers):
    feedback = utils.evaluate_responses(qa_chain, questions, user_answers)
    return feedback


def gradio_ui():
    with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none}") as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        collection_name = gr.State()

        gr.Markdown("""<h1 style='text-align:center;'>📚 GenAI Document Assistant</h1>
        <h3 style='text-align:center;color:gray;'>Smart, interactive reading of research papers, legal docs, and more.</h3>""")

        with gr.Tab("1️⃣ Upload Document"):
            document = gr.File(label="Upload PDF or TXT", file_types=[".pdf", ".txt"], file_count="multiple")
            slider_chunk_size = gr.Slider(100, 1000, value=600, step=20, label="Chunk Size")
            slider_chunk_overlap = gr.Slider(10, 200, value=40, step=10, label="Chunk Overlap")
            db_progress = gr.Textbox(label="Processing Status")
            summary_box = gr.Textbox(label="Auto Summary (≤ 150 words)", lines=5)
            db_btn = gr.Button("📥 Process Document")

        with gr.Tab("2️⃣ QA Chain Initialization"):
            llm_btn = gr.Radio(list_llm_simple, label="Select LLM", value=list_llm_simple[0], type="index")
            slider_temperature = gr.Slider(0.01, 1.0, value=0.7, step=0.1, label="Temperature")
            slider_maxtokens = gr.Slider(224, 4096, value=1024, step=32, label="Max Tokens")
            slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-K")
            llm_progress = gr.Textbox(label="LLM Status")
            qachain_btn = gr.Button("⚙️ Initialize QA Chain")

        with gr.Tab("3️⃣ Ask Anything"):
            chatbot = gr.Chatbot(height=300)
            msg = gr.Textbox(placeholder="Ask a question from the document...")
            submit_btn = gr.Button("💬 Ask")
            clear_btn = gr.ClearButton([msg, chatbot])
            ref1 = gr.Textbox(label="Reference 1")
            ref2 = gr.Textbox(label="Reference 2")
            ref3 = gr.Textbox(label="Reference 3")

        with gr.Tab("4️⃣ Challenge Me"):
            challenge_btn = gr.Button("🎯 Generate Questions")
            q1 = gr.Textbox(label="Question 1")
            a1 = gr.Textbox(label="Your Answer 1")
            q2 = gr.Textbox(label="Question 2")
            a2 = gr.Textbox(label="Your Answer 2")
            q3 = gr.Textbox(label="Question 3")
            a3 = gr.Textbox(label="Your Answer 3")
            eval_btn = gr.Button("✅ Submit Answers")
            feedback = gr.Textbox(label="Feedback", lines=5)

        db_btn.click(initialize_database, [document, slider_chunk_size, slider_chunk_overlap], [vector_db, collection_name, summary_box, db_progress])
        qachain_btn.click(initialize_llm, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
        submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, ref1, ref2, ref3])
        challenge_btn.click(challenge_me, [qa_chain], [q1, q2, q3])
        eval_btn.click(evaluate_answers, [qa_chain, [q1, q2, q3], [a1, a2, a3]], [feedback])

    demo.launch(debug=True)

if __name__ == "__main__":
    retrieve_api()
    gradio_ui()