File size: 5,112 Bytes
da88e84
 
590d633
da88e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590d633
 
da88e84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590d633
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import chromadb
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import create_retrieval_chain, LLMChain
from langchain.prompts import PromptTemplate
from collections import OrderedDict

# Load embeddings model
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Load Chroma database (Avoid reprocessing documents)
CHROMA_PATH = "./chroma_db"
if not os.path.exists(CHROMA_PATH):
    raise FileNotFoundError("ChromaDB folder not found. Make sure it's uploaded to the repo.")

chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
db = Chroma(embedding_function=embeddings, client=chroma_client)

# Load the model
model_name = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Create pipeline
qa_pipeline = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0,
    max_length=512,
    min_length=50,
    do_sample=False,
    repetition_penalty=1.2
)

# Wrap pipeline in LangChain
llm = HuggingFacePipeline(pipeline=qa_pipeline)
retriever = db.as_retriever(search_kwargs={"k": 3})


def clean_context(context_list, max_tokens=350, min_length=50):
    """
    Cleans retrieved document context:
    - Removes duplicates while preserving order
    - Limits total token count
    - Ensures useful, non-repetitive context
    """

    # Preserve order while removing duplicates
    unique_texts = list(OrderedDict.fromkeys([doc.page_content.strip() for doc in context_list]))

    # Remove very short texts (e.g., headers)
    filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]

    # Avoid near-duplicate entries
    deduplicated_texts = []
    seen_texts = set()
    for text in filtered_texts:
        if not any(text in s for s in seen_texts):  # Avoid near-duplicates
            deduplicated_texts.append(text)
            seen_texts.add(text)

    # Limit context based on token count
    trimmed_context = []
    total_tokens = 0
    for text in deduplicated_texts:
        tokenized_text = tokenizer.encode(text, add_special_tokens=False)
        token_count = len(tokenized_text)

        if total_tokens + token_count > max_tokens:
            remaining_tokens = max_tokens - total_tokens
            if remaining_tokens > 20:
                trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
            break

        trimmed_context.append(text)
        total_tokens += token_count

    return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."

# Define prompt
prompt_template = PromptTemplate(
    template="""
    You are a Kubernetes instructor. Answer the question based on the provided context.
    If the context does not provide an answer, say "I don't have enough information."

    Context:
    {context}

    Question:
    {input}

    Answer:
    """,
    input_variables=["context", "input"]
)

llm_chain = LLMChain(llm=llm, prompt=prompt_template)
qa_chain = create_retrieval_chain(retriever, llm_chain)

# Query function
def get_k8s_answer(query):
    retrieved_context = retriever.get_relevant_documents(query)
    cleaned_context = clean_context(retrieved_context, max_tokens=350)  # Ensure context size is within limits

    # Ensure total input tokens < 512 before passing to model
    input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
    total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))

    if total_tokens > 512:
      # Trim context further to fit within the limit
      allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50  # 50 tokens for the model's response
      cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens)

      # Recalculate total tokens
      input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
      total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))

      if total_tokens > 512:
          return "Error: Even after trimming, input is too large."

    response = qa_chain.invoke({"input": query, "context": cleaned_context})
    return response

def get_k8s_answer_text(query):
  model_full_answer = get_k8s_answer(query)
  if 'answer' in model_full_answer.keys():
    if 'text' in model_full_answer['answer'].keys():
      return model_full_answer['answer']['text']
  return "Error"

# Gradio Interface
demo = gr.Interface(
    fn=get_k8s_answer_text,
    inputs=gr.Textbox(label="Ask a Kubernetes Question"),
    outputs=gr.Textbox(label="Answer"),
    title="Kubernetes RAG Assistant",
    description="Ask any Kubernetes-related question and get a step-by-step answer based on documentation."
)

if __name__ == "__main__":
    demo.launch()