Abbeite commited on
Commit
4d65b06
·
verified ·
1 Parent(s): f06f2b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import fitz # PyMuPDF
 
4
 
5
- # Define a function to load the PDF document and cache the content with Streamlit's caching
6
  @st.cache(allow_output_mutation=True)
7
  def load_pdf_document(file_path):
8
  text = ""
@@ -11,40 +12,44 @@ def load_pdf_document(file_path):
11
  text += page.get_text()
12
  return text
13
 
14
- # Define a function to get answers from the model and cache the model and tokenizer
15
  @st.cache(allow_output_mutation=True)
16
  def load_model_and_tokenizer(model_name):
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForCausalLM.from_pretrained(model_name)
19
  return tokenizer, model
20
 
21
- # UI for the Streamlit app
22
- st.title("Question Answering with LLaMA 2")
 
 
 
 
 
 
 
 
 
23
 
24
- # Load the document
25
- document_path = "jeff_wo.pdf" # Change this to the path of your PDF document in the repository
 
26
  document_text = load_pdf_document(document_path)
27
- # Displaying the document text can be optional, based on your preference or usability considerations
28
- st.text_area("Document Text", value=document_text, height=300, help="Content of the PDF document.")
 
29
 
30
  # Load model and tokenizer
31
  model_name = "NousResearch/Llama-2-7b-chat-hf"
32
  tokenizer, model = load_model_and_tokenizer(model_name)
33
 
34
- # Sidebar for user input
35
- st.sidebar.header("Ask a Question")
36
- query = st.sidebar.text_input("Enter your question:", "")
37
 
38
- if st.sidebar.button("Answer"):
39
  if query:
40
  with st.spinner("Generating answer..."):
41
- input_text = f"Context: {document_text}\nQ: {query}\nA:"
42
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
43
- # Adjust the generation parameters as needed
44
- output = model.generate(input_ids, max_length=512, num_return_sequences=1, temperature=0.7, top_p=0.9)
45
- answer = tokenizer.decode(output[0], skip_special_tokens=True)
46
  st.write(answer)
47
  else:
48
- st.sidebar.error("Please enter a question.")
49
- else:
50
- st.write("Enter a question on the sidebar and click 'Answer' to get a response.")
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import fitz # PyMuPDF
4
+ import torch
5
 
6
+ # Function to load the PDF document
7
  @st.cache(allow_output_mutation=True)
8
  def load_pdf_document(file_path):
9
  text = ""
 
12
  text += page.get_text()
13
  return text
14
 
15
+ # Function to load the model and tokenizer
16
  @st.cache(allow_output_mutation=True)
17
  def load_model_and_tokenizer(model_name):
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  model = AutoModelForCausalLM.from_pretrained(model_name)
20
  return tokenizer, model
21
 
22
+ # Function to generate an answer from the model
23
+ def generate_answer(context, query, tokenizer, model):
24
+ # Preprocess and truncate the context to fit within model limits
25
+ encoded_input = tokenizer.encode_plus(query, context, add_special_tokens=True, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length - 20)
26
+ input_ids = encoded_input["input_ids"]
27
+ attention_mask = encoded_input["attention_mask"]
28
+
29
+ # Generate an answer using max_new_tokens to limit output length
30
+ output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=150, num_return_sequences=1, temperature=0.7, top_p=0.9)
31
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
32
+ return answer
33
 
34
+ # Streamlit UI
35
+ st.title("Question Answering with LLaMA 2")
36
+ document_path = "jeff_wo.pdf"
37
  document_text = load_pdf_document(document_path)
38
+
39
+ # Optional: Display the document text or a portion of it
40
+ st.text_area("Document Text (preview)", value=document_text[:1000], height=250, help="Preview of the document text.")
41
 
42
  # Load model and tokenizer
43
  model_name = "NousResearch/Llama-2-7b-chat-hf"
44
  tokenizer, model = load_model_and_tokenizer(model_name)
45
 
46
+ # User input for the query
47
+ query = st.text_input("Enter your question:", "")
 
48
 
49
+ if st.button("Generate Answer"):
50
  if query:
51
  with st.spinner("Generating answer..."):
52
+ answer = generate_answer(document_text, query, tokenizer, model)
 
 
 
 
53
  st.write(answer)
54
  else:
55
+ st.error("Please enter a question to get an answer.")