import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline import chromadb from langchain_community.vectorstores import Chroma from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain.chains import create_retrieval_chain, LLMChain from langchain.prompts import PromptTemplate import os import shutil import zipfile # 🚀 Step 1: Extract ChromaDB if not already done (only once) if not os.path.exists("./chroma_db"): with zipfile.ZipFile("chroma.zip", "r") as zip_ref: zip_ref.extractall("./chroma_db") # 🚀 Step 2: Load Pre-trained Model & Tokenizer (Fast Startup) MODEL_NAME = "google/flan-t5-xl" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) # 🚀 Step 3: Load Vector Store Efficiently embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") chroma_client = chromadb.PersistentClient(path="./chroma_db") db = Chroma(embedding_function=embeddings, client=chroma_client) # 🚀 Step 4: Optimize Retriever (Lower `k` for Speed) retriever = db.as_retriever(search_kwargs={"k": 10}) # 🚀 Step 5: Define Prompt for the LLM prompt_template = PromptTemplate( template=""" You are a Kubernetes expert. **Answer the question using ONLY the provided context.** If the context does NOT contain enough information, return: `"I don't have enough information to answer this question."` Always include YAML examples when relevant. --- **Context:** {context} **Question:** {input} --- **Answer:** """, input_variables=["context", "input"] ) # 🚀 Step 6: Build Retrieval Chain qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9) llm = HuggingFacePipeline(pipeline=qa_pipeline) llm_chain = LLMChain(llm=llm, prompt=prompt_template) qa_chain = create_retrieval_chain(retriever, llm_chain) # 🚀 Step 7: Define Fast Answer Function def clean_context(context_list, max_tokens=350, min_length=50): """ Improves the retrieved document context: - Removes duplicates while preserving order - Filters out very short or unstructured text - Limits token count for better LLM performance """ from collections import OrderedDict # Preserve order while removing exact duplicates unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list)) # Remove very short texts (e.g., headers, page numbers) filtered_texts = [text for text in unique_texts if len(text.split()) > min_length] # Avoid near-duplicates deduplicated_texts = [] seen_texts = set() for text in filtered_texts: normalized_text = " ".join(text.split()) # Normalize spacing if not any(normalized_text in seen for seen in seen_texts): # Avoid near-duplicates deduplicated_texts.append(normalized_text) seen_texts.add(normalized_text) # Limit context based on token count trimmed_context = [] total_tokens = 0 for text in deduplicated_texts: tokenized_text = tokenizer.encode(text, add_special_tokens=False) token_count = len(tokenized_text) if total_tokens + token_count > max_tokens: remaining_tokens = max_tokens - total_tokens if remaining_tokens > 20: # Allow partial inclusion if meaningful trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens])) break trimmed_context.append(text) total_tokens += token_count return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found." def get_k8s_answer(query): retrieved_context = retriever.invoke(query) cleaned_context = clean_context(retrieved_context, max_tokens=350) input_text = prompt_template.format(context=cleaned_context, input=query) inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device) output_ids = model.generate(**inputs, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9) response = tokenizer.decode(output_ids[0], skip_special_tokens=True) return response # 🚀 Step 8: Optimize Gradio App with `Blocks()` with gr.Blocks(theme="soft") as demo: gr.Markdown("# ⚡ Kubernetes RAG") gr.Markdown("Ask any Kubernetes-related question!") with gr.Row(): question = gr.Textbox(label="Ask a Kubernetes Question", lines=1) answer = gr.Textbox(label="Answer", interactive=False) submit_button = gr.Button("Get Answer") submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer) demo.launch()