Spaces:
Runtime error
Runtime error
File size: 6,750 Bytes
25d1cc3 99afe26 25d1cc3 99afe26 9f3c9bf 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 9201e2b 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 99afe26 25d1cc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# backend/main.py
import os
import json
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form, Response
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.encoders import jsonable_encoder
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFacePipeline # NEW
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
# -------- optional OpenAI imports (kept, but disabled) ----------
# from langchain.llms import OpenAI
# from langchain.embeddings import OpenAIEmbeddings
# ---------------------------------------------------------------
from ingest import Ingest
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# ------------------------------------------------------------------
# 1. ENVIRONMENT
# ------------------------------------------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HUGGINGFACE_TOKEN not set in the environment.")
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional
# if OPENAI_API_KEY is None:
# print("OpenAI key missing – OpenAI path disabled.")
# ------------------------------------------------------------------
# 2. LLM & EMBEDDINGS CONFIGURATION
# ------------------------------------------------------------------
DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired
EMB_EN = "sentence-transformers/all-MiniLM-L6-v2"
EMB_CZ = "Seznam/retromae-small-cs"
def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline:
"""
Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface.
Works on CPU; uses half precision automatically when CUDA is available.
"""
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
token = HF_TOKEN,
torch_dtype = dtype,
device_map = "auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
gen_pipe = pipeline(
task = "text-generation",
model = model,
tokenizer = tokenizer,
max_new_tokens = 512,
temperature = 0.2,
top_p = 0.95,
)
return HuggingFacePipeline(pipeline=gen_pipe)
HF_LLM = build_hf_llm() # Initialise once; reuse in every request
# OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional
# ------------------------------------------------------------------
# 3. FASTAPI PLUMBING
# ------------------------------------------------------------------
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
# Embedding stores
CZECH_STORE = "stores/czech_512"
ENGLISH_STORE = "stores/english_512"
ingestor = Ingest(
# openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds
chunk = 512,
overlap = 256,
czech_store = CZECH_STORE,
english_store = ENGLISH_STORE,
czech_embedding_model = EMB_CZ,
english_embedding_model = EMB_EN,
)
# ------------------------------------------------------------------
# 4. PROMPTS
# ------------------------------------------------------------------
def prompt_en() -> PromptTemplate:
tmpl = """You are an electrical engineer and you answer users' ###Question.
# Your answer must be helpful, relevant and closely related to the user's ###Question.
# Quote literally from the ###Context wherever possible.
# Use your own words only to connect or clarify. If you don't know, say so.
###Context: {context}
###Question: {question}
Helpful answer:
"""
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
def prompt_cz() -> PromptTemplate:
tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku.
# Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou.
# Citujte co nejvíce doslovně z ###Kontextu.
# Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to.
###Kontext: {context}
###Otázka: {question}
Užitečná odpověď:
"""
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
# ------------------------------------------------------------------
# 5. ROUTES
# ------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/ingest_data")
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
if language.lower() == "czech":
ingestor.data_czech = folderPath
ingestor.ingest_czech()
return {"message": "Czech data ingestion complete."}
ingestor.data_english = folderPath
ingestor.ingest_english()
return {"message": "English data ingestion complete."}
@app.post("/get_response")
async def get_response(query: str = Form(...), language: str = Form(...)):
is_czech = language.lower() == "czech"
prompt = prompt_cz() if is_czech else prompt_en()
store_path = CZECH_STORE if is_czech else ENGLISH_STORE
embed_name = EMB_CZ if is_czech else EMB_EN
embeddings = HuggingFaceEmbeddings(
model_name = embed_name,
model_kwargs = {"device": "cpu"},
encode_kwargs= {"normalize_embeddings": False}
)
vectordb = FAISS.load_local(store_path, embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
qa_chain = RetrievalQA.from_chain_type(
llm = HF_LLM, # <- default open-source model
# llm = OPENAI_LLM, # <- optional paid model
chain_type = "stuff",
retriever = retriever,
return_source_documents= True,
chain_type_kwargs = {"prompt": prompt},
verbose = True,
)
result = qa_chain(query)
answer = result["result"]
src_doc = result["source_documents"][0].page_content
src_path = result["source_documents"][0].metadata["source"]
payload = jsonable_encoder(json.dumps({
"answer" : answer,
"source_document" : src_doc,
"doc" : src_path
}))
return Response(payload)
|