ncert-helper / app.py
ishaan-mital's picture
initial commit
4ebe251
raw
history blame
2.39 kB
from gradio_client import Client
import gradio as gr
import requests
# from langchain.chains import RetrievalQA
# import pinecone
# from langchain.vectorstores import Pinecone
import os
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
# import time
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
headers = {"Authorization": f"Bearer {os.environ.get('API_KEY')}"}
retrieval = Client("https://ishaan-mital-ncert-helper-vector-db.hf.space/--replicas/149bl5mjn/")
# embed_model_id = 'sentence-transformers/all-MiniLM-L6-v2'
# # device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# embed_model = HuggingFaceEmbeddings(
# model_name=embed_model_id,
# # model_kwargs={'device': device},
# # encode_kwargs={'device': device, 'batch_size': 32}
# )
# pinecone.init(
# api_key=os.environ.get('PINECONE_API_KEY'),
# environment=os.environ.get('PINECONE_ENVIRONMENT')
# )
# index_name = 'llama-rag'
# index = pinecone.Index(index_name)
# text_field = 'text' # field in metadata that contains text content
# docs = [
# "this is one document",
# "and another document"
# ]
# embeddings = embed_model.embed_documents(docs)
# if index_name not in pinecone.list_indexes():
# pinecone.create_index(
# index_name,
# dimension=len(embeddings[0]),
# metric='cosine'
# )
# # wait for index to finish initialization
# while not pinecone.describe_index(index_name).status['ready']:
# time.sleep(1)
# vectorstore = Pinecone(
# index, embed_model.embed_query, text_field
# )
def call_llm_api(input_text,context):
payload = {
"inputs": f'question: {input_text}, context: {context}'
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json() # Adjust as needed based on your API response format
# rag_pipeline = RetrievalQA.from_chain_type(
# llm=call_llm_api, chain_type='stuff',
# retriever=vectorstore.as_retriever()
# )
def main(question):
# return rag_pipeline(question)
global chatbot
context = retrieval.predict(question, api_name = "/predict")
answer=call_llm_api(question,context)
chatbot = answer[1]
return answer[0]['generated_text']
demo = gr.Interface(main, inputs = "text", outputs = "text")
if __name__ == "__main__":
demo.launch()