File size: 5,112 Bytes
da88e84 590d633 da88e84 590d633 da88e84 590d633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import chromadb
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_retrieval_chain, LLMChain
from langchain.prompts import PromptTemplate
from collections import OrderedDict
# Load embeddings model
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Load Chroma database (Avoid reprocessing documents)
CHROMA_PATH = "./chroma_db"
if not os.path.exists(CHROMA_PATH):
raise FileNotFoundError("ChromaDB folder not found. Make sure it's uploaded to the repo.")
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
db = Chroma(embedding_function=embeddings, client=chroma_client)
# Load the model
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Create pipeline
qa_pipeline = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0,
max_length=512,
min_length=50,
do_sample=False,
repetition_penalty=1.2
)
# Wrap pipeline in LangChain
llm = HuggingFacePipeline(pipeline=qa_pipeline)
retriever = db.as_retriever(search_kwargs={"k": 3})
def clean_context(context_list, max_tokens=350, min_length=50):
"""
Cleans retrieved document context:
- Removes duplicates while preserving order
- Limits total token count
- Ensures useful, non-repetitive context
"""
# Preserve order while removing duplicates
unique_texts = list(OrderedDict.fromkeys([doc.page_content.strip() for doc in context_list]))
# Remove very short texts (e.g., headers)
filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
# Avoid near-duplicate entries
deduplicated_texts = []
seen_texts = set()
for text in filtered_texts:
if not any(text in s for s in seen_texts): # Avoid near-duplicates
deduplicated_texts.append(text)
seen_texts.add(text)
# Limit context based on token count
trimmed_context = []
total_tokens = 0
for text in deduplicated_texts:
tokenized_text = tokenizer.encode(text, add_special_tokens=False)
token_count = len(tokenized_text)
if total_tokens + token_count > max_tokens:
remaining_tokens = max_tokens - total_tokens
if remaining_tokens > 20:
trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
break
trimmed_context.append(text)
total_tokens += token_count
return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
# Define prompt
prompt_template = PromptTemplate(
template="""
You are a Kubernetes instructor. Answer the question based on the provided context.
If the context does not provide an answer, say "I don't have enough information."
Context:
{context}
Question:
{input}
Answer:
""",
input_variables=["context", "input"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt_template)
qa_chain = create_retrieval_chain(retriever, llm_chain)
# Query function
def get_k8s_answer(query):
retrieved_context = retriever.get_relevant_documents(query)
cleaned_context = clean_context(retrieved_context, max_tokens=350) # Ensure context size is within limits
# Ensure total input tokens < 512 before passing to model
input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
if total_tokens > 512:
# Trim context further to fit within the limit
allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # 50 tokens for the model's response
cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens)
# Recalculate total tokens
input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
if total_tokens > 512:
return "Error: Even after trimming, input is too large."
response = qa_chain.invoke({"input": query, "context": cleaned_context})
return response
def get_k8s_answer_text(query):
model_full_answer = get_k8s_answer(query)
if 'answer' in model_full_answer.keys():
if 'text' in model_full_answer['answer'].keys():
return model_full_answer['answer']['text']
return "Error"
# Gradio Interface
demo = gr.Interface(
fn=get_k8s_answer_text,
inputs=gr.Textbox(label="Ask a Kubernetes Question"),
outputs=gr.Textbox(label="Answer"),
title="Kubernetes RAG Assistant",
description="Ask any Kubernetes-related question and get a step-by-step answer based on documentation."
)
if __name__ == "__main__":
demo.launch()
|