calmgoose commited on
Commit
e48d6fc
·
1 Parent(s): 3d3bae7

use hugging face inference api instead of loading pipeline since that wasn't working

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -9,7 +9,7 @@ from langchain.chains import VectorDBQA
9
  from huggingface_hub import snapshot_download
10
  from langchain import OpenAI
11
  from langchain import PromptTemplate
12
- from langchain.llms import HuggingFacePipeline
13
 
14
 
15
  BOOK_NAME = "1984"
@@ -84,11 +84,16 @@ def load_chain(model: Literal["openai", "GPT-NeoXT-Chat-Base-20B"] ="openai"):
84
  llm = OpenAI(temperature=0.2)
85
 
86
  if model=="GPT-NeoXT-Chat-Base-20B":
87
- llm = HuggingFacePipeline.from_model_id(
88
- model_id="togethercomputer/GPT-NeoXT-Chat-Base-20B",
 
 
 
 
 
89
  task="text-generation",
90
  model_kwargs={"temperature":0.2, "max_length":400}
91
- )
92
 
93
  # load chain
94
  chain = VectorDBQA.from_chain_type(
@@ -153,6 +158,14 @@ with st.sidebar:
153
  )
154
  os.environ["OPENAI_API_KEY"] = api_key
155
 
 
 
 
 
 
 
 
 
156
  st.markdown("---")
157
 
158
  st.info("Based on [Talk2Book](https://github.com/batmanscode/Talk2Book)")
@@ -174,7 +187,7 @@ ask = col2.button("Ask")
174
 
175
  if ask:
176
 
177
- if choice=="OpenAI" and api_key is "":
178
  st.write(f"**{BOOK_NAME}:** Whoops looks like you forgot your API key buddy")
179
  st.stop()
180
  else:
 
9
  from huggingface_hub import snapshot_download
10
  from langchain import OpenAI
11
  from langchain import PromptTemplate
12
+ from langchain.llms import HuggingFacePipeline, HuggingFaceHub
13
 
14
 
15
  BOOK_NAME = "1984"
 
84
  llm = OpenAI(temperature=0.2)
85
 
86
  if model=="GPT-NeoXT-Chat-Base-20B":
87
+ # llm = HuggingFacePipeline.from_model_id(
88
+ # model_id="togethercomputer/GPT-NeoXT-Chat-Base-20B",
89
+ # task="text-generation",
90
+ # model_kwargs={"temperature":0.2, "max_length":400}
91
+ # )
92
+ llm = HuggingFaceHub(
93
+ repo_id="togethercomputer/GPT-NeoXT-Chat-Base-20B",
94
  task="text-generation",
95
  model_kwargs={"temperature":0.2, "max_length":400}
96
+ )
97
 
98
  # load chain
99
  chain = VectorDBQA.from_chain_type(
 
158
  )
159
  os.environ["OPENAI_API_KEY"] = api_key
160
 
161
+
162
+ if choice == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
163
+ api_key = st.text_input(label = "Paste your Hugging Face Hub API key here to get started",
164
+ type = "password",
165
+ help = "This isn't saved 🙈"
166
+ )
167
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
168
+
169
  st.markdown("---")
170
 
171
  st.info("Based on [Talk2Book](https://github.com/batmanscode/Talk2Book)")
 
187
 
188
  if ask:
189
 
190
+ if api_key is "":
191
  st.write(f"**{BOOK_NAME}:** Whoops looks like you forgot your API key buddy")
192
  st.stop()
193
  else: