Spaces:
Runtime error
Runtime error
File size: 2,308 Bytes
4d82c20 1e5cddc 9bd7561 1e5cddc 9bd7561 4d82c20 1e5cddc 8e3e8e3 9bd7561 8e3e8e3 9bd7561 8e3e8e3 9bd7561 8e3e8e3 9bd7561 8e3e8e3 9bd7561 8e3e8e3 9bd7561 8e3e8e3 571a3b2 4ebe251 8e3e8e3 1e5cddc 9bd7561 4d82c20 9bd7561 4d82c20 f57bc73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from gradio_client import Client
import gradio as gr
import requests
from langchain.chains import RetrievalQA
import pinecone
from langchain.vectorstores import Pinecone
import os
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
import time
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
headers = {"Authorization": f"Bearer {os.environ.get('API_KEY')}"}
# retrieval = Client("https://ishaan-mital-ncert-helper-vector-db.hf.space/--replicas/149bl5mjn/")
embed_model_id = 'sentence-transformers/all-MiniLM-L6-v2'
# device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
embed_model = HuggingFaceEmbeddings(
model_name=embed_model_id,
# model_kwargs={'device': device},
# encode_kwargs={'device': device, 'batch_size': 32}
)
pinecone.init(
api_key=os.environ.get('PINECONE_API_KEY'),
environment=os.environ.get('PINECONE_ENVIRONMENT')
)
index_name = 'llama-rag'
index = pinecone.Index(index_name)
text_field = 'text' # field in metadata that contains text content
docs = [
"this is one document",
"and another document"
]
embeddings = embed_model.embed_documents(docs)
if index_name not in pinecone.list_indexes():
pinecone.create_index(
index_name,
dimension=len(embeddings[0]),
metric='cosine'
)
# wait for index to finish initialization
while not pinecone.describe_index(index_name).status['ready']:
time.sleep(1)
vectorstore = Pinecone(
index, embed_model.embed_query, text_field
)
def call_llm_api(input_text,context):
payload = {
"inputs": f'question: {input_text}, context: {context}'
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json() # Adjust as needed based on your API response format
rag_pipeline = RetrievalQA.from_chain_type(
llm=call_llm_api, chain_type='stuff',
retriever=vectorstore.as_retriever()
)
def main(question):
return rag_pipeline(question)
# global chatbot
# context = retrieval.predict(question, api_name = "/predict")
# answer=call_llm_api(question,context)
# # chatbot = answer[1]
# return answer[0]
demo = gr.Interface(main, inputs = "text", outputs = "text")
if __name__ == "__main__":
demo.launch() |