K8sPilot / app.py
benjika's picture
Update app.py
4a0eba8 verified
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import chromadb
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.chains import create_retrieval_chain, LLMChain
from langchain.prompts import PromptTemplate
import os
import shutil
import zipfile
# πŸš€ Step 1: Extract ChromaDB if not already done (only once)
if not os.path.exists("./chroma_db"):
with zipfile.ZipFile("chroma.zip", "r") as zip_ref:
zip_ref.extractall("./chroma_db")
# πŸš€ Step 2: Load Pre-trained Model & Tokenizer (Fast Startup)
MODEL_NAME = "google/flan-t5-xl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# πŸš€ Step 3: Load Vector Store Efficiently
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
chroma_client = chromadb.PersistentClient(path="./chroma_db")
db = Chroma(embedding_function=embeddings, client=chroma_client)
# πŸš€ Step 4: Optimize Retriever (Lower `k` for Speed)
retriever = db.as_retriever(search_kwargs={"k": 10})
# πŸš€ Step 5: Define Prompt for the LLM
prompt_template = PromptTemplate(
template="""
You are a Kubernetes expert.
**Answer the question using ONLY the provided context.**
If the context does NOT contain enough information, return:
`"I don't have enough information to answer this question."`
Always include YAML examples when relevant.
---
**Context:**
{context}
**Question:**
{input}
---
**Answer:**
""",
input_variables=["context", "input"]
)
# πŸš€ Step 6: Build Retrieval Chain
qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0,
max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
llm = HuggingFacePipeline(pipeline=qa_pipeline)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
qa_chain = create_retrieval_chain(retriever, llm_chain)
# πŸš€ Step 7: Define Fast Answer Function
def clean_context(context_list, max_tokens=350, min_length=50):
"""
Improves the retrieved document context:
- Removes duplicates while preserving order
- Filters out very short or unstructured text
- Limits token count for better LLM performance
"""
from collections import OrderedDict
# Preserve order while removing exact duplicates
unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list))
# Remove very short texts (e.g., headers, page numbers)
filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
# Avoid near-duplicates
deduplicated_texts = []
seen_texts = set()
for text in filtered_texts:
normalized_text = " ".join(text.split()) # Normalize spacing
if not any(normalized_text in seen for seen in seen_texts): # Avoid near-duplicates
deduplicated_texts.append(normalized_text)
seen_texts.add(normalized_text)
# Limit context based on token count
trimmed_context = []
total_tokens = 0
for text in deduplicated_texts:
tokenized_text = tokenizer.encode(text, add_special_tokens=False)
token_count = len(tokenized_text)
if total_tokens + token_count > max_tokens:
remaining_tokens = max_tokens - total_tokens
if remaining_tokens > 20: # Allow partial inclusion if meaningful
trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
break
trimmed_context.append(text)
total_tokens += token_count
return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
def get_k8s_answer(query):
retrieved_context = retriever.invoke(query)
cleaned_context = clean_context(retrieved_context, max_tokens=350)
input_text = prompt_template.format(context=cleaned_context, input=query)
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
output_ids = model.generate(**inputs, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
# πŸš€ Step 8: Optimize Gradio App with `Blocks()`
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# ⚑ Kubernetes RAG")
gr.Markdown("Ask any Kubernetes-related question!")
with gr.Row():
question = gr.Textbox(label="Ask a Kubernetes Question", lines=1)
answer = gr.Textbox(label="Answer", interactive=False)
submit_button = gr.Button("Get Answer")
submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer)
demo.launch()