Update app.py
Browse files
app.py
CHANGED
@@ -1,66 +1,118 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
|
|
|
|
|
4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
# Function to load the PDF document
|
7 |
-
@st.cache(allow_output_mutation=True)
|
8 |
-
def load_pdf_document(file_path):
|
9 |
-
text = ""
|
10 |
-
with fitz.open(file_path) as doc:
|
11 |
-
for page in doc:
|
12 |
-
text += page.get_text()
|
13 |
-
return text
|
14 |
-
|
15 |
-
# Function to load the model and tokenizer
|
16 |
-
@st.cache(allow_output_mutation=True)
|
17 |
-
def load_model_and_tokenizer(model_name):
|
18 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
-
model = AutoModelForCausalLM.from_pretrained(model_name)
|
20 |
return tokenizer, model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
if len(tokenizer.encode(context, add_special_tokens=False)) > max_context_length:
|
29 |
-
# If context is too long, truncate it from the beginning (simple approach)
|
30 |
-
start_index = len(tokenizer.encode(context, add_special_tokens=False)) - max_context_length
|
31 |
-
truncated_context = tokenizer.decode(tokenizer.encode(context, add_special_tokens=False)[start_index:])
|
32 |
-
else:
|
33 |
-
truncated_context = context
|
34 |
-
|
35 |
-
encoded_input = tokenizer.encode_plus(query, truncated_context, add_special_tokens=True, return_tensors="pt", truncation=True)
|
36 |
-
input_ids = encoded_input["input_ids"]
|
37 |
-
attention_mask = encoded_input["attention_mask"]
|
38 |
-
|
39 |
-
# Use max_new_tokens to control the length of the generated content
|
40 |
-
output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=150, num_return_sequences=1, temperature=0.7, top_p=0.9)
|
41 |
-
answer = tokenizer.decode(output[0], skip_special_tokens=True)
|
42 |
-
return answer
|
43 |
-
|
44 |
-
|
45 |
-
# Streamlit UI
|
46 |
-
st.title("Question Answering with LLaMA 2")
|
47 |
-
document_path = "jeff_wo.pdf"
|
48 |
-
document_text = load_pdf_document(document_path)
|
49 |
-
|
50 |
-
# Optional: Display the document text or a portion of it
|
51 |
-
st.text_area("Document Text (preview)", value=document_text[:1000], height=250, help="Preview of the document text.")
|
52 |
-
|
53 |
-
# Load model and tokenizer
|
54 |
-
model_name = "NousResearch/Llama-2-7b-chat-hf"
|
55 |
-
tokenizer, model = load_model_and_tokenizer(model_name)
|
56 |
-
|
57 |
-
# User input for the query
|
58 |
-
query = st.text_input("Enter your question:", "")
|
59 |
-
|
60 |
-
if st.button("Generate Answer"):
|
61 |
-
if query:
|
62 |
-
with st.spinner("Generating answer..."):
|
63 |
-
answer = generate_answer(document_text, query, tokenizer, model)
|
64 |
-
st.write(answer)
|
65 |
-
else:
|
66 |
-
st.error("Please enter a question to get an answer.")
|
|
|
1 |
import streamlit as st
|
2 |
+
|
3 |
+
# Import transformer classes for generaiton
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
|
5 |
+
# Import torch for datatype attributes
|
6 |
import torch
|
7 |
+
# Import the prompt wrapper...but for llama index
|
8 |
+
from llama_index.prompts.prompts import SimpleInputPrompt
|
9 |
+
# Import the llama index HF Wrapper
|
10 |
+
from llama_index.llms import HuggingFaceLLM
|
11 |
+
# Bring in embeddings wrapper
|
12 |
+
from llama_index.embeddings import LangchainEmbedding
|
13 |
+
# Bring in HF embeddings - need these to represent document chunks
|
14 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
15 |
+
# Bring in stuff to change service context
|
16 |
+
from llama_index import set_global_service_context
|
17 |
+
from llama_index import ServiceContext
|
18 |
+
# Import deps to load documents
|
19 |
+
from llama_index import VectorStoreIndex, download_loader
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
# Define variable to hold llama2 weights naming
|
23 |
+
name = "meta-llama/Llama-2-70b-chat-hf"
|
24 |
+
# Set auth token variable from hugging face
|
25 |
+
auth_token = "YOUR HUGGING FACE AUTH TOKEN HERE"
|
26 |
+
|
27 |
+
|
28 |
+
@st.cache_resource
|
29 |
+
def get_tokenizer_model():
|
30 |
+
# Create tokenizer
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir='./model/', use_auth_token=auth_token)
|
32 |
+
|
33 |
+
# Create model
|
34 |
+
model = AutoModelForCausalLM.from_pretrained(name, cache_dir='./model/'
|
35 |
+
, use_auth_token=auth_token, torch_dtype=torch.float16,
|
36 |
+
rope_scaling={"type": "dynamic", "factor": 2}, load_in_8bit=True)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
return tokenizer, model
|
39 |
+
tokenizer, model = get_tokenizer_model()
|
40 |
+
|
41 |
+
# Create a system prompt
|
42 |
+
system_prompt = """<s>[INST] <<SYS>>
|
43 |
+
You are a helpful, respectful and honest assistant. Always answer as
|
44 |
+
helpfully as possible, while being safe. Your answers should not include
|
45 |
+
any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
|
46 |
+
Please ensure that your responses are socially unbiased and positive in nature.
|
47 |
+
|
48 |
+
If a question does not make any sense, or is not factually coherent, explain
|
49 |
+
why instead of answering something not correct. If you don't know the answer
|
50 |
+
to a question, please don't share false information.
|
51 |
+
|
52 |
+
Your goal is to provide answers relating to the workout science and informatins in the documentSYS>>
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
# Throw together the query wrapper
|
57 |
+
query_wrapper_prompt = SimpleInputPrompt("{query_str} [/INST]")
|
58 |
+
|
59 |
+
|
60 |
+
llm = HuggingFaceLLM(context_window=1024,
|
61 |
+
max_new_tokens=128,
|
62 |
+
system_prompt=system_prompt,
|
63 |
+
query_wrapper_prompt=query_wrapper_prompt,
|
64 |
+
model=model,
|
65 |
+
tokenizer=tokenizer)
|
66 |
+
|
67 |
+
# Create and dl embeddings instance
|
68 |
+
embeddings=LangchainEmbedding(
|
69 |
+
HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
# Create new service context instance
|
74 |
+
service_context = ServiceContext.from_defaults(
|
75 |
+
chunk_size=1024,
|
76 |
+
llm=llm,
|
77 |
+
embed_model=embeddings
|
78 |
+
)
|
79 |
+
# And set the service context
|
80 |
+
set_global_service_context(service_context)
|
81 |
+
|
82 |
+
# Download PDF Loader
|
83 |
+
PyMuPDFReader = download_loader("PyMuPDFReader")
|
84 |
+
# Create PDF Loader
|
85 |
+
loader = PyMuPDFReader()
|
86 |
+
# Load documents
|
87 |
+
documents = loader.load(file_path=Path('./data/annualreport.pdf'), metadata=True)
|
88 |
+
|
89 |
+
# Download PDF Loader
|
90 |
+
PyMuPDFReader = download_loader("PyMuPDFReader")
|
91 |
+
# Create PDF Loader
|
92 |
+
loader = PyMuPDFReader()
|
93 |
+
# Load documents
|
94 |
+
documents = loader.load(file_path=Path('jeff_wo.pdf'), metadata=True)
|
95 |
+
|
96 |
+
# Create an index - we'll be able to query this in a sec
|
97 |
+
index = VectorStoreIndex.from_documents(documents)
|
98 |
+
# Setup index query engine using LLM
|
99 |
+
query_engine = index.as_query_engine()
|
100 |
+
|
101 |
+
|
102 |
+
# Create centered main title
|
103 |
+
st.title('🦙 Llama Banker')
|
104 |
+
# Create a text input box for the user
|
105 |
+
prompt = st.text_input('Input your prompt here')
|
106 |
+
|
107 |
+
# If the user hits enter
|
108 |
+
if prompt:
|
109 |
+
response = query_engine.query(prompt)
|
110 |
+
# ...and write it out to the screen
|
111 |
+
st.write(response)
|
112 |
|
113 |
+
# Display raw response object
|
114 |
+
with st.expander('Response Object'):
|
115 |
+
st.write(response)
|
116 |
+
# Display source text
|
117 |
+
with st.expander('Source Text'):
|
118 |
+
st.write(response.get_formatted_sources())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|