geetu040 commited on
Commit
ced6b34
·
1 Parent(s): eea7eec

upload files

Browse files
Files changed (4) hide show
  1. Dockerfile +16 -0
  2. app.py +65 -0
  3. orator.py +146 -0
  4. requirements.txt +7 -0
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from orator import Session, SQLDatabase, DocumentDatabase
4
+ from langchain.chat_models import init_chat_model
5
+ from fastapi.responses import StreamingResponse
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ import asyncio
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI(title="Orator Chat API")
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"], # Use a specific origin in production
15
+ allow_credentials=True,
16
+ allow_methods=["*"], # Allow all methods (GET, POST, etc.)
17
+ allow_headers=["*"], # Allow all headers
18
+ )
19
+
20
+ # Initialize LLM and databases
21
+ llm = init_chat_model("o3-mini", model_provider="openai")
22
+ chinook_db = SQLDatabase.from_uri("sqlite:////home/geetu/work/orator/data/chinook/Chinook.db")
23
+ pricegram_db = DocumentDatabase("/home/geetu/work/orator/data/pricegram/data.json", top_k=10)
24
+
25
+ # Initialize session
26
+ session = Session(llm=llm, datasources=[chinook_db, pricegram_db])
27
+
28
+ # Pydantic model for request
29
+ class QueryRequest(BaseModel):
30
+ query: str
31
+ source: int
32
+
33
+
34
+ @app.post("/query/")
35
+ async def get_response(request: QueryRequest):
36
+ """Process a query and return the response."""
37
+ try:
38
+ print("Got Request:", request)
39
+ response, logs = session.invoke(request.query, datasource=request.source)
40
+ response = {"response": response}
41
+ print("Sending Respose:", response)
42
+ return response
43
+ except Exception as e:
44
+ raise HTTPException(status_code=500, detail=str(e))
45
+
46
+
47
+ @app.post("/query/stream/")
48
+ async def stream_response(request: QueryRequest):
49
+ """Stream responses for a given query."""
50
+ async def event_generator():
51
+ try:
52
+ events = session.stream(request.query)
53
+ for event in events:
54
+ for person, quote in event.items():
55
+ yield f"{person}: {quote['messages'][-1].text}\n"
56
+ await asyncio.sleep(0.1) # Simulate streaming delay
57
+ except Exception as e:
58
+ yield f"Error: {str(e)}"
59
+
60
+ return StreamingResponse(event_generator(), media_type="text/plain")
61
+
62
+
63
+ @app.get("/")
64
+ async def root():
65
+ return {"message": "Welcome to the Orator Chat API"}
orator.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from abc import abstractmethod, ABC
3
+
4
+ from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
5
+ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
6
+ from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
7
+ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
8
+
9
+ from langchain_core.vectorstores import InMemoryVectorStore
10
+ from langchain_core.messages import SystemMessage, HumanMessage
11
+ from langchain_core.language_models.chat_models import BaseChatModel
12
+
13
+ from langchain_huggingface import HuggingFaceEmbeddings
14
+
15
+ from langchain import hub
16
+ from langchain.agents import create_react_agent
17
+ from langchain.schema import SystemMessage
18
+ from langchain.schema import SystemMessage, HumanMessage
19
+ from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
20
+
21
+ from langgraph.prebuilt import create_react_agent
22
+
23
+ class Database(ABC):
24
+ @abstractmethod
25
+ def create_agent(self, llm):
26
+ raise NotImplementedError
27
+
28
+ class Session:
29
+ def __init__(self, llm: BaseChatModel, datasources=None):
30
+ self.llm = llm
31
+ self.datasources = datasources
32
+ self._datasources = []
33
+ self._dataagents = []
34
+
35
+ if self.datasources is not None:
36
+ for datasource in self.datasources:
37
+ self.add_datasource(datasource)
38
+
39
+ def add_datasource(self, database: Database):
40
+ agent = database.create_agent(self.llm)
41
+ self._datasources.append(database)
42
+ self._dataagents.append(agent)
43
+
44
+ def get_relevant_source(self, message, datasource):
45
+ if datasource is not None:
46
+ return self._datasources[datasource], self._dataagents[datasource]
47
+ return self._datasources[0], self._dataagents[0]
48
+
49
+ def invoke(self, message, datasource=None):
50
+ db, agent = self.get_relevant_source(message, datasource)
51
+ processed_message = db.process_message(message)
52
+ response = agent.invoke(processed_message)
53
+ processed_response = db.postprocess(response)
54
+ return processed_response, response
55
+
56
+ def stream(self, message, stream_mode=None):
57
+ db, agent = self.get_relevant_source(message)
58
+ return agent.stream(
59
+ {"messages": [("user", message)]},
60
+ stream_mode=stream_mode,
61
+ )
62
+
63
+ class SQLDatabase(Database):
64
+ def __init__(self, db):
65
+ self.db = db
66
+
67
+ def create_agent(self, llm):
68
+ toolkit = SQLDatabaseToolkit(db=self.db, llm=llm)
69
+ prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
70
+ system_message = prompt_template.format(dialect="SQLite", top_k=5)
71
+ agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)
72
+ return agent
73
+
74
+ def process_message(self, message):
75
+ return {"messages": [("user", message)]}
76
+
77
+ def postprocess(self, response):
78
+ return response['messages'][-1].content
79
+
80
+ @classmethod
81
+ def from_uri(cls, database_uri, engine_args=None, **kwargs):
82
+ db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs)
83
+ return cls(db)
84
+
85
+
86
+ class DocumentDatabase(Database):
87
+ def __init__(
88
+ self,
89
+ path: str,
90
+ model_name: str = "sentence-transformers/all-mpnet-base-v2",
91
+ top_k: int = 3,
92
+ model_kwargs = None,
93
+ encode_kwargs = None,
94
+ ):
95
+ self.path = path
96
+ self.model_name = model_name
97
+ self.top_k = top_k
98
+ self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs
99
+ self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs
100
+
101
+ embeddings = HuggingFaceEmbeddings(
102
+ model_name=self.model_name,
103
+ model_kwargs=self.model_kwargs,
104
+ encode_kwargs=self.encode_kwargs,
105
+ show_progress=False,
106
+ )
107
+ self.vector_store = InMemoryVectorStore(embeddings)
108
+ with open(path, 'rb') as f:
109
+ self.vector_store.store = json.load(f)
110
+
111
+ def create_agent(self, llm):
112
+ # Step 1: Retrieve relevant documents from the vector store
113
+ retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k)))
114
+
115
+ # Step 2: Format the retrieved docs into a prompt
116
+ def format_prompt(inputs):
117
+ message, docs = inputs
118
+ prompt = [
119
+ SystemMessage(
120
+ "You are an assistant for question-answering tasks. "
121
+ "Use the following pieces of retrieved context to answer "
122
+ "the question. If you don't know the answer, say that you "
123
+ "don't know. Use three sentences maximum and keep the "
124
+ "answer concise."
125
+ "\n\n"
126
+ f"{'\n\n'.join(doc.page_content for doc in docs)}"
127
+ ),
128
+ HumanMessage(message)
129
+ ]
130
+ return prompt
131
+
132
+ format_prompt_node = RunnableLambda(format_prompt)
133
+
134
+ # Step 3: Invoke LLM with the formatted prompt
135
+ invoke_llm = llm
136
+
137
+ # Step 4: Chain everything together
138
+ agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm
139
+
140
+ return agent_pipeline
141
+
142
+ def process_message(self, message):
143
+ return message
144
+
145
+ def postprocess(self, response):
146
+ return response.content
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ langchain
4
+ langgraph
5
+ langchain-core
6
+ langchain-community
7
+ langchain-huggingface