Illia56 commited on
Commit
a0737f2
·
1 Parent(s): 2de191f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
  from typing import Any, List, Mapping, Optional
4
-
5
  from gradio_client import Client
6
  from langchain.schema import Document
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -13,7 +13,7 @@ from langchain.chains import RetrievalQA
13
  from langchain.prompts import PromptTemplate
14
  import streamlit as st
15
  from pytube import YouTube
16
- import replicate
17
 
18
 
19
 
@@ -89,7 +89,6 @@ def sidebar():
89
  # RepetitionpenaltySide = st.slider("Repetition penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.05)
90
 
91
 
92
-
93
  sidebar()
94
  initialize_session_state()
95
 
@@ -120,9 +119,10 @@ if st.session_state.youtube_url and not st.session_state.setup_done and "REPLICA
120
  retriever.search_kwargs['distance_metric'] = 'cos'
121
  retriever.search_kwargs['k'] = 4
122
  with st.status("Running RetrievalQA..."):
123
- llama_instance = replicate.load(
124
- model="meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3",
125
- model_kwargs={"temperature": 0.75, "max_length": 4096, "top_p": 1},
 
126
  )
127
  st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever,chain_type_kwargs={"prompt": prompt})
128
 
 
1
  import os
2
  import logging
3
  from typing import Any, List, Mapping, Optional
4
+ from langchain.llms import HuggingFaceHub
5
  from gradio_client import Client
6
  from langchain.schema import Document
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
13
  from langchain.prompts import PromptTemplate
14
  import streamlit as st
15
  from pytube import YouTube
16
+ # import replicate
17
 
18
 
19
 
 
89
  # RepetitionpenaltySide = st.slider("Repetition penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.05)
90
 
91
 
 
92
  sidebar()
93
  initialize_session_state()
94
 
 
119
  retriever.search_kwargs['distance_metric'] = 'cos'
120
  retriever.search_kwargs['k'] = 4
121
  with st.status("Running RetrievalQA..."):
122
+ llama_instance = HuggingFaceHub(
123
+ model_kwargs={"max_length": 4096},
124
+ repo_id="meta-llama/Llama-2-70b-chat-hf",
125
+ huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
126
  )
127
  st.session_state.qa = RetrievalQA.from_chain_type(llm=llama_instance, chain_type="stuff", retriever=retriever,chain_type_kwargs={"prompt": prompt})
128