import os import chromadb import gradio as gr from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline from langchain_chroma import Chroma from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline from langchain.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import create_retrieval_chain, LLMChain from langchain.prompts import PromptTemplate from collections import OrderedDict # Load embeddings model embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") # Load Chroma database (Avoid reprocessing documents) CHROMA_PATH = "./chroma_db" if not os.path.exists(CHROMA_PATH): raise FileNotFoundError("ChromaDB folder not found. Make sure it's uploaded to the repo.") chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) db = Chroma(embedding_function=embeddings, client=chroma_client) # Load the model model_name = "google/flan-t5-large" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Create pipeline qa_pipeline = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, device=0, max_length=512, min_length=50, do_sample=False, repetition_penalty=1.2 ) # Wrap pipeline in LangChain llm = HuggingFacePipeline(pipeline=qa_pipeline) retriever = db.as_retriever(search_kwargs={"k": 3}) def clean_context(context_list, max_tokens=350, min_length=50): """ Cleans retrieved document context: - Removes duplicates while preserving order - Limits total token count - Ensures useful, non-repetitive context """ # Preserve order while removing duplicates unique_texts = list(OrderedDict.fromkeys([doc.page_content.strip() for doc in context_list])) # Remove very short texts (e.g., headers) filtered_texts = [text for text in unique_texts if len(text.split()) > min_length] # Avoid near-duplicate entries deduplicated_texts = [] seen_texts = set() for text in filtered_texts: if not any(text in s for s in seen_texts): # Avoid near-duplicates deduplicated_texts.append(text) seen_texts.add(text) # Limit context based on token count trimmed_context = [] total_tokens = 0 for text in deduplicated_texts: tokenized_text = tokenizer.encode(text, add_special_tokens=False) token_count = len(tokenized_text) if total_tokens + token_count > max_tokens: remaining_tokens = max_tokens - total_tokens if remaining_tokens > 20: trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens])) break trimmed_context.append(text) total_tokens += token_count return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found." # Define prompt prompt_template = PromptTemplate( template=""" You are a Kubernetes instructor. Answer the question based on the provided context. If the context does not provide an answer, say "I don't have enough information." Context: {context} Question: {input} Answer: """, input_variables=["context", "input"] ) llm_chain = LLMChain(llm=llm, prompt=prompt_template) qa_chain = create_retrieval_chain(retriever, llm_chain) # Query function def get_k8s_answer(query): retrieved_context = retriever.get_relevant_documents(query) cleaned_context = clean_context(retrieved_context, max_tokens=350) # Ensure context size is within limits # Ensure total input tokens < 512 before passing to model input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:" total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True)) if total_tokens > 512: # Trim context further to fit within the limit allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # 50 tokens for the model's response cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens) # Recalculate total tokens input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:" total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True)) if total_tokens > 512: return "Error: Even after trimming, input is too large." response = qa_chain.invoke({"input": query, "context": cleaned_context}) return response def get_k8s_answer_text(query): model_full_answer = get_k8s_answer(query) if 'answer' in model_full_answer.keys(): if 'text' in model_full_answer['answer'].keys(): return model_full_answer['answer']['text'] return "Error" # Gradio Interface demo = gr.Interface( fn=get_k8s_answer_text, inputs=gr.Textbox(label="Ask a Kubernetes Question"), outputs=gr.Textbox(label="Answer"), title="Kubernetes RAG Assistant", description="Ask any Kubernetes-related question and get a step-by-step answer based on documentation." ) if __name__ == "__main__": demo.launch()