Spaces:
Paused
Paused
import os | |
import gradio as gr | |
import pinecone | |
from langchain import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.llms import HuggingFaceEndpoint | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.vectorstores import Pinecone | |
from torch import cuda | |
LLAMA_2_7B_CHAT_HF_FRANC_V0_9 = os.environ.get("LLAMA_2_7B_CHAT_HF_FRANC_V0_9") | |
HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY') | |
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT') | |
# Set up Pinecone vector store | |
pinecone.init( | |
api_key=PINECONE_API_KEY, | |
environment=PINECONE_ENVIRONMENT | |
) | |
index_name = 'stadion-6237' | |
index = pinecone.Index(index_name) | |
embedding_model_id = 'sentence-transformers/paraphrase-mpnet-base-v2' | |
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' | |
embedding_model = HuggingFaceEmbeddings( | |
model_name=embedding_model_id, | |
model_kwargs={'device': device}, | |
encode_kwargs={'device': device, 'batch_size': 32} | |
) | |
text_key = 'text' | |
vector_store = Pinecone( | |
index, embedding_model.embed_query, text_key | |
) | |
B_INST, E_INST = "[INST] ", " [/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
def get_prompt_template(instruction, system_prompt): | |
system_prompt = B_SYS + system_prompt + E_SYS | |
prompt_template = B_INST + system_prompt + instruction + E_INST | |
return prompt_template | |
template = get_prompt_template( | |
"""Use the following context to answer the question at the end. | |
Context: | |
{context} | |
Question: {question}""", | |
"""Reply in 10 sentences or less. | |
Do not use emotes.""" | |
) | |
endpoint_url = ( | |
LLAMA_2_7B_CHAT_HF_FRANC_V0_9 | |
) | |
llm = HuggingFaceEndpoint( | |
endpoint_url=endpoint_url, | |
huggingfacehub_api_token=HUGGING_FACE_HUB_TOKEN, | |
task="text-generation", | |
model_kwargs={ | |
"max_new_tokens": 512, | |
"temperature": 0.1, | |
"repetition_penalty": 1.1, | |
"return_full_text": True, | |
}, | |
) | |
prompt = PromptTemplate( | |
template=template, | |
input_variables=["context", "question"] | |
) | |
memory = ConversationBufferWindowMemory( | |
k=3, | |
memory_key="history", | |
input_key="question", | |
ai_prefix="Franc", | |
human_prefix="Runner", | |
) | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type='stuff', | |
retriever=vector_store.as_retriever(search_kwargs={'k': 4}), | |
chain_type_kwargs={ | |
"prompt": prompt, | |
# "memory": memory, | |
}, | |
) | |
def generate(message, history): | |
reply = rag_chain(message) | |
return reply['result'].strip() | |
gr.ChatInterface( | |
generate, | |
title="Franc v1.0", | |
theme=gr.themes.Soft(), | |
submit_btn="Ask Franc", | |
retry_btn="Do better, Franc!", | |
autofocus=True, | |
).queue().launch() | |