Teapack1 commited on
Commit
25d1cc3
·
verified ·
1 Parent(s): 0bdb69e

Update fast_app.py

Browse files
Files changed (1) hide show
  1. fast_app.py +132 -131
fast_app.py CHANGED
@@ -1,6 +1,7 @@
1
- from dotenv import load_dotenv
2
  import os
3
  import json
 
4
  from fastapi import FastAPI, Request, Form, Response
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.templating import Jinja2Templates
@@ -8,163 +9,163 @@ from fastapi.staticfiles import StaticFiles
8
  from fastapi.encoders import jsonable_encoder
9
 
10
  from langchain_community.vectorstores import FAISS
 
 
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
-
13
  from langchain.chains import RetrievalQA
14
-
15
- from langchain.llms import OpenAI
16
  from langchain import PromptTemplate
17
- from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings
18
-
19
- from ingest import Ingest
20
-
21
- # setx OPENAI_API_KEY "your_openai_api_key_here"
22
 
23
- # Access the Hugging Face API token from an environment variable
24
- # huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
25
- # if huggingface_token is None:
26
- # raise ValueError("Hugging Face token is not set in environment variables.")
27
 
28
- openai_api_key = os.getenv("OPENAI_API_KEY")
29
- if openai_api_key is None:
30
- raise ValueError("OAI token is not set in environment variables.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
32
 
 
 
 
33
  app = FastAPI()
34
  templates = Jinja2Templates(directory="templates")
35
  app.mount("/static", StaticFiles(directory="static"), name="static")
36
- english_embedding_model = "text-embedding-3-large"
37
- czech_embedding_model = "Seznam/retromae-small-cs"
38
 
39
- czech_store = "stores/czech_512"
40
- english_store = "stores/english_512"
 
41
 
42
  ingestor = Ingest(
43
- openai_api_key=openai_api_key,
44
- chunk=512,
45
- overlap=256,
46
- czech_store=czech_store,
47
- english_store=english_store,
48
- czech_embedding_model=czech_embedding_model,
49
- english_embedding_model=english_embedding_model,
50
  )
51
 
52
-
53
- def prompt_en():
54
- prompt_template_en = """You are electrical engineer and you answer users ###Question.
55
-
56
- #Your answer has to be helpful, relevant and closely related to the user's ###Question.
57
- #Provide as much literal information and transcription from the #Context as possible.
58
- #Only use your own words to connect, clarify or explain the information!
59
- #If you don't know the answer, just say that you don't know, don't try to make up an answer.
60
-
61
- ###Context: {context}
62
- ###Question: {question}
63
-
64
- Only return the helpful answer below and nothing else.
65
- Helpful answer:
66
- """
67
- prompt_en = PromptTemplate(
68
- template=prompt_template_en, input_variables=["context", "question"]
69
- )
70
- print("\n Prompt ready... \n\n")
71
- return prompt_en
72
-
73
-
74
- def prompt_cz():
75
- prompt_template_cz = """Jste elektroinženýr a odpovídáte uživatelům na ###Otázku.
76
-
77
- #Vaše odpověď musí být užitečná, relevantní a úzce souviset s uživatelovou ###Otázkou.
78
- #Poskytněte co nejvíce doslovných informací a přepisů z #Kontextu.
79
- #Použijte vlastní slova pouze pro spojení, objasnění nebo vysvětlení informací!
80
- #Pokud odpověď neznáte, prostě řekněte, že to nevíte, nepokoušejte se vymýšlet odpověď.
81
-
82
- ###Kontext: {context}
83
- ###Otázka: {question}
84
-
85
- Níže vraťte pouze užitečnou odpověď a nic jiného.
86
- Užitečná odpověď:
87
- """
88
- prompt_cz = PromptTemplate(
89
- template=prompt_template_cz, input_variables=["context", "question"]
90
- )
91
- print("\n Prompt ready... \n\n")
92
- return prompt_cz
93
-
94
-
95
  @app.get("/", response_class=HTMLResponse)
96
- def read_item(request: Request):
97
  return templates.TemplateResponse("index.html", {"request": request})
98
 
99
-
100
  @app.post("/ingest_data")
101
  async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
102
- # Determine the correct data path and store based on the language
103
- if language == "czech":
104
- print("\n Czech language selected....\n\n")
105
  ingestor.data_czech = folderPath
106
  ingestor.ingest_czech()
107
- message = "Czech data ingestion complete."
108
- else:
109
- print("\n English language selected....\n\n")
110
- ingestor.data_english = folderPath
111
- ingestor.ingest_english()
112
- message = "English data ingestion complete."
113
-
114
- return {"message": message}
115
-
116
 
117
  @app.post("/get_response")
118
  async def get_response(query: str = Form(...), language: str = Form(...)):
119
- print(language)
120
- if language == "czech":
121
- prompt = prompt_cz()
122
- print("\n Czech language selected....\n\n")
123
- embedding_model = czech_embedding_model
124
- persist_directory = czech_store
125
- model_name = embedding_model
126
- model_kwargs = {"device": "cpu"}
127
- encode_kwargs = {"normalize_embeddings": False}
128
- embedding = HuggingFaceEmbeddings(
129
- model_name=model_name,
130
- model_kwargs=model_kwargs,
131
- encode_kwargs=encode_kwargs,
132
- )
133
- else:
134
- prompt = prompt_en()
135
- print("\n English language selected....\n\n")
136
- embedding_model = english_embedding_model # Default to English
137
- persist_directory = english_store
138
- embedding = OpenAIEmbeddings(
139
- openai_api_key=openai_api_key,
140
- model=embedding_model,
141
- )
142
-
143
- vectordb = FAISS.load_local(persist_directory, embedding)
144
- retriever = vectordb.as_retriever(search_kwargs={"k": 2})
145
 
146
- chain_type_kwargs = {"prompt": prompt}
147
- qa_chain = RetrievalQA.from_chain_type(
148
- llm=OpenAI(openai_api_key=openai_api_key),
149
- chain_type="stuff",
150
- retriever=retriever,
151
- return_source_documents=True,
152
- chain_type_kwargs=chain_type_kwargs,
153
- verbose=True,
154
- )
155
- response = qa_chain(query)
156
 
157
- for i in response["source_documents"]:
158
- print(f"\n{i}\n\n")
159
-
160
- print(response)
 
 
 
161
 
162
- answer = response["result"]
163
- source_document = response["source_documents"][0].page_content
164
- doc = response["source_documents"][0].metadata["source"]
165
- response_data = jsonable_encoder(
166
- json.dumps({"answer": answer, "source_document": source_document, "doc": doc})
 
 
 
167
  )
168
 
169
- res = Response(response_data)
170
- return res
 
 
 
 
 
 
 
 
 
 
1
+ # backend/main.py
2
  import os
3
  import json
4
+ from dotenv import load_dotenv
5
  from fastapi import FastAPI, Request, Form, Response
6
  from fastapi.responses import HTMLResponse
7
  from fastapi.templating import Jinja2Templates
 
9
  from fastapi.encoders import jsonable_encoder
10
 
11
  from langchain_community.vectorstores import FAISS
12
+ from langchain_community.llms import HuggingFacePipeline # NEW
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
15
  from langchain.chains import RetrievalQA
 
 
16
  from langchain import PromptTemplate
 
 
 
 
 
17
 
18
+ # -------- optional OpenAI imports (kept, but disabled) ----------
19
+ # from langchain.llms import OpenAI
20
+ # from langchain.embeddings import OpenAIEmbeddings
21
+ # ---------------------------------------------------------------
22
 
23
+ from ingest import Ingest
24
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
25
+ import torch
26
+
27
+ # ------------------------------------------------------------------
28
+ # 1. ENVIRONMENT
29
+ # ------------------------------------------------------------------
30
+ load_dotenv()
31
+ HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
32
+ if HF_TOKEN is None:
33
+ raise ValueError("HUGGINGFACE_TOKEN not set in the environment.")
34
+
35
+ # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Optional
36
+ # if OPENAI_API_KEY is None:
37
+ # print("OpenAI key missing – OpenAI path disabled.")
38
+
39
+ # ------------------------------------------------------------------
40
+ # 2. LLM & EMBEDDINGS CONFIGURATION
41
+ # ------------------------------------------------------------------
42
+ DEFAULT_LLM = "google/gemma-3-4b-it" # change here if desired
43
+ EMB_EN = "sentence-transformers/all-MiniLM-L6-v2"
44
+ EMB_CZ = "Seznam/retromae-small-cs"
45
+
46
+ def build_hf_llm(model_id: str = DEFAULT_LLM) -> HuggingFacePipeline:
47
+ """
48
+ Creates a HuggingFacePipeline wrapped inside LangChain's LLM interface.
49
+ Works on CPU; uses half precision automatically when CUDA is available.
50
+ """
51
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_id,
54
+ token = HF_TOKEN,
55
+ torch_dtype = dtype,
56
+ device_map = "auto"
57
+ )
58
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
59
+ gen_pipe = pipeline(
60
+ task = "text-generation",
61
+ model = model,
62
+ tokenizer = tokenizer,
63
+ max_new_tokens = 512,
64
+ temperature = 0.2,
65
+ top_p = 0.95,
66
+ )
67
+ return HuggingFacePipeline(pipeline=gen_pipe)
68
 
69
+ HF_LLM = build_hf_llm() # Initialise once; reuse in every request
70
+ # OPENAI_LLM = OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0) # optional
71
 
72
+ # ------------------------------------------------------------------
73
+ # 3. FASTAPI PLUMBING
74
+ # ------------------------------------------------------------------
75
  app = FastAPI()
76
  templates = Jinja2Templates(directory="templates")
77
  app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
78
 
79
+ # Embedding stores
80
+ CZECH_STORE = "stores/czech_512"
81
+ ENGLISH_STORE = "stores/english_512"
82
 
83
  ingestor = Ingest(
84
+ # openai_api_key = OPENAI_API_KEY, # still needed only if you ingest via OpenAI embeds
85
+ chunk = 512,
86
+ overlap = 256,
87
+ czech_store = CZECH_STORE,
88
+ english_store = ENGLISH_STORE,
89
+ czech_embedding_model = EMB_CZ,
90
+ english_embedding_model = EMB_EN,
91
  )
92
 
93
+ # ------------------------------------------------------------------
94
+ # 4. PROMPTS
95
+ # ------------------------------------------------------------------
96
+ def prompt_en() -> PromptTemplate:
97
+ tmpl = """You are an electrical engineer and you answer users' ###Question.
98
+ # Your answer must be helpful, relevant and closely related to the user's ###Question.
99
+ # Quote literally from the ###Context wherever possible.
100
+ # Use your own words only to connect or clarify. If you don't know, say so.
101
+ ###Context: {context}
102
+ ###Question: {question}
103
+ Helpful answer:
104
+ """
105
+ return PromptTemplate(template=tmpl, input_variables=["context", "question"])
106
+
107
+ def prompt_cz() -> PromptTemplate:
108
+ tmpl = """Jste elektroinženýr a odpovídáte na ###Otázku.
109
+ # Odpověď musí být užitečná, relevantní a úzce souviset s ###Otázkou.
110
+ # Citujte co nejvíce doslovně z ###Kontextu.
111
+ # Vlastními slovy pouze propojujte nebo vysvětlujte. Nevíte-li, řekněte to.
112
+ ###Kontext: {context}
113
+ ###Otázka: {question}
114
+ Užitečná odpověď:
115
+ """
116
+ return PromptTemplate(template=tmpl, input_variables=["context", "question"])
117
+
118
+ # ------------------------------------------------------------------
119
+ # 5. ROUTES
120
+ # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  @app.get("/", response_class=HTMLResponse)
122
+ def home(request: Request):
123
  return templates.TemplateResponse("index.html", {"request": request})
124
 
 
125
  @app.post("/ingest_data")
126
  async def ingest_data(folderPath: str = Form(...), language: str = Form(...)):
127
+ if language.lower() == "czech":
 
 
128
  ingestor.data_czech = folderPath
129
  ingestor.ingest_czech()
130
+ return {"message": "Czech data ingestion complete."}
131
+ ingestor.data_english = folderPath
132
+ ingestor.ingest_english()
133
+ return {"message": "English data ingestion complete."}
 
 
 
 
 
134
 
135
  @app.post("/get_response")
136
  async def get_response(query: str = Form(...), language: str = Form(...)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ is_czech = language.lower() == "czech"
139
+ prompt = prompt_cz() if is_czech else prompt_en()
140
+ store_path = CZECH_STORE if is_czech else ENGLISH_STORE
141
+ embed_name = EMB_CZ if is_czech else EMB_EN
 
 
 
 
 
 
142
 
143
+ embeddings = HuggingFaceEmbeddings(
144
+ model_name = embed_name,
145
+ model_kwargs = {"device": "cpu"},
146
+ encode_kwargs= {"normalize_embeddings": False}
147
+ )
148
+ vectordb = FAISS.load_local(store_path, embeddings)
149
+ retriever = vectordb.as_retriever(search_kwargs={"k": 2})
150
 
151
+ qa_chain = RetrievalQA.from_chain_type(
152
+ llm = HF_LLM, # <- default open-source model
153
+ # llm = OPENAI_LLM, # <- optional paid model
154
+ chain_type = "stuff",
155
+ retriever = retriever,
156
+ return_source_documents= True,
157
+ chain_type_kwargs = {"prompt": prompt},
158
+ verbose = True,
159
  )
160
 
161
+ result = qa_chain(query)
162
+ answer = result["result"]
163
+ src_doc = result["source_documents"][0].page_content
164
+ src_path = result["source_documents"][0].metadata["source"]
165
+
166
+ payload = jsonable_encoder(json.dumps({
167
+ "answer" : answer,
168
+ "source_document" : src_doc,
169
+ "doc" : src_path
170
+ }))
171
+ return Response(payload)