File size: 6,750 Bytes
25d1cc3
99afe26
 
25d1cc3
99afe26
 
 
 
 
 
9f3c9bf
25d1cc3
 
99afe26
 
 
 
25d1cc3
 
 
 
99afe26
25d1cc3
 
 
 
 
 
 
 
9201e2b
25d1cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99afe26
25d1cc3
 
99afe26
25d1cc3
 
 
99afe26
 
 
 
25d1cc3
 
 
99afe26
 
25d1cc3
 
 
 
 
 
 
99afe26
 
25d1cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99afe26
25d1cc3
99afe26
 
 
 
25d1cc3
99afe26
 
25d1cc3
 
 
 
99afe26
 
 
 
25d1cc3
 
 
 
99afe26
25d1cc3
 
 
 
 
 
 
99afe26
25d1cc3
 
 
 
 
 
 
 
99afe26
 
25d1cc3
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# backend/main.py
import os
import json
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form, Response
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder

from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFacePipeline        # NEW
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain import PromptTemplate

# -------- optional OpenAI imports (kept, but disabled) ----------
# from langchain.llms import OpenAI
# from langchain.embeddings import OpenAIEmbeddings
# ---------------------------------------------------------------

from ingest import Ingest
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

# ------------------------------------------------------------------
# 1. ENVIRONMENT
# ------------------------------------------------------------------
load_dotenv()
HF_TOKEN       = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
    raise ValueError("HUGGINGFACE_TOKEN not set in the environment.")

# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")        # Optional
# if OPENAI_API_KEY is None:
#     print("OpenAI key missing – OpenAI path disabled.")

# ------------------------------------------------------------------
# 2. LLM & EMBEDDINGS CONFIGURATION
# ------------------------------------------------------------------
DEFAULT_LLM = "google/gemma-3-4b-it"                    # change here if desired
EMB_EN      = "sentence-transformers/all-MiniLM-L6-v2"
EMB_CZ      = "Seznam/retromae-small-cs"

def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline:
    """
    Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface.
    Works on CPU; uses half precision automatically when CUDA is available.
    """
    dtype   = torch.float16 if torch.cuda.is_available() else torch.float32
    model   = AutoModelForCausalLM.from_pretrained(
        model_id,
        token            = HF_TOKEN,
        torch_dtype      = dtype,
        device_map       = "auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
    gen_pipe  = pipeline(
        task              = "text-generation",
        model             = model,
        tokenizer         = tokenizer,
        max_new_tokens    = 512,
        temperature       = 0.2,
        top_p             = 0.95,
    )
    return HuggingFacePipeline(pipeline=gen_pipe)

HF_LLM = build_hf_llm()           # Initialise once; reuse in every request
# OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0)   # optional

# ------------------------------------------------------------------
# 3. FASTAPI PLUMBING
# ------------------------------------------------------------------
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")

# Embedding stores
CZECH_STORE   = "stores/czech_512"
ENGLISH_STORE = "stores/english_512"

ingestor = Ingest(
    # openai_api_key = OPENAI_API_KEY,   # still needed only if you ingest via OpenAI embeds
    chunk                   = 512,
    overlap                 = 256,
    czech_store             = CZECH_STORE,
    english_store           = ENGLISH_STORE,
    czech_embedding_model   = EMB_CZ,
    english_embedding_model = EMB_EN,
)

# ------------------------------------------------------------------
# 4. PROMPTS
# ------------------------------------------------------------------
def prompt_en() -> PromptTemplate:
    tmpl = """You are an electrical engineer and you answer users' ###Question.
# Your answer must be helpful, relevant and closely related to the user's ###Question.
# Quote literally from the ###Context wherever possible.
# Use your own words only to connect or clarify. If you don't know, say so.
###Context: {context}
###Question: {question}
Helpful answer:
"""
    return PromptTemplate(template=tmpl, input_variables=["context", "question"])

def prompt_cz() -> PromptTemplate:
    tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku.
# Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou.
# Citujte co nejvíce doslovně z ###Kontextu.
# Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to.
###Kontext: {context}
###Otázka: {question}
Užitečná odpověď:
"""
    return PromptTemplate(template=tmpl, input_variables=["context", "question"])

# ------------------------------------------------------------------
# 5. ROUTES
# ------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/ingest_data")
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
    if language.lower() == "czech":
        ingestor.data_czech = folderPath
        ingestor.ingest_czech()
        return {"message": "Czech data ingestion complete."}
    ingestor.data_english = folderPath
    ingestor.ingest_english()
    return {"message": "English data ingestion complete."}

@app.post("/get_response")
async def get_response(query: str = Form(...), language: str = Form(...)):

    is_czech   = language.lower() == "czech"
    prompt     = prompt_cz() if is_czech else prompt_en()
    store_path = CZECH_STORE if is_czech else ENGLISH_STORE
    embed_name = EMB_CZ if is_czech else EMB_EN

    embeddings = HuggingFaceEmbeddings(
        model_name   = embed_name,
        model_kwargs = {"device": "cpu"},
        encode_kwargs= {"normalize_embeddings": False}
    )
    vectordb  = FAISS.load_local(store_path, embeddings)
    retriever = vectordb.as_retriever(search_kwargs={"k": 2})

    qa_chain = RetrievalQA.from_chain_type(
        llm                    = HF_LLM,           # <- default open-source model
        # llm                  = OPENAI_LLM,       # <- optional paid model
        chain_type             = "stuff",
        retriever              = retriever,
        return_source_documents= True,
        chain_type_kwargs      = {"prompt": prompt},
        verbose                = True,
    )

    result   = qa_chain(query)
    answer   = result["result"]
    src_doc  = result["source_documents"][0].page_content
    src_path = result["source_documents"][0].metadata["source"]

    payload  = jsonable_encoder(json.dumps({
        "answer"          : answer,
        "source_document" : src_doc,
        "doc"             : src_path
    }))
    return Response(payload)