Spaces:
Paused
Paused
File size: 2,836 Bytes
4f5fe85 f30f794 4f5fe85 9065fba 4f5fe85 f30f794 9065fba 4f5fe85 ce2c548 4f5fe85 f30f794 4f5fe85 f30f794 5191bb0 f30f794 4f5fe85 3dd00c1 4f5fe85 f30f794 5191bb0 4f5fe85 3b7e2b7 69e7bc7 4f5fe85 b55a7b6 4f5fe85 a50f4c4 4f5fe85 b55a7b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import gradio as gr
import pinecone
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferWindowMemory
from langchain.vectorstores import Pinecone
from torch import cuda
LLAMA_2_7B_CHAT_HF_FRANC_V0_9 = os.environ.get("LLAMA_2_7B_CHAT_HF_FRANC_V0_9")
HUGGING_FACE_HUB_TOKEN = os.environ.get("HUGGING_FACE_HUB_TOKEN")
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
PINECONE_ENVIRONMENT = os.environ.get('PINECONE_ENVIRONMENT')
# Set up Pinecone vector store
pinecone.init(
api_key=PINECONE_API_KEY,
environment=PINECONE_ENVIRONMENT
)
index_name = 'stadion-6237'
index = pinecone.Index(index_name)
embedding_model_id = 'sentence-transformers/paraphrase-mpnet-base-v2'
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
embedding_model = HuggingFaceEmbeddings(
model_name=embedding_model_id,
model_kwargs={'device': device},
encode_kwargs={'device': device, 'batch_size': 32}
)
text_key = 'text'
vector_store = Pinecone(
index, embedding_model.embed_query, text_key
)
B_INST, E_INST = "[INST] ", " [/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
def get_prompt_template(instruction, system_prompt):
system_prompt = B_SYS + system_prompt + E_SYS
prompt_template = B_INST + system_prompt + instruction + E_INST
return prompt_template
template = get_prompt_template(
"""Use the following context to answer the question at the end.
Context:
{context}
Question: {question}""",
"""Reply in 10 sentences or less.
Do not use emotes."""
)
endpoint_url = (
LLAMA_2_7B_CHAT_HF_FRANC_V0_9
)
llm = HuggingFaceEndpoint(
endpoint_url=endpoint_url,
huggingfacehub_api_token=HUGGING_FACE_HUB_TOKEN,
task="text-generation",
model_kwargs={
"max_new_tokens": 512,
"temperature": 0.1,
"repetition_penalty": 1.1,
"return_full_text": True,
},
)
prompt = PromptTemplate(
template=template,
input_variables=["context", "question"]
)
memory = ConversationBufferWindowMemory(
k=3,
memory_key="history",
input_key="question",
ai_prefix="Franc",
human_prefix="Runner",
)
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=vector_store.as_retriever(search_kwargs={'k': 4}),
chain_type_kwargs={
"prompt": prompt,
# "memory": memory,
},
)
def generate(message, history):
reply = rag_chain(message)
return reply['result'].strip()
gr.ChatInterface(
generate,
title="Franc v1.0",
theme=gr.themes.Soft(),
submit_btn="Ask Franc",
retry_btn="Do better, Franc!",
autofocus=True,
).queue().launch()
|