Abbeite commited on
Commit
8be321d
·
verified ·
1 Parent(s): 1dc34bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -61
app.py CHANGED
@@ -1,66 +1,118 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import fitz # PyMuPDF
 
 
4
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Function to load the PDF document
7
- @st.cache(allow_output_mutation=True)
8
- def load_pdf_document(file_path):
9
- text = ""
10
- with fitz.open(file_path) as doc:
11
- for page in doc:
12
- text += page.get_text()
13
- return text
14
-
15
- # Function to load the model and tokenizer
16
- @st.cache(allow_output_mutation=True)
17
- def load_model_and_tokenizer(model_name):
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- model = AutoModelForCausalLM.from_pretrained(model_name)
20
  return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Function to generate an answer from the model
23
- def generate_answer(context, query, tokenizer, model):
24
- # Aim to use a chunk of the context that keeps us within the max model input size
25
- # This is a simplified approach: in practice, you'd want to find a more intelligent way to select the relevant part of the context
26
- max_context_length = tokenizer.model_max_length - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # Adjust buffer as needed
27
-
28
- if len(tokenizer.encode(context, add_special_tokens=False)) > max_context_length:
29
- # If context is too long, truncate it from the beginning (simple approach)
30
- start_index = len(tokenizer.encode(context, add_special_tokens=False)) - max_context_length
31
- truncated_context = tokenizer.decode(tokenizer.encode(context, add_special_tokens=False)[start_index:])
32
- else:
33
- truncated_context = context
34
-
35
- encoded_input = tokenizer.encode_plus(query, truncated_context, add_special_tokens=True, return_tensors="pt", truncation=True)
36
- input_ids = encoded_input["input_ids"]
37
- attention_mask = encoded_input["attention_mask"]
38
-
39
- # Use max_new_tokens to control the length of the generated content
40
- output = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=150, num_return_sequences=1, temperature=0.7, top_p=0.9)
41
- answer = tokenizer.decode(output[0], skip_special_tokens=True)
42
- return answer
43
-
44
-
45
- # Streamlit UI
46
- st.title("Question Answering with LLaMA 2")
47
- document_path = "jeff_wo.pdf"
48
- document_text = load_pdf_document(document_path)
49
-
50
- # Optional: Display the document text or a portion of it
51
- st.text_area("Document Text (preview)", value=document_text[:1000], height=250, help="Preview of the document text.")
52
-
53
- # Load model and tokenizer
54
- model_name = "NousResearch/Llama-2-7b-chat-hf"
55
- tokenizer, model = load_model_and_tokenizer(model_name)
56
-
57
- # User input for the query
58
- query = st.text_input("Enter your question:", "")
59
-
60
- if st.button("Generate Answer"):
61
- if query:
62
- with st.spinner("Generating answer..."):
63
- answer = generate_answer(document_text, query, tokenizer, model)
64
- st.write(answer)
65
- else:
66
- st.error("Please enter a question to get an answer.")
 
1
  import streamlit as st
2
+
3
+ # Import transformer classes for generaiton
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
5
+ # Import torch for datatype attributes
6
  import torch
7
+ # Import the prompt wrapper...but for llama index
8
+ from llama_index.prompts.prompts import SimpleInputPrompt
9
+ # Import the llama index HF Wrapper
10
+ from llama_index.llms import HuggingFaceLLM
11
+ # Bring in embeddings wrapper
12
+ from llama_index.embeddings import LangchainEmbedding
13
+ # Bring in HF embeddings - need these to represent document chunks
14
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
15
+ # Bring in stuff to change service context
16
+ from llama_index import set_global_service_context
17
+ from llama_index import ServiceContext
18
+ # Import deps to load documents
19
+ from llama_index import VectorStoreIndex, download_loader
20
+ from pathlib import Path
21
+
22
+ # Define variable to hold llama2 weights naming
23
+ name = "meta-llama/Llama-2-70b-chat-hf"
24
+ # Set auth token variable from hugging face
25
+ auth_token = "YOUR HUGGING FACE AUTH TOKEN HERE"
26
+
27
+
28
+ @st.cache_resource
29
+ def get_tokenizer_model():
30
+ # Create tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(name, cache_dir='./model/', use_auth_token=auth_token)
32
+
33
+ # Create model
34
+ model = AutoModelForCausalLM.from_pretrained(name, cache_dir='./model/'
35
+ , use_auth_token=auth_token, torch_dtype=torch.float16,
36
+ rope_scaling={"type": "dynamic", "factor": 2}, load_in_8bit=True)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return tokenizer, model
39
+ tokenizer, model = get_tokenizer_model()
40
+
41
+ # Create a system prompt
42
+ system_prompt = """<s>[INST] <<SYS>>
43
+ You are a helpful, respectful and honest assistant. Always answer as
44
+ helpfully as possible, while being safe. Your answers should not include
45
+ any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
46
+ Please ensure that your responses are socially unbiased and positive in nature.
47
+
48
+ If a question does not make any sense, or is not factually coherent, explain
49
+ why instead of answering something not correct. If you don't know the answer
50
+ to a question, please don't share false information.
51
+
52
+ Your goal is to provide answers relating to the workout science and informatins in the documentSYS>>
53
+ """
54
+
55
+
56
+ # Throw together the query wrapper
57
+ query_wrapper_prompt = SimpleInputPrompt("{query_str} [/INST]")
58
+
59
+
60
+ llm = HuggingFaceLLM(context_window=1024,
61
+ max_new_tokens=128,
62
+ system_prompt=system_prompt,
63
+ query_wrapper_prompt=query_wrapper_prompt,
64
+ model=model,
65
+ tokenizer=tokenizer)
66
+
67
+ # Create and dl embeddings instance
68
+ embeddings=LangchainEmbedding(
69
+ HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
70
+ )
71
+
72
+
73
+ # Create new service context instance
74
+ service_context = ServiceContext.from_defaults(
75
+ chunk_size=1024,
76
+ llm=llm,
77
+ embed_model=embeddings
78
+ )
79
+ # And set the service context
80
+ set_global_service_context(service_context)
81
+
82
+ # Download PDF Loader
83
+ PyMuPDFReader = download_loader("PyMuPDFReader")
84
+ # Create PDF Loader
85
+ loader = PyMuPDFReader()
86
+ # Load documents
87
+ documents = loader.load(file_path=Path('./data/annualreport.pdf'), metadata=True)
88
+
89
+ # Download PDF Loader
90
+ PyMuPDFReader = download_loader("PyMuPDFReader")
91
+ # Create PDF Loader
92
+ loader = PyMuPDFReader()
93
+ # Load documents
94
+ documents = loader.load(file_path=Path('jeff_wo.pdf'), metadata=True)
95
+
96
+ # Create an index - we'll be able to query this in a sec
97
+ index = VectorStoreIndex.from_documents(documents)
98
+ # Setup index query engine using LLM
99
+ query_engine = index.as_query_engine()
100
+
101
+
102
+ # Create centered main title
103
+ st.title('🦙 Llama Banker')
104
+ # Create a text input box for the user
105
+ prompt = st.text_input('Input your prompt here')
106
+
107
+ # If the user hits enter
108
+ if prompt:
109
+ response = query_engine.query(prompt)
110
+ # ...and write it out to the screen
111
+ st.write(response)
112
 
113
+ # Display raw response object
114
+ with st.expander('Response Object'):
115
+ st.write(response)
116
+ # Display source text
117
+ with st.expander('Source Text'):
118
+ st.write(response.get_formatted_sources())