Spaces:
Sleeping
Sleeping
upload files
Browse files- Dockerfile +16 -0
- app.py +65 -0
- orator.py +146 -0
- requirements.txt +7 -0
Dockerfile
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.9
|
5 |
+
|
6 |
+
RUN useradd -m -u 1000 user
|
7 |
+
USER user
|
8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
9 |
+
|
10 |
+
WORKDIR /app
|
11 |
+
|
12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
14 |
+
|
15 |
+
COPY --chown=user . /app
|
16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from orator import Session, SQLDatabase, DocumentDatabase
|
4 |
+
from langchain.chat_models import init_chat_model
|
5 |
+
from fastapi.responses import StreamingResponse
|
6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
7 |
+
import asyncio
|
8 |
+
|
9 |
+
# Initialize FastAPI app
|
10 |
+
app = FastAPI(title="Orator Chat API")
|
11 |
+
|
12 |
+
app.add_middleware(
|
13 |
+
CORSMiddleware,
|
14 |
+
allow_origins=["*"], # Use a specific origin in production
|
15 |
+
allow_credentials=True,
|
16 |
+
allow_methods=["*"], # Allow all methods (GET, POST, etc.)
|
17 |
+
allow_headers=["*"], # Allow all headers
|
18 |
+
)
|
19 |
+
|
20 |
+
# Initialize LLM and databases
|
21 |
+
llm = init_chat_model("o3-mini", model_provider="openai")
|
22 |
+
chinook_db = SQLDatabase.from_uri("sqlite:////home/geetu/work/orator/data/chinook/Chinook.db")
|
23 |
+
pricegram_db = DocumentDatabase("/home/geetu/work/orator/data/pricegram/data.json", top_k=10)
|
24 |
+
|
25 |
+
# Initialize session
|
26 |
+
session = Session(llm=llm, datasources=[chinook_db, pricegram_db])
|
27 |
+
|
28 |
+
# Pydantic model for request
|
29 |
+
class QueryRequest(BaseModel):
|
30 |
+
query: str
|
31 |
+
source: int
|
32 |
+
|
33 |
+
|
34 |
+
@app.post("/query/")
|
35 |
+
async def get_response(request: QueryRequest):
|
36 |
+
"""Process a query and return the response."""
|
37 |
+
try:
|
38 |
+
print("Got Request:", request)
|
39 |
+
response, logs = session.invoke(request.query, datasource=request.source)
|
40 |
+
response = {"response": response}
|
41 |
+
print("Sending Respose:", response)
|
42 |
+
return response
|
43 |
+
except Exception as e:
|
44 |
+
raise HTTPException(status_code=500, detail=str(e))
|
45 |
+
|
46 |
+
|
47 |
+
@app.post("/query/stream/")
|
48 |
+
async def stream_response(request: QueryRequest):
|
49 |
+
"""Stream responses for a given query."""
|
50 |
+
async def event_generator():
|
51 |
+
try:
|
52 |
+
events = session.stream(request.query)
|
53 |
+
for event in events:
|
54 |
+
for person, quote in event.items():
|
55 |
+
yield f"{person}: {quote['messages'][-1].text}\n"
|
56 |
+
await asyncio.sleep(0.1) # Simulate streaming delay
|
57 |
+
except Exception as e:
|
58 |
+
yield f"Error: {str(e)}"
|
59 |
+
|
60 |
+
return StreamingResponse(event_generator(), media_type="text/plain")
|
61 |
+
|
62 |
+
|
63 |
+
@app.get("/")
|
64 |
+
async def root():
|
65 |
+
return {"message": "Welcome to the Orator Chat API"}
|
orator.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from abc import abstractmethod, ABC
|
3 |
+
|
4 |
+
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
|
5 |
+
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
6 |
+
from langchain_community.utilities.sql_database import SQLDatabase as LangchainSQLDatabase
|
7 |
+
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
8 |
+
|
9 |
+
from langchain_core.vectorstores import InMemoryVectorStore
|
10 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
11 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
12 |
+
|
13 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
14 |
+
|
15 |
+
from langchain import hub
|
16 |
+
from langchain.agents import create_react_agent
|
17 |
+
from langchain.schema import SystemMessage
|
18 |
+
from langchain.schema import SystemMessage, HumanMessage
|
19 |
+
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
|
20 |
+
|
21 |
+
from langgraph.prebuilt import create_react_agent
|
22 |
+
|
23 |
+
class Database(ABC):
|
24 |
+
@abstractmethod
|
25 |
+
def create_agent(self, llm):
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
class Session:
|
29 |
+
def __init__(self, llm: BaseChatModel, datasources=None):
|
30 |
+
self.llm = llm
|
31 |
+
self.datasources = datasources
|
32 |
+
self._datasources = []
|
33 |
+
self._dataagents = []
|
34 |
+
|
35 |
+
if self.datasources is not None:
|
36 |
+
for datasource in self.datasources:
|
37 |
+
self.add_datasource(datasource)
|
38 |
+
|
39 |
+
def add_datasource(self, database: Database):
|
40 |
+
agent = database.create_agent(self.llm)
|
41 |
+
self._datasources.append(database)
|
42 |
+
self._dataagents.append(agent)
|
43 |
+
|
44 |
+
def get_relevant_source(self, message, datasource):
|
45 |
+
if datasource is not None:
|
46 |
+
return self._datasources[datasource], self._dataagents[datasource]
|
47 |
+
return self._datasources[0], self._dataagents[0]
|
48 |
+
|
49 |
+
def invoke(self, message, datasource=None):
|
50 |
+
db, agent = self.get_relevant_source(message, datasource)
|
51 |
+
processed_message = db.process_message(message)
|
52 |
+
response = agent.invoke(processed_message)
|
53 |
+
processed_response = db.postprocess(response)
|
54 |
+
return processed_response, response
|
55 |
+
|
56 |
+
def stream(self, message, stream_mode=None):
|
57 |
+
db, agent = self.get_relevant_source(message)
|
58 |
+
return agent.stream(
|
59 |
+
{"messages": [("user", message)]},
|
60 |
+
stream_mode=stream_mode,
|
61 |
+
)
|
62 |
+
|
63 |
+
class SQLDatabase(Database):
|
64 |
+
def __init__(self, db):
|
65 |
+
self.db = db
|
66 |
+
|
67 |
+
def create_agent(self, llm):
|
68 |
+
toolkit = SQLDatabaseToolkit(db=self.db, llm=llm)
|
69 |
+
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
70 |
+
system_message = prompt_template.format(dialect="SQLite", top_k=5)
|
71 |
+
agent = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)
|
72 |
+
return agent
|
73 |
+
|
74 |
+
def process_message(self, message):
|
75 |
+
return {"messages": [("user", message)]}
|
76 |
+
|
77 |
+
def postprocess(self, response):
|
78 |
+
return response['messages'][-1].content
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def from_uri(cls, database_uri, engine_args=None, **kwargs):
|
82 |
+
db = LangchainSQLDatabase.from_uri(database_uri, engine_args, **kwargs)
|
83 |
+
return cls(db)
|
84 |
+
|
85 |
+
|
86 |
+
class DocumentDatabase(Database):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
path: str,
|
90 |
+
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
91 |
+
top_k: int = 3,
|
92 |
+
model_kwargs = None,
|
93 |
+
encode_kwargs = None,
|
94 |
+
):
|
95 |
+
self.path = path
|
96 |
+
self.model_name = model_name
|
97 |
+
self.top_k = top_k
|
98 |
+
self.model_kwargs = {"device": "cpu"} if model_kwargs is None else model_kwargs
|
99 |
+
self.encode_kwargs = {"batch_size": 8} if encode_kwargs is None else encode_kwargs
|
100 |
+
|
101 |
+
embeddings = HuggingFaceEmbeddings(
|
102 |
+
model_name=self.model_name,
|
103 |
+
model_kwargs=self.model_kwargs,
|
104 |
+
encode_kwargs=self.encode_kwargs,
|
105 |
+
show_progress=False,
|
106 |
+
)
|
107 |
+
self.vector_store = InMemoryVectorStore(embeddings)
|
108 |
+
with open(path, 'rb') as f:
|
109 |
+
self.vector_store.store = json.load(f)
|
110 |
+
|
111 |
+
def create_agent(self, llm):
|
112 |
+
# Step 1: Retrieve relevant documents from the vector store
|
113 |
+
retrieve_docs = RunnableLambda(lambda message: (message, self.vector_store.similarity_search(message, k=self.top_k)))
|
114 |
+
|
115 |
+
# Step 2: Format the retrieved docs into a prompt
|
116 |
+
def format_prompt(inputs):
|
117 |
+
message, docs = inputs
|
118 |
+
prompt = [
|
119 |
+
SystemMessage(
|
120 |
+
"You are an assistant for question-answering tasks. "
|
121 |
+
"Use the following pieces of retrieved context to answer "
|
122 |
+
"the question. If you don't know the answer, say that you "
|
123 |
+
"don't know. Use three sentences maximum and keep the "
|
124 |
+
"answer concise."
|
125 |
+
"\n\n"
|
126 |
+
f"{'\n\n'.join(doc.page_content for doc in docs)}"
|
127 |
+
),
|
128 |
+
HumanMessage(message)
|
129 |
+
]
|
130 |
+
return prompt
|
131 |
+
|
132 |
+
format_prompt_node = RunnableLambda(format_prompt)
|
133 |
+
|
134 |
+
# Step 3: Invoke LLM with the formatted prompt
|
135 |
+
invoke_llm = llm
|
136 |
+
|
137 |
+
# Step 4: Chain everything together
|
138 |
+
agent_pipeline = RunnablePassthrough() | retrieve_docs | format_prompt_node | invoke_llm
|
139 |
+
|
140 |
+
return agent_pipeline
|
141 |
+
|
142 |
+
def process_message(self, message):
|
143 |
+
return message
|
144 |
+
|
145 |
+
def postprocess(self, response):
|
146 |
+
return response.content
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
uvicorn[standard]
|
3 |
+
langchain
|
4 |
+
langgraph
|
5 |
+
langchain-core
|
6 |
+
langchain-community
|
7 |
+
langchain-huggingface
|