import gradio as gr import os import pinecone import time from langchain.embeddings.huggingface import HuggingFaceEmbeddings import torch import sentence_transformers from langchain.vectorstores import Pinecone from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA embed_model_id = 'sentence-transformers/all-MiniLM-L6-v2' # device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu' embed_model = HuggingFaceEmbeddings( model_name=embed_model_id, # model_kwargs={'device': device}, # encode_kwargs={'device': device, 'batch_size': 32} ) # get API key from app.pinecone.io and environment from console pinecone.init( api_key=os.environ.get('PINECONE_API_KEY'), environment=os.environ.get('PINECONE_ENVIRONMENT') ) docs = [ "this is one document", "and another document" ] embeddings = embed_model.embed_documents(docs) index_name = 'llama-rag' # if index_name not in pinecone.list_indexes(): # pinecone.create_index( # index_name, # dimension=len(embeddings[0]), # metric='cosine' # ) # # wait for index to finish initialization # while not pinecone.describe_index(index_name).status['ready']: # time.sleep(1) index = pinecone.Index(index_name) index.describe_index_stats() text_field = 'text' # field in metadata that contains text content vectorstore = Pinecone( index, embed_model.embed_query, text_field ) API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta" headers = {"Authorization": f"Bearer {os.environ.get('API_KEY')}"} llm = HuggingFaceTextGenInference( inference_server_url=API_URL, max_new_tokens=1024, top_k=10, top_p=0.95, typical_p=0.95, temperature=0.01, repetition_penalty=1.03, ) rag_pipeline = RetrievalQA.from_chain_type( llm=llm, chain_type='stuff', retriever=vectorstore.as_retriever() ) def question(question): global chatbot answer = rag_pipeline(question) chatbot = answer return answer['result'] demo = gr.Interface(fn=question, inputs="text", outputs="text") if __name__ == "__main__": demo.launch()