File size: 4,880 Bytes
7cb0a3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import chromadb
from langchain.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.chains import create_retrieval_chain, LLMChain
from langchain.prompts import PromptTemplate
import os
import shutil
import zipfile
# π Step 1: Extract ChromaDB if not already done (only once)
if not os.path.exists("./chroma_db"):
with zipfile.ZipFile("chroma.zip", "r") as zip_ref:
zip_ref.extractall("./chroma_db")
# π Step 2: Load Pre-trained Model & Tokenizer (Fast Startup)
MODEL_NAME = "google/flan-t5-xl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# π Step 3: Load Vector Store Efficiently
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
chroma_client = chromadb.PersistentClient(path="./chroma_db")
db = Chroma(embedding_function=embeddings, client=chroma_client)
# π Step 4: Optimize Retriever (Lower `k` for Speed)
retriever = db.as_retriever(search_kwargs={"k": 10})
# π Step 5: Define Prompt for the LLM
prompt_template = PromptTemplate(
template="""
You are a Kubernetes expert.
**Answer the question using ONLY the provided context.**
If the context does NOT contain enough information, return:
`"I don't have enough information to answer this question."`
Always include YAML examples when relevant.
---
**Context:**
{context}
**Question:**
{input}
---
**Answer:**
""",
input_variables=["context", "input"]
)
# π Step 6: Build Retrieval Chain
qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0,
max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
llm = HuggingFacePipeline(pipeline=qa_pipeline)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
qa_chain = create_retrieval_chain(retriever, llm_chain)
# π Step 7: Define Fast Answer Function
def clean_context(context_list, max_tokens=350, min_length=50):
"""
Improves the retrieved document context:
- Removes duplicates while preserving order
- Filters out very short or unstructured text
- Limits token count for better LLM performance
"""
from collections import OrderedDict
# Preserve order while removing exact duplicates
unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list))
# Remove very short texts (e.g., headers, page numbers)
filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
# Avoid near-duplicates
deduplicated_texts = []
seen_texts = set()
for text in filtered_texts:
normalized_text = " ".join(text.split()) # Normalize spacing
if not any(normalized_text in seen for seen in seen_texts): # Avoid near-duplicates
deduplicated_texts.append(normalized_text)
seen_texts.add(normalized_text)
# Limit context based on token count
trimmed_context = []
total_tokens = 0
for text in deduplicated_texts:
tokenized_text = tokenizer.encode(text, add_special_tokens=False)
token_count = len(tokenized_text)
if total_tokens + token_count > max_tokens:
remaining_tokens = max_tokens - total_tokens
if remaining_tokens > 20: # Allow partial inclusion if meaningful
trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
break
trimmed_context.append(text)
total_tokens += token_count
return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
def get_k8s_answer(query):
retrieved_context = retriever.get_relevant_documents(query)
cleaned_context = clean_context(retrieved_context, max_tokens=350)
input_text = prompt_template.format(context=cleaned_context, input=query)
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
output_ids = model.generate(**inputs, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
# π Step 8: Optimize Gradio App with `Blocks()`
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# β‘ Kubernetes RAG")
gr.Markdown("Ask any Kubernetes-related question!")
with gr.Row():
question = gr.Textbox(label="Ask a Kubernetes Question", lines=1)
answer = gr.Textbox(label="Answer", interactive=False)
submit_button = gr.Button("Get Answer")
submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer)
demo.launch()
|