|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
import chromadb |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline |
|
from langchain.chains import create_retrieval_chain, LLMChain |
|
from langchain.prompts import PromptTemplate |
|
import os |
|
import shutil |
|
import zipfile |
|
|
|
|
|
if not os.path.exists("./chroma_db"): |
|
with zipfile.ZipFile("chroma.zip", "r") as zip_ref: |
|
zip_ref.extractall("./chroma_db") |
|
|
|
|
|
MODEL_NAME = "google/flan-t5-xl" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
chroma_client = chromadb.PersistentClient(path="./chroma_db") |
|
db = Chroma(embedding_function=embeddings, client=chroma_client) |
|
|
|
|
|
retriever = db.as_retriever(search_kwargs={"k": 10}) |
|
|
|
|
|
prompt_template = PromptTemplate( |
|
template=""" |
|
You are a Kubernetes expert. |
|
**Answer the question using ONLY the provided context.** |
|
If the context does NOT contain enough information, return: |
|
`"I don't have enough information to answer this question."` |
|
Always include YAML examples when relevant. |
|
|
|
--- |
|
**Context:** |
|
{context} |
|
|
|
**Question:** |
|
{input} |
|
|
|
--- |
|
**Answer:** |
|
""", |
|
input_variables=["context", "input"] |
|
) |
|
|
|
|
|
qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0, |
|
max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9) |
|
llm = HuggingFacePipeline(pipeline=qa_pipeline) |
|
llm_chain = LLMChain(llm=llm, prompt=prompt_template) |
|
qa_chain = create_retrieval_chain(retriever, llm_chain) |
|
|
|
|
|
def clean_context(context_list, max_tokens=350, min_length=50): |
|
""" |
|
Improves the retrieved document context: |
|
- Removes duplicates while preserving order |
|
- Filters out very short or unstructured text |
|
- Limits token count for better LLM performance |
|
""" |
|
from collections import OrderedDict |
|
|
|
|
|
unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list)) |
|
|
|
|
|
filtered_texts = [text for text in unique_texts if len(text.split()) > min_length] |
|
|
|
|
|
deduplicated_texts = [] |
|
seen_texts = set() |
|
for text in filtered_texts: |
|
normalized_text = " ".join(text.split()) |
|
if not any(normalized_text in seen for seen in seen_texts): |
|
deduplicated_texts.append(normalized_text) |
|
seen_texts.add(normalized_text) |
|
|
|
|
|
trimmed_context = [] |
|
total_tokens = 0 |
|
for text in deduplicated_texts: |
|
tokenized_text = tokenizer.encode(text, add_special_tokens=False) |
|
token_count = len(tokenized_text) |
|
|
|
if total_tokens + token_count > max_tokens: |
|
remaining_tokens = max_tokens - total_tokens |
|
if remaining_tokens > 20: |
|
trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens])) |
|
break |
|
|
|
trimmed_context.append(text) |
|
total_tokens += token_count |
|
|
|
return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found." |
|
|
|
|
|
def get_k8s_answer(query): |
|
retrieved_context = retriever.invoke(query) |
|
cleaned_context = clean_context(retrieved_context, max_tokens=350) |
|
|
|
input_text = prompt_template.format(context=cleaned_context, input=query) |
|
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device) |
|
output_ids = model.generate(**inputs, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9) |
|
|
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
return response |
|
|
|
|
|
with gr.Blocks(theme="soft") as demo: |
|
gr.Markdown("# β‘ Kubernetes RAG") |
|
gr.Markdown("Ask any Kubernetes-related question!") |
|
|
|
with gr.Row(): |
|
question = gr.Textbox(label="Ask a Kubernetes Question", lines=1) |
|
answer = gr.Textbox(label="Answer", interactive=False) |
|
|
|
submit_button = gr.Button("Get Answer") |
|
|
|
submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer) |
|
|
|
demo.launch() |
|
|