halyn commited on
Commit
ddf266a
·
1 Parent(s): f085c10

code update for streamlit

Browse files
Files changed (2) hide show
  1. app.py +49 -107
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import os
2
  import io
3
  import requests
 
4
  from dotenv import load_dotenv
5
- from fastapi import FastAPI, HTTPException, UploadFile, File
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from pydantic import BaseModel
8
  from PyPDF2 import PdfReader
9
  from langchain.text_splitter import CharacterTextSplitter
10
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
@@ -12,27 +10,14 @@ from langchain.vectorstores import FAISS
12
  from langchain.chains.question_answering import load_qa_chain
13
  from langchain.llms import HuggingFacePipeline
14
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
15
- import streamlit as st
16
 
17
  # Disable WANDB
18
  os.environ['WANDB_DISABLED'] = "true"
19
 
20
  # Constants
21
  MODEL_PATH = "/home/lab/halyn/gemma/halyn/paper/models/gemma-2-9b-it"
22
- FASTAPI_URL = "http://203.249.64.50:8080" # 서버 주소
23
-
24
- app = FastAPI()
25
 
26
- # CORS 설정
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"], # 모든 출처 허용
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- # Global variables to store the knowledge base and QA chain
36
  knowledge_base = None
37
  qa_chain = None
38
 
@@ -40,7 +25,7 @@ def load_pdf(pdf_file):
40
  """
41
  Load and extract text from a PDF.
42
  Args:
43
- pdf_file (str) : The PDF file.
44
  Returns:
45
  str: Extracted text from the PDF.
46
  """
@@ -52,9 +37,9 @@ def split_text(text):
52
  """
53
  Split the extracted text into chunks.
54
  Args:
55
- text (str) : The full text extracted from the PDF.
56
  Returns:
57
- list : A list of text chunks
58
  """
59
  text_splitter = CharacterTextSplitter(
60
  separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
@@ -65,9 +50,9 @@ def create_knowledge_base(chunks):
65
  """
66
  Create a FAISS knowledge base from text chunks.
67
  Args:
68
- chunks (list) : A list of text chunks.
69
  Returns:
70
- FAISS: A FAISS knowledge base object
71
  """
72
  embeddings = HuggingFaceEmbeddings()
73
  return FAISS.from_texts(chunks, embeddings)
@@ -76,7 +61,7 @@ def load_model(model_path):
76
  """
77
  Load the HuggingFace model and tokenizer, and create a text-generation pipeline.
78
  Args:
79
- model_path (str) : The path to the pre-trained model.
80
  Returns:
81
  pipeline: A HuggingFace pipeline for text generation.
82
  """
@@ -84,56 +69,14 @@ def load_model(model_path):
84
  model = AutoModelForCausalLM.from_pretrained(model_path)
85
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
86
 
87
- @app.on_event("startup")
88
- async def startup_event():
89
- """ Start function to run the PDF question-answering system. """
 
90
  global qa_chain
91
- load_dotenv()
92
-
93
- # Load the language model
94
- try:
95
- pipe = load_model(MODEL_PATH)
96
- llm = HuggingFacePipeline(pipeline=pipe)
97
- qa_chain = load_qa_chain(llm, chain_type="stuff")
98
- except Exception as e:
99
- print(f"Error loading model: {e}")
100
- raise HTTPException(status_code=500, detail="Failed to load the language model")
101
-
102
- @app.post("/upload_pdf")
103
- async def upload_pdf(file: UploadFile = File(...)):
104
- global knowledge_base
105
- try:
106
- contents = await file.read()
107
- pdf_file = io.BytesIO(contents)
108
- text = load_pdf(pdf_file)
109
- chunks = split_text(text)
110
- knowledge_base = create_knowledge_base(chunks)
111
- return {"message": "PDF uploaded and processed successfully"}
112
- except Exception as e:
113
- raise HTTPException(status_code=400, detail=f"Failed to process PDF: {str(e)}")
114
-
115
- class Question(BaseModel):
116
- text: str
117
-
118
- @app.post("/ask")
119
- async def ask_question(question: Question):
120
- global knowledge_base, qa_chain
121
- if not knowledge_base:
122
- raise HTTPException(status_code=400, detail="No PDF has been uploaded yet")
123
- if not qa_chain:
124
- raise HTTPException(status_code=500, detail="QA chain is not initialized")
125
-
126
- try:
127
- docs = knowledge_base.similarity_search(question.text)
128
- response = qa_chain.run(input_documents=docs, question=question.text)
129
-
130
- if "Helpful Answer:" in response:
131
- response = response.split("Helpful Answer:")[1].strip()
132
-
133
- return {"response": response}
134
- except Exception as e:
135
- raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
136
-
137
 
138
  # Streamlit UI
139
  def main_page():
@@ -146,23 +89,24 @@ def main_page():
146
  st.write("Please click the button below.")
147
 
148
  if st.button("Click Here :)"):
149
- # FastAPI 서버에 PDF 파일 전송
150
  try:
151
- files = {"file": (paper.name, paper, "application/pdf")}
152
- response = requests.post(f"{FASTAPI_URL}/upload_pdf", files=files)
153
- if response.status_code == 200:
154
- st.success("PDF successfully uploaded to the model! Please click the button again")
155
- st.session_state.messages = []
156
- st.session_state.paper_name = paper.name[:-4]
157
- st.session_state.page = "chat"
158
- else:
159
- st.error(f"Failed to upload PDF to the model. Error: {response.text}")
160
- except requests.RequestException as e:
161
- st.error(f"Error connecting to the server: {str(e)}")
 
 
 
162
 
163
  def chat_page():
164
- st.title(f"Welcome to GemmaPaperQA")
165
- st.subheader(f"Ask anything about {st.session_state.paper_name}")
166
 
167
  if "messages" not in st.session_state:
168
  st.session_state.messages = []
@@ -170,37 +114,40 @@ def chat_page():
170
  for message in st.session_state.messages:
171
  with st.chat_message(message["role"]):
172
  st.markdown(message["content"])
173
-
174
- if prompt := st.chat_input("Chat here !"):
175
- # Add user message to chat history
176
  st.session_state.messages.append({"role": "user", "content": prompt})
177
 
178
- # Display user message in chat message container
179
  with st.chat_message("user"):
180
  st.markdown(prompt)
181
 
182
- # Get response from FastAPI server
183
- response = get_response_from_fastapi(prompt)
184
 
185
- # Display assistant response in chat message container
186
  with st.chat_message("assistant"):
187
  st.markdown(response)
188
 
189
- # Add assistant response to chat history
190
  st.session_state.messages.append({"role": "assistant", "content": response})
191
 
192
  if st.button("Go back to main page"):
193
  st.session_state.page = "main"
194
 
195
- def get_response_from_fastapi(prompt):
196
  try:
197
- response = requests.post(f"{FASTAPI_URL}/ask", json={"text": prompt})
198
- if response.status_code == 200:
199
- return response.json()["response"]
200
- else:
201
- return f"Sorry, I couldn't generate a response. Error: {response.text}"
202
- except requests.RequestException as e:
203
- return f"Sorry, there was an error connecting to the server: {str(e)}"
 
 
 
 
 
 
 
 
204
 
205
  # Streamlit - 초기 페이지 설정
206
  if "page" not in st.session_state:
@@ -215,8 +162,3 @@ if st.session_state.page == "main":
215
  main_page()
216
  elif st.session_state.page == "chat":
217
  chat_page()
218
-
219
- # FastAPI 앱 실행을 위한 코드
220
- if __name__ == "__main__":
221
- import uvicorn
222
- uvicorn.run(app, host="0.0.0.0", port=8050)
 
1
  import os
2
  import io
3
  import requests
4
+ import streamlit as st
5
  from dotenv import load_dotenv
 
 
 
6
  from PyPDF2 import PdfReader
7
  from langchain.text_splitter import CharacterTextSplitter
8
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
 
10
  from langchain.chains.question_answering import load_qa_chain
11
  from langchain.llms import HuggingFacePipeline
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
13
 
14
  # Disable WANDB
15
  os.environ['WANDB_DISABLED'] = "true"
16
 
17
  # Constants
18
  MODEL_PATH = "/home/lab/halyn/gemma/halyn/paper/models/gemma-2-9b-it"
 
 
 
19
 
20
+ # Global variables
 
 
 
 
 
 
 
 
 
21
  knowledge_base = None
22
  qa_chain = None
23
 
 
25
  """
26
  Load and extract text from a PDF.
27
  Args:
28
+ pdf_file (str): The PDF file.
29
  Returns:
30
  str: Extracted text from the PDF.
31
  """
 
37
  """
38
  Split the extracted text into chunks.
39
  Args:
40
+ text (str): The full text extracted from the PDF.
41
  Returns:
42
+ list: A list of text chunks.
43
  """
44
  text_splitter = CharacterTextSplitter(
45
  separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len
 
50
  """
51
  Create a FAISS knowledge base from text chunks.
52
  Args:
53
+ chunks (list): A list of text chunks.
54
  Returns:
55
+ FAISS: A FAISS knowledge base object.
56
  """
57
  embeddings = HuggingFaceEmbeddings()
58
  return FAISS.from_texts(chunks, embeddings)
 
61
  """
62
  Load the HuggingFace model and tokenizer, and create a text-generation pipeline.
63
  Args:
64
+ model_path (str): The path to the pre-trained model.
65
  Returns:
66
  pipeline: A HuggingFace pipeline for text generation.
67
  """
 
69
  model = AutoModelForCausalLM.from_pretrained(model_path)
70
  return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.1)
71
 
72
+ def setup_qa_chain():
73
+ """
74
+ Set up the question-answering chain.
75
+ """
76
  global qa_chain
77
+ pipe = load_model(MODEL_PATH)
78
+ llm = HuggingFacePipeline(pipeline=pipe)
79
+ qa_chain = load_qa_chain(llm, chain_type="stuff")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Streamlit UI
82
  def main_page():
 
89
  st.write("Please click the button below.")
90
 
91
  if st.button("Click Here :)"):
 
92
  try:
93
+ # PDF 파일 처리
94
+ contents = paper.read()
95
+ pdf_file = io.BytesIO(contents)
96
+ text = load_pdf(pdf_file)
97
+ chunks = split_text(text)
98
+ global knowledge_base
99
+ knowledge_base = create_knowledge_base(chunks)
100
+
101
+ st.success("PDF successfully processed! You can now ask questions.")
102
+ st.session_state.paper_name = paper.name[:-4]
103
+ st.session_state.page = "chat"
104
+ setup_qa_chain()
105
+ except Exception as e:
106
+ st.error(f"Failed to process the PDF: {str(e)}")
107
 
108
  def chat_page():
109
+ st.title(f"Ask anything about {st.session_state.paper_name}")
 
110
 
111
  if "messages" not in st.session_state:
112
  st.session_state.messages = []
 
114
  for message in st.session_state.messages:
115
  with st.chat_message(message["role"]):
116
  st.markdown(message["content"])
117
+
118
+ if prompt := st.chat_input("Chat here!"):
 
119
  st.session_state.messages.append({"role": "user", "content": prompt})
120
 
 
121
  with st.chat_message("user"):
122
  st.markdown(prompt)
123
 
124
+ response = get_response_from_model(prompt)
 
125
 
 
126
  with st.chat_message("assistant"):
127
  st.markdown(response)
128
 
 
129
  st.session_state.messages.append({"role": "assistant", "content": response})
130
 
131
  if st.button("Go back to main page"):
132
  st.session_state.page = "main"
133
 
134
+ def get_response_from_model(prompt):
135
  try:
136
+ global knowledge_base, qa_chain
137
+ if not knowledge_base:
138
+ return "No PDF has been uploaded yet."
139
+ if not qa_chain:
140
+ return "QA chain is not initialized."
141
+
142
+ docs = knowledge_base.similarity_search(prompt)
143
+ response = qa_chain.run(input_documents=docs, question=prompt)
144
+
145
+ if "Helpful Answer:" in response:
146
+ response = response.split("Helpful Answer:")[1].strip()
147
+
148
+ return response
149
+ except Exception as e:
150
+ return f"Error: {str(e)}"
151
 
152
  # Streamlit - 초기 페이지 설정
153
  if "page" not in st.session_state:
 
162
  main_page()
163
  elif st.session_state.page == "chat":
164
  chat_page()
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  streamlit
2
  requests
3
  PyPDF2
 
 
 
 
1
  streamlit
2
  requests
3
  PyPDF2
4
+ dotenv
5
+ langchain
6
+ transformers