File size: 4,880 Bytes
7cb0a3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import chromadb
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.chains import create_retrieval_chain, LLMChain
from langchain.prompts import PromptTemplate
import os
import shutil
import zipfile

# πŸš€ Step 1: Extract ChromaDB if not already done (only once)
if not os.path.exists("./chroma_db"):
    with zipfile.ZipFile("chroma.zip", "r") as zip_ref:
        zip_ref.extractall("./chroma_db")

# πŸš€ Step 2: Load Pre-trained Model & Tokenizer (Fast Startup)
MODEL_NAME = "google/flan-t5-xl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# πŸš€ Step 3: Load Vector Store Efficiently
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
chroma_client = chromadb.PersistentClient(path="./chroma_db")
db = Chroma(embedding_function=embeddings, client=chroma_client)

# πŸš€ Step 4: Optimize Retriever (Lower `k` for Speed)
retriever = db.as_retriever(search_kwargs={"k": 10})

# πŸš€ Step 5: Define Prompt for the LLM
prompt_template = PromptTemplate(
    template="""
        You are a Kubernetes expert.  
    **Answer the question using ONLY the provided context.**  
    If the context does NOT contain enough information, return:  
    `"I don't have enough information to answer this question."`  
    Always include YAML examples when relevant. 
    
    ---
    **Context:**  
    {context}  
    
    **Question:**  
    {input}  
    
    ---
    **Answer:**  
    """,
    input_variables=["context", "input"]
)

# πŸš€ Step 6: Build Retrieval Chain
qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0,
                       max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
llm = HuggingFacePipeline(pipeline=qa_pipeline)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
qa_chain = create_retrieval_chain(retriever, llm_chain)

# πŸš€ Step 7: Define Fast Answer Function
def clean_context(context_list, max_tokens=350, min_length=50):
    """
    Improves the retrieved document context:
    - Removes duplicates while preserving order
    - Filters out very short or unstructured text
    - Limits token count for better LLM performance
    """
    from collections import OrderedDict

    # Preserve order while removing exact duplicates
    unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list))

    # Remove very short texts (e.g., headers, page numbers)
    filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]

    # Avoid near-duplicates
    deduplicated_texts = []
    seen_texts = set()
    for text in filtered_texts:
        normalized_text = " ".join(text.split())  # Normalize spacing
        if not any(normalized_text in seen for seen in seen_texts):  # Avoid near-duplicates
            deduplicated_texts.append(normalized_text)
            seen_texts.add(normalized_text)

    # Limit context based on token count
    trimmed_context = []
    total_tokens = 0
    for text in deduplicated_texts:
        tokenized_text = tokenizer.encode(text, add_special_tokens=False)
        token_count = len(tokenized_text)

        if total_tokens + token_count > max_tokens:
            remaining_tokens = max_tokens - total_tokens
            if remaining_tokens > 20:  # Allow partial inclusion if meaningful
                trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
            break

        trimmed_context.append(text)
        total_tokens += token_count

    return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."


def get_k8s_answer(query):
    retrieved_context = retriever.get_relevant_documents(query)
    cleaned_context = clean_context(retrieved_context, max_tokens=350)

    input_text = prompt_template.format(context=cleaned_context, input=query)

    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
    output_ids = model.generate(**inputs, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response

# πŸš€ Step 8: Optimize Gradio App with `Blocks()`
with gr.Blocks(theme="soft") as demo:
    gr.Markdown("# ⚑ Kubernetes RAG")
    gr.Markdown("Ask any Kubernetes-related question!")

    with gr.Row():
        question = gr.Textbox(label="Ask a Kubernetes Question", lines=1)
        answer = gr.Textbox(label="Answer", interactive=False)

    submit_button = gr.Button("Get Answer")

    submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer)

demo.launch()