calmgoose commited on
Commit
6e48340
Β·
1 Parent(s): eeefa5d

revert: remove option to use `togethercomputer/GPT-NeoXT-Chat-Base-20B`

Browse files
Files changed (1) hide show
  1. app.py +15 -56
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- from typing import Literal
3
- import logging
4
  import streamlit as st
5
 
6
  from langchain.embeddings import HuggingFaceInstructEmbeddings
@@ -9,7 +7,6 @@ from langchain.chains import VectorDBQA
9
  from huggingface_hub import snapshot_download
10
  from langchain import OpenAI
11
  from langchain import PromptTemplate
12
- from langchain.llms import HuggingFacePipeline, HuggingFaceHub
13
 
14
 
15
  BOOK_NAME = "1984"
@@ -76,26 +73,10 @@ def load_prompt(book_name, author_name):
76
  return PROMPT
77
 
78
 
79
- @st.experimental_singleton(show_spinner=False, max_entries=1)
80
- def load_chain(model: Literal["openai", "togethercomputer/GPT-NeoXT-Chat-Base-20B"] ="openai"):
81
-
82
- # choose model
83
- if model=="openai":
84
- llm = OpenAI(temperature=0.2)
85
-
86
- if model=="togethercomputer/GPT-NeoXT-Chat-Base-20B":
87
- # llm = HuggingFacePipeline.from_model_id(
88
- # model_id="togethercomputer/GPT-NeoXT-Chat-Base-20B",
89
- # task="text-generation",
90
- # model_kwargs={"temperature":0.2, "max_length":400}
91
- # )
92
- llm = HuggingFaceHub(
93
- repo_id="togethercomputer/GPT-NeoXT-Chat-Base-20B",
94
- task="text-generation",
95
- model_kwargs={"temperature":0.2, "max_length":400}
96
- )
97
-
98
- # load chain
99
  chain = VectorDBQA.from_chain_type(
100
  chain_type_kwargs = {"prompt": load_prompt(book_name=BOOK_NAME, author_name=AUTHOR_NAME)},
101
  llm=llm,
@@ -104,14 +85,12 @@ def load_chain(model: Literal["openai", "togethercomputer/GPT-NeoXT-Chat-Base-20
104
  k=8,
105
  return_source_documents=True,
106
  )
107
-
108
- logging.info(f"Loaded chain with {model}.")
109
 
110
  return chain
111
 
112
 
113
- def get_answer(question, model="openai"):
114
- chain = load_chain(model=model)
115
  result = chain({"query": question})
116
 
117
  answer = result["result"]
@@ -145,26 +124,11 @@ def get_answer(question, model="openai"):
145
 
146
  ##### sidebar ####
147
  with st.sidebar:
148
-
149
- choice= st.radio("Choose your API:",
150
- ["OpenAI", "togethercomputer/GPT-NeoXT-Chat-Base-20B"],
151
- help="GPT-NeoXT-Chat-Base-20B doesn't need an API Key"
152
- )
153
-
154
- if choice == "OpenAI":
155
- api_key = st.text_input(label = "Paste your OpenAI API key here to get started",
156
- type = "password",
157
- help = "This isn't saved πŸ™ˆ"
158
- )
159
- os.environ["OPENAI_API_KEY"] = api_key
160
-
161
-
162
- if choice == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
163
- api_key = st.text_input(label = "Paste your Hugging Face Hub API key here to get started",
164
- type = "password",
165
- help = "This isn't saved πŸ™ˆ"
166
- )
167
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
168
 
169
  st.markdown("---")
170
 
@@ -187,21 +151,16 @@ ask = col2.button("Ask")
187
 
188
  if ask:
189
 
190
- if (choice=="openai" and os.environ["OPENAI_API_KEY"]=="") or (choice=="togethercomputer/GPT-NeoXT-Chat-Base-20B" and os.environ["HUGGINGFACEHUB_API_TOKEN"]==""):
191
  st.write(f"**{BOOK_NAME}:** Whoops looks like you forgot your API key buddy")
192
  st.stop()
193
  else:
194
  with st.spinner("Um... excuse me but... this can take about a minute for your first question because some stuff have to be downloaded πŸ₯ΊπŸ‘‰πŸ»πŸ‘ˆπŸ»"):
195
  try:
196
- answer, pages, extract = get_answer(question=user_input, model=choice)
197
- logging.info(f"Answer successfully generated using {choice}.")
198
  except:
199
- if choice=="togethercomputer/GPT-NeoXT-Chat-Base-20B":
200
- st.write("The model probably timed out :(")
201
- st.stop()
202
- else:
203
- st.write(f"**{BOOK_NAME}:** What\'s going on? That's not the right API key")
204
- st.stop()
205
 
206
  st.write(f"**{BOOK_NAME}:** {answer}")
207
 
 
1
  import os
 
 
2
  import streamlit as st
3
 
4
  from langchain.embeddings import HuggingFaceInstructEmbeddings
 
7
  from huggingface_hub import snapshot_download
8
  from langchain import OpenAI
9
  from langchain import PromptTemplate
 
10
 
11
 
12
  BOOK_NAME = "1984"
 
73
  return PROMPT
74
 
75
 
76
+ @st.experimental_singleton(show_spinner=False)
77
+ def load_chain():
78
+ llm = OpenAI(temperature=0.2)
79
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  chain = VectorDBQA.from_chain_type(
81
  chain_type_kwargs = {"prompt": load_prompt(book_name=BOOK_NAME, author_name=AUTHOR_NAME)},
82
  llm=llm,
 
85
  k=8,
86
  return_source_documents=True,
87
  )
 
 
88
 
89
  return chain
90
 
91
 
92
+ def get_answer(question):
93
+ chain = load_chain()
94
  result = chain({"query": question})
95
 
96
  answer = result["result"]
 
124
 
125
  ##### sidebar ####
126
  with st.sidebar:
127
+ api_key = st.text_input(label = "Paste your OpenAI API key here to get started",
128
+ type = "password",
129
+ help = "This isn't saved πŸ™ˆ"
130
+ )
131
+ os.environ["OPENAI_API_KEY"] = api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  st.markdown("---")
134
 
 
151
 
152
  if ask:
153
 
154
+ if api_key is "":
155
  st.write(f"**{BOOK_NAME}:** Whoops looks like you forgot your API key buddy")
156
  st.stop()
157
  else:
158
  with st.spinner("Um... excuse me but... this can take about a minute for your first question because some stuff have to be downloaded πŸ₯ΊπŸ‘‰πŸ»πŸ‘ˆπŸ»"):
159
  try:
160
+ answer, pages, extract = get_answer(question=user_input)
 
161
  except:
162
+ st.write(f"**{BOOK_NAME}:** What\'s going on? That's not the right API key")
163
+ st.stop()
 
 
 
 
164
 
165
  st.write(f"**{BOOK_NAME}:** {answer}")
166