|
import os |
|
import chromadb |
|
import gradio as gr |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
from langchain_chroma import Chroma |
|
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.chains import create_retrieval_chain, LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from collections import OrderedDict |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
CHROMA_PATH = "./chroma_db" |
|
if not os.path.exists(CHROMA_PATH): |
|
raise FileNotFoundError("ChromaDB folder not found. Make sure it's uploaded to the repo.") |
|
|
|
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) |
|
db = Chroma(embedding_function=embeddings, client=chroma_client) |
|
|
|
|
|
model_name = "google/flan-t5-large" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
qa_pipeline = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0, |
|
max_length=512, |
|
min_length=50, |
|
do_sample=False, |
|
repetition_penalty=1.2 |
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=qa_pipeline) |
|
retriever = db.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
def clean_context(context_list, max_tokens=350, min_length=50): |
|
""" |
|
Cleans retrieved document context: |
|
- Removes duplicates while preserving order |
|
- Limits total token count |
|
- Ensures useful, non-repetitive context |
|
""" |
|
|
|
|
|
unique_texts = list(OrderedDict.fromkeys([doc.page_content.strip() for doc in context_list])) |
|
|
|
|
|
filtered_texts = [text for text in unique_texts if len(text.split()) > min_length] |
|
|
|
|
|
deduplicated_texts = [] |
|
seen_texts = set() |
|
for text in filtered_texts: |
|
if not any(text in s for s in seen_texts): |
|
deduplicated_texts.append(text) |
|
seen_texts.add(text) |
|
|
|
|
|
trimmed_context = [] |
|
total_tokens = 0 |
|
for text in deduplicated_texts: |
|
tokenized_text = tokenizer.encode(text, add_special_tokens=False) |
|
token_count = len(tokenized_text) |
|
|
|
if total_tokens + token_count > max_tokens: |
|
remaining_tokens = max_tokens - total_tokens |
|
if remaining_tokens > 20: |
|
trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens])) |
|
break |
|
|
|
trimmed_context.append(text) |
|
total_tokens += token_count |
|
|
|
return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found." |
|
|
|
|
|
prompt_template = PromptTemplate( |
|
template=""" |
|
You are a Kubernetes instructor. Answer the question based on the provided context. |
|
If the context does not provide an answer, say "I don't have enough information." |
|
|
|
Context: |
|
{context} |
|
|
|
Question: |
|
{input} |
|
|
|
Answer: |
|
""", |
|
input_variables=["context", "input"] |
|
) |
|
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt_template) |
|
qa_chain = create_retrieval_chain(retriever, llm_chain) |
|
|
|
|
|
def get_k8s_answer(query): |
|
retrieved_context = retriever.get_relevant_documents(query) |
|
cleaned_context = clean_context(retrieved_context, max_tokens=350) |
|
|
|
|
|
input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:" |
|
total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True)) |
|
|
|
if total_tokens > 512: |
|
|
|
allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 |
|
cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens) |
|
|
|
|
|
input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:" |
|
total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True)) |
|
|
|
if total_tokens > 512: |
|
return "Error: Even after trimming, input is too large." |
|
|
|
response = qa_chain.invoke({"input": query, "context": cleaned_context}) |
|
return response |
|
|
|
def get_k8s_answer_text(query): |
|
model_full_answer = get_k8s_answer(query) |
|
if 'answer' in model_full_answer.keys(): |
|
if 'text' in model_full_answer['answer'].keys(): |
|
return model_full_answer['answer']['text'] |
|
return "Error" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=get_k8s_answer_text, |
|
inputs=gr.Textbox(label="Ask a Kubernetes Question"), |
|
outputs=gr.Textbox(label="Answer"), |
|
title="Kubernetes RAG Assistant", |
|
description="Ask any Kubernetes-related question and get a step-by-step answer based on documentation." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|