K8sPilot / app.py
benjika's picture
Update app.py
da88e84 verified
raw
history blame
5.11 kB
import os
import chromadb
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_retrieval_chain, LLMChain
from langchain.prompts import PromptTemplate
from collections import OrderedDict
# Load embeddings model
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Load Chroma database (Avoid reprocessing documents)
CHROMA_PATH = "./chroma_db"
if not os.path.exists(CHROMA_PATH):
raise FileNotFoundError("ChromaDB folder not found. Make sure it's uploaded to the repo.")
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
db = Chroma(embedding_function=embeddings, client=chroma_client)
# Load the model
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Create pipeline
qa_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0,
max_length=512,
min_length=50,
do_sample=False,
repetition_penalty=1.2
)
# Wrap pipeline in LangChain
llm = HuggingFacePipeline(pipeline=qa_pipeline)
retriever = db.as_retriever(search_kwargs={"k": 3})
def clean_context(context_list, max_tokens=350, min_length=50):
"""
Cleans retrieved document context:
- Removes duplicates while preserving order
- Limits total token count
- Ensures useful, non-repetitive context
"""
# Preserve order while removing duplicates
unique_texts = list(OrderedDict.fromkeys([doc.page_content.strip() for doc in context_list]))
# Remove very short texts (e.g., headers)
filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
# Avoid near-duplicate entries
deduplicated_texts = []
seen_texts = set()
for text in filtered_texts:
if not any(text in s for s in seen_texts): # Avoid near-duplicates
deduplicated_texts.append(text)
seen_texts.add(text)
# Limit context based on token count
trimmed_context = []
total_tokens = 0
for text in deduplicated_texts:
tokenized_text = tokenizer.encode(text, add_special_tokens=False)
token_count = len(tokenized_text)
if total_tokens + token_count > max_tokens:
remaining_tokens = max_tokens - total_tokens
if remaining_tokens > 20:
trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
break
trimmed_context.append(text)
total_tokens += token_count
return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
# Define prompt
prompt_template = PromptTemplate(
template="""
You are a Kubernetes instructor. Answer the question based on the provided context.
If the context does not provide an answer, say "I don't have enough information."
Context:
{context}
Question:
{input}
Answer:
""",
input_variables=["context", "input"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
qa_chain = create_retrieval_chain(retriever, llm_chain)
# Query function
def get_k8s_answer(query):
retrieved_context = retriever.get_relevant_documents(query)
cleaned_context = clean_context(retrieved_context, max_tokens=350) # Ensure context size is within limits
# Ensure total input tokens < 512 before passing to model
input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
if total_tokens > 512:
# Trim context further to fit within the limit
allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # 50 tokens for the model's response
cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens)
# Recalculate total tokens
input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
if total_tokens > 512:
return "Error: Even after trimming, input is too large."
response = qa_chain.invoke({"input": query, "context": cleaned_context})
return response
def get_k8s_answer_text(query):
model_full_answer = get_k8s_answer(query)
if 'answer' in model_full_answer.keys():
if 'text' in model_full_answer['answer'].keys():
return model_full_answer['answer']['text']
return "Error"
# Gradio Interface
demo = gr.Interface(
fn=get_k8s_answer_text,
inputs=gr.Textbox(label="Ask a Kubernetes Question"),
outputs=gr.Textbox(label="Answer"),
title="Kubernetes RAG Assistant",
description="Ask any Kubernetes-related question and get a step-by-step answer based on documentation."
)
if __name__ == "__main__":
demo.launch()