Spaces:
Runtime error
Runtime error
Update fast_app.py
Browse files- fast_app.py +132 -131
fast_app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import json
|
|
|
4 |
from fastapi import FastAPI, Request, Form, Response
|
5 |
from fastapi.responses import HTMLResponse
|
6 |
from fastapi.templating import Jinja2Templates
|
@@ -8,163 +9,163 @@ from fastapi.staticfiles import StaticFiles
|
|
8 |
from fastapi.encoders import jsonable_encoder
|
9 |
|
10 |
from langchain_community.vectorstores import FAISS
|
|
|
|
|
11 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
-
|
13 |
from langchain.chains import RetrievalQA
|
14 |
-
|
15 |
-
from langchain.llms import OpenAI
|
16 |
from langchain import PromptTemplate
|
17 |
-
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
|
18 |
-
|
19 |
-
from ingest import Ingest
|
20 |
-
|
21 |
-
# setx OPENAI_API_KEY "your_openai_api_key_here"
|
22 |
|
23 |
-
#
|
24 |
-
#
|
25 |
-
#
|
26 |
-
#
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
|
|
32 |
|
|
|
|
|
|
|
33 |
app = FastAPI()
|
34 |
templates = Jinja2Templates(directory="templates")
|
35 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
36 |
-
english_embedding_model = "text-embedding-3-large"
|
37 |
-
czech_embedding_model = "Seznam/retromae-small-cs"
|
38 |
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
ingestor = Ingest(
|
43 |
-
openai_api_key=
|
44 |
-
chunk=512,
|
45 |
-
overlap=256,
|
46 |
-
czech_store=
|
47 |
-
english_store=
|
48 |
-
czech_embedding_model=
|
49 |
-
english_embedding_model=
|
50 |
)
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
#Pokud odpověď neznáte, prostě řekněte, že to nevíte, nepokoušejte se vymýšlet odpověď.
|
81 |
-
|
82 |
-
###Kontext: {context}
|
83 |
-
###Otázka: {question}
|
84 |
-
|
85 |
-
Níže vraťte pouze užitečnou odpověď a nic jiného.
|
86 |
-
Užitečná odpověď:
|
87 |
-
"""
|
88 |
-
prompt_cz = PromptTemplate(
|
89 |
-
template=prompt_template_cz, input_variables=["context", "question"]
|
90 |
-
)
|
91 |
-
print("\n Prompt ready... \n\n")
|
92 |
-
return prompt_cz
|
93 |
-
|
94 |
-
|
95 |
@app.get("/", response_class=HTMLResponse)
|
96 |
-
def
|
97 |
return templates.TemplateResponse("index.html", {"request": request})
|
98 |
|
99 |
-
|
100 |
@app.post("/ingest_data")
|
101 |
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
|
102 |
-
|
103 |
-
if language == "czech":
|
104 |
-
print("\n Czech language selected....\n\n")
|
105 |
ingestor.data_czech = folderPath
|
106 |
ingestor.ingest_czech()
|
107 |
-
message
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
ingestor.ingest_english()
|
112 |
-
message = "English data ingestion complete."
|
113 |
-
|
114 |
-
return {"message": message}
|
115 |
-
|
116 |
|
117 |
@app.post("/get_response")
|
118 |
async def get_response(query: str = Form(...), language: str = Form(...)):
|
119 |
-
print(language)
|
120 |
-
if language == "czech":
|
121 |
-
prompt = prompt_cz()
|
122 |
-
print("\n Czech language selected....\n\n")
|
123 |
-
embedding_model = czech_embedding_model
|
124 |
-
persist_directory = czech_store
|
125 |
-
model_name = embedding_model
|
126 |
-
model_kwargs = {"device": "cpu"}
|
127 |
-
encode_kwargs = {"normalize_embeddings": False}
|
128 |
-
embedding = HuggingFaceEmbeddings(
|
129 |
-
model_name=model_name,
|
130 |
-
model_kwargs=model_kwargs,
|
131 |
-
encode_kwargs=encode_kwargs,
|
132 |
-
)
|
133 |
-
else:
|
134 |
-
prompt = prompt_en()
|
135 |
-
print("\n English language selected....\n\n")
|
136 |
-
embedding_model = english_embedding_model # Default to English
|
137 |
-
persist_directory = english_store
|
138 |
-
embedding = OpenAIEmbeddings(
|
139 |
-
openai_api_key=openai_api_key,
|
140 |
-
model=embedding_model,
|
141 |
-
)
|
142 |
-
|
143 |
-
vectordb = FAISS.load_local(persist_directory, embedding)
|
144 |
-
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
retriever=retriever,
|
151 |
-
return_source_documents=True,
|
152 |
-
chain_type_kwargs=chain_type_kwargs,
|
153 |
-
verbose=True,
|
154 |
-
)
|
155 |
-
response = qa_chain(query)
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
167 |
)
|
168 |
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# backend/main.py
|
2 |
import os
|
3 |
import json
|
4 |
+
from dotenv import load_dotenv
|
5 |
from fastapi import FastAPI, Request, Form, Response
|
6 |
from fastapi.responses import HTMLResponse
|
7 |
from fastapi.templating import Jinja2Templates
|
|
|
9 |
from fastapi.encoders import jsonable_encoder
|
10 |
|
11 |
from langchain_community.vectorstores import FAISS
|
12 |
+
from langchain_community.llms import HuggingFacePipeline # NEW
|
13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
14 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
15 |
from langchain.chains import RetrievalQA
|
|
|
|
|
16 |
from langchain import PromptTemplate
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
# -------- optional OpenAI imports (kept, but disabled) ----------
|
19 |
+
# from langchain.llms import OpenAI
|
20 |
+
# from langchain.embeddings import OpenAIEmbeddings
|
21 |
+
# ---------------------------------------------------------------
|
22 |
|
23 |
+
from ingest import Ingest
|
24 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
25 |
+
import torch
|
26 |
+
|
27 |
+
# ------------------------------------------------------------------
|
28 |
+
# 1. ENVIRONMENT
|
29 |
+
# ------------------------------------------------------------------
|
30 |
+
load_dotenv()
|
31 |
+
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
32 |
+
if HF_TOKEN is None:
|
33 |
+
raise ValueError("HUGGINGFACE_TOKEN not set in the environment.")
|
34 |
+
|
35 |
+
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional
|
36 |
+
# if OPENAI_API_KEY is None:
|
37 |
+
# print("OpenAI key missing – OpenAI path disabled.")
|
38 |
+
|
39 |
+
# ------------------------------------------------------------------
|
40 |
+
# 2. LLM & EMBEDDINGS CONFIGURATION
|
41 |
+
# ------------------------------------------------------------------
|
42 |
+
DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired
|
43 |
+
EMB_EN = "sentence-transformers/all-MiniLM-L6-v2"
|
44 |
+
EMB_CZ = "Seznam/retromae-small-cs"
|
45 |
+
|
46 |
+
def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline:
|
47 |
+
"""
|
48 |
+
Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface.
|
49 |
+
Works on CPU; uses half precision automatically when CUDA is available.
|
50 |
+
"""
|
51 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
model_id,
|
54 |
+
token = HF_TOKEN,
|
55 |
+
torch_dtype = dtype,
|
56 |
+
device_map = "auto"
|
57 |
+
)
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
|
59 |
+
gen_pipe = pipeline(
|
60 |
+
task = "text-generation",
|
61 |
+
model = model,
|
62 |
+
tokenizer = tokenizer,
|
63 |
+
max_new_tokens = 512,
|
64 |
+
temperature = 0.2,
|
65 |
+
top_p = 0.95,
|
66 |
+
)
|
67 |
+
return HuggingFacePipeline(pipeline=gen_pipe)
|
68 |
|
69 |
+
HF_LLM = build_hf_llm() # Initialise once; reuse in every request
|
70 |
+
# OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional
|
71 |
|
72 |
+
# ------------------------------------------------------------------
|
73 |
+
# 3. FASTAPI PLUMBING
|
74 |
+
# ------------------------------------------------------------------
|
75 |
app = FastAPI()
|
76 |
templates = Jinja2Templates(directory="templates")
|
77 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
78 |
|
79 |
+
# Embedding stores
|
80 |
+
CZECH_STORE = "stores/czech_512"
|
81 |
+
ENGLISH_STORE = "stores/english_512"
|
82 |
|
83 |
ingestor = Ingest(
|
84 |
+
# openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds
|
85 |
+
chunk = 512,
|
86 |
+
overlap = 256,
|
87 |
+
czech_store = CZECH_STORE,
|
88 |
+
english_store = ENGLISH_STORE,
|
89 |
+
czech_embedding_model = EMB_CZ,
|
90 |
+
english_embedding_model = EMB_EN,
|
91 |
)
|
92 |
|
93 |
+
# ------------------------------------------------------------------
|
94 |
+
# 4. PROMPTS
|
95 |
+
# ------------------------------------------------------------------
|
96 |
+
def prompt_en() -> PromptTemplate:
|
97 |
+
tmpl = """You are an electrical engineer and you answer users' ###Question.
|
98 |
+
# Your answer must be helpful, relevant and closely related to the user's ###Question.
|
99 |
+
# Quote literally from the ###Context wherever possible.
|
100 |
+
# Use your own words only to connect or clarify. If you don't know, say so.
|
101 |
+
###Context: {context}
|
102 |
+
###Question: {question}
|
103 |
+
Helpful answer:
|
104 |
+
"""
|
105 |
+
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
|
106 |
+
|
107 |
+
def prompt_cz() -> PromptTemplate:
|
108 |
+
tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku.
|
109 |
+
# Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou.
|
110 |
+
# Citujte co nejvíce doslovně z ###Kontextu.
|
111 |
+
# Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to.
|
112 |
+
###Kontext: {context}
|
113 |
+
###Otázka: {question}
|
114 |
+
Užitečná odpověď:
|
115 |
+
"""
|
116 |
+
return PromptTemplate(template=tmpl, input_variables=["context", "question"])
|
117 |
+
|
118 |
+
# ------------------------------------------------------------------
|
119 |
+
# 5. ROUTES
|
120 |
+
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
@app.get("/", response_class=HTMLResponse)
|
122 |
+
def home(request: Request):
|
123 |
return templates.TemplateResponse("index.html", {"request": request})
|
124 |
|
|
|
125 |
@app.post("/ingest_data")
|
126 |
async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
|
127 |
+
if language.lower() == "czech":
|
|
|
|
|
128 |
ingestor.data_czech = folderPath
|
129 |
ingestor.ingest_czech()
|
130 |
+
return {"message": "Czech data ingestion complete."}
|
131 |
+
ingestor.data_english = folderPath
|
132 |
+
ingestor.ingest_english()
|
133 |
+
return {"message": "English data ingestion complete."}
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
@app.post("/get_response")
|
136 |
async def get_response(query: str = Form(...), language: str = Form(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
is_czech = language.lower() == "czech"
|
139 |
+
prompt = prompt_cz() if is_czech else prompt_en()
|
140 |
+
store_path = CZECH_STORE if is_czech else ENGLISH_STORE
|
141 |
+
embed_name = EMB_CZ if is_czech else EMB_EN
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
embeddings = HuggingFaceEmbeddings(
|
144 |
+
model_name = embed_name,
|
145 |
+
model_kwargs = {"device": "cpu"},
|
146 |
+
encode_kwargs= {"normalize_embeddings": False}
|
147 |
+
)
|
148 |
+
vectordb = FAISS.load_local(store_path, embeddings)
|
149 |
+
retriever = vectordb.as_retriever(search_kwargs={"k": 2})
|
150 |
|
151 |
+
qa_chain = RetrievalQA.from_chain_type(
|
152 |
+
llm = HF_LLM, # <- default open-source model
|
153 |
+
# llm = OPENAI_LLM, # <- optional paid model
|
154 |
+
chain_type = "stuff",
|
155 |
+
retriever = retriever,
|
156 |
+
return_source_documents= True,
|
157 |
+
chain_type_kwargs = {"prompt": prompt},
|
158 |
+
verbose = True,
|
159 |
)
|
160 |
|
161 |
+
result = qa_chain(query)
|
162 |
+
answer = result["result"]
|
163 |
+
src_doc = result["source_documents"][0].page_content
|
164 |
+
src_path = result["source_documents"][0].metadata["source"]
|
165 |
+
|
166 |
+
payload = jsonable_encoder(json.dumps({
|
167 |
+
"answer" : answer,
|
168 |
+
"source_document" : src_doc,
|
169 |
+
"doc" : src_path
|
170 |
+
}))
|
171 |
+
return Response(payload)
|