elshehawy commited on
Commit
6a53f44
·
1 Parent(s): eee4c3e

code refactoring

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -10,9 +10,10 @@ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
11
  from pypdf import PdfReader, PdfWriter
12
  from pathlib import Path
 
13
 
14
 
15
- def build_rag_chain(pdf_paths):
16
  loaders = [PyPDFLoader(path) for path in pdf_paths]
17
 
18
  docs = []
@@ -21,8 +22,6 @@ def build_rag_chain(pdf_paths):
21
  loader.load()[0:] # skip first page
22
  )
23
 
24
- chunk_size = 1000
25
- chunk_overlap = 200
26
 
27
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
28
  chunk_overlap=chunk_overlap)
@@ -35,7 +34,7 @@ def build_rag_chain(pdf_paths):
35
 
36
  # model_name = 'gpt-3.5-turbo-0125'
37
  # model_name = 'gpt-4-1106-preview'
38
- model_name = 'gpt-4-0125-preview'
39
  llm = ChatOpenAI(model_name=model_name, temperature=0)
40
 
41
  def format_docs(docs):
@@ -51,9 +50,10 @@ def build_rag_chain(pdf_paths):
51
  return rag_chain
52
 
53
 
54
- def predict(query, pdf_file):
55
- print(type(pdf_file))
56
- if pdf_file:
 
57
  # pdf_path = Path(pdf_file)
58
  # pdf_reader = PdfReader(pdf_path)
59
  # pdf_writer = PdfWriter()
@@ -72,7 +72,7 @@ def predict(query, pdf_file):
72
  # os.system("ls data/pdf")
73
 
74
  # pdf_paths = load_pdf_paths(data_root)
75
- rag_chain = build_rag_chain([pdf_file])
76
  return rag_chain.invoke(query)
77
  return "Please upload PDF file"
78
 
 
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
11
  from pypdf import PdfReader, PdfWriter
12
  from pathlib import Path
13
+ from typing import List
14
 
15
 
16
+ def build_rag_chain(pdf_paths: List[str], chunk_size: int =1000, chunk_overlap: int =200, model_name: str ='gpt-4-0125-preview'):
17
  loaders = [PyPDFLoader(path) for path in pdf_paths]
18
 
19
  docs = []
 
22
  loader.load()[0:] # skip first page
23
  )
24
 
 
 
25
 
26
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
27
  chunk_overlap=chunk_overlap)
 
34
 
35
  # model_name = 'gpt-3.5-turbo-0125'
36
  # model_name = 'gpt-4-1106-preview'
37
+ # model_name = 'gpt-4-0125-preview'
38
  llm = ChatOpenAI(model_name=model_name, temperature=0)
39
 
40
  def format_docs(docs):
 
50
  return rag_chain
51
 
52
 
53
+ def predict(query: str, pdf_id: str =None, user_id: str = None):
54
+ print(type(pdf_id))
55
+ # print(user_id)
56
+ if pdf_id:
57
  # pdf_path = Path(pdf_file)
58
  # pdf_reader = PdfReader(pdf_path)
59
  # pdf_writer = PdfWriter()
 
72
  # os.system("ls data/pdf")
73
 
74
  # pdf_paths = load_pdf_paths(data_root)
75
+ rag_chain = build_rag_chain([pdf_id])
76
  return rag_chain.invoke(query)
77
  return "Please upload PDF file"
78