dotku commited on
Commit
14d715f
·
1 Parent(s): da6d4d7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +44 -1
main.py CHANGED
@@ -1,5 +1,23 @@
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  app = FastAPI()
5
 
@@ -14,4 +32,29 @@ def read_root():
14
 
15
  @app.get("/api/python")
16
  def hello_python():
17
- return {"message": "Hello Python"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
  from fastapi import FastAPI
4
  from fastapi.middleware.cors import CORSMiddleware
5
+ from langchain.vectorstores import Pinecone
6
+ from langchain.llms import OpenAI
7
+ from langchain.chains import RetrievalQA
8
+
9
+ PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
10
+ PINECONE_ENV = os.getenv('PINECONE_ENV')
11
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
12
+ PINECONE_INDEX_NAME = os.getenv('PINECONE_INDEX_NAME')
13
+
14
+ def parse_response(response):
15
+ result = response['result']
16
+ result += '\n\nSources:'
17
+ for source_name in response["source_documents"]:
18
+ result += ''.join((source_name.metadata['source'],
19
+ "page #:", str(source_name.metadata['page'])))
20
+ return result
21
 
22
  app = FastAPI()
23
 
 
32
 
33
  @app.get("/api/python")
34
  def hello_python():
35
+ return {"message": "Hello Python"}
36
+
37
+ @app.get("/prompt")
38
+ def read_root(p: str='According to HQ H303140, what is "Country of origin" means?'):
39
+ pinecone.init(
40
+ api_key=PINECONE_API_KEY,
41
+ environment=PINECONE_ENV
42
+ )
43
+ index = pinecone.Index(PINECONE_INDEX_NAME)
44
+ index.describe_index_stats()
45
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
46
+ docsearch = Pinecone.from_existing_index(PINECONE_INDEX_NAME, embeddings)
47
+ retriever = docsearch.as_retriever(
48
+ include_metadata=True,
49
+ metadata_key='source'
50
+ )
51
+ llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)
52
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
53
+ chain_type="stuff",
54
+ retriever=retriever,
55
+ return_source_documents=True)
56
+ response = qa_chain(p)
57
+ return {
58
+ "prompt": p,
59
+ "response": parse_response(response)
60
+ }