|
import { BaseChatModel } from "@langchain/core/language_models/chat_models"; |
|
import { Embeddings } from "@langchain/core/embeddings"; |
|
import { CHAT_MODELS, EMBEDDING_MODELS, IConfig, MODALITIES, PROVIDERS } from "@/lib/config/types"; |
|
import { ConfigManager } from "@/lib/config/manager"; |
|
import { ChatOllama, OllamaEmbeddings } from "@langchain/ollama"; |
|
import { ChatOpenAI, OpenAIEmbeddings } from "@langchain/openai"; |
|
import { ChatAnthropic } from "@langchain/anthropic"; |
|
import { ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings } from "@langchain/google-genai"; |
|
import { IDocument } from "../document/types"; |
|
import { ChatPromptTemplate } from "@langchain/core/prompts"; |
|
import { Tool } from "langchain/tools"; |
|
import { Calculator } from "@langchain/community/tools/calculator"; |
|
import { AgentExecutor, createToolCallingAgent } from "langchain/agents"; |
|
import { RunnableWithMessageHistory } from "@langchain/core/runnables"; |
|
import { DexieChatMemory } from "./memory"; |
|
import { DocumentManager } from "@/lib/document/manager"; |
|
import { Document } from "@langchain/core/documents"; |
|
import { HumanMessage, ToolMessage } from "@langchain/core/messages"; |
|
import { IChatSession } from "./types"; |
|
import { ChatHFInference } from "./chat-hf"; |
|
|
|
|
|
interface ErrorWithMessage { |
|
message: string; |
|
} |
|
|
|
function isErrorWithMessage(error: unknown): error is ErrorWithMessage { |
|
return ( |
|
typeof error === 'object' && |
|
error !== null && |
|
'message' in error && |
|
typeof (error as Record<string, unknown>).message === 'string' |
|
); |
|
} |
|
|
|
function toErrorWithMessage(error: unknown): ErrorWithMessage { |
|
if (isErrorWithMessage(error)) return error; |
|
|
|
try { |
|
return new Error(String(error)); |
|
} catch { |
|
|
|
return new Error('Unknown error'); |
|
} |
|
} |
|
|
|
function getErrorMessage(error: unknown): string { |
|
return toErrorWithMessage(error).message; |
|
} |
|
|
|
export class ChatManager { |
|
model!: BaseChatModel; |
|
embeddings!: Embeddings | null; |
|
controller: AbortController; |
|
configManager: ConfigManager; |
|
config!: IConfig; |
|
documentManager: DocumentManager; |
|
private static instance: ChatManager | null = null; |
|
|
|
constructor() { |
|
this.controller = new AbortController(); |
|
this.configManager = ConfigManager.getInstance(); |
|
this.documentManager = DocumentManager.getInstance(); |
|
this.initializeConfig(); |
|
} |
|
|
|
public static getInstance(): ChatManager { |
|
if (!ChatManager.instance) { |
|
ChatManager.instance = new ChatManager(); |
|
} |
|
return ChatManager.instance; |
|
} |
|
|
|
private async initializeConfig() { |
|
this.config = await this.configManager.getConfig(); |
|
} |
|
|
|
public resetController() { |
|
this.controller = new AbortController(); |
|
} |
|
|
|
private async getChatModel(modelName: string): Promise<BaseChatModel> { |
|
|
|
if (!this.config) { |
|
await this.initializeConfig(); |
|
} |
|
|
|
const model = CHAT_MODELS.find(m => m.model === modelName); |
|
|
|
if (!model) { |
|
throw new Error(`Chat model ${modelName} not found`); |
|
} |
|
|
|
try { |
|
switch (model.provider) { |
|
case PROVIDERS.ollama: |
|
return new ChatOllama({ |
|
baseUrl: this.config.ollama_base_url, |
|
model: model.model, |
|
}); |
|
|
|
case PROVIDERS.openai: |
|
return new ChatOpenAI({ |
|
modelName: this.config.openai_model && this.config.openai_model.trim() !== '' ? this.config.openai_model : model.model, |
|
apiKey: this.config.openai_api_key, |
|
configuration: { |
|
baseURL: this.config.openai_base_url && this.config.openai_base_url.trim() !== '' ? this.config.openai_base_url : undefined, |
|
} |
|
}); |
|
|
|
case PROVIDERS.anthropic: |
|
return new ChatAnthropic({ |
|
modelName: model.model, |
|
apiKey: this.config.anthropic_api_key, |
|
}); |
|
|
|
case PROVIDERS.gemini: |
|
return new ChatGoogleGenerativeAI({ |
|
modelName: model.model, |
|
apiKey: this.config.gemini_api_key, |
|
}); |
|
|
|
case PROVIDERS.huggingface: |
|
return ChatHFInference({ |
|
modelName: model.model, |
|
apiKey: this.config.hf_token, |
|
}); |
|
|
|
default: |
|
throw new Error(`Provider ${model.provider} not implemented yet for chat models`); |
|
} |
|
} catch (error: unknown) { |
|
console.error(`Error creating chat model ${modelName}:`, error); |
|
throw new Error(`Failed to initialize chat model ${modelName}: ${getErrorMessage(error)}`); |
|
} |
|
} |
|
|
|
private async getEmbeddingModel(modelName: string): Promise<Embeddings> { |
|
|
|
if (!this.config) { |
|
await this.initializeConfig(); |
|
} |
|
|
|
if (!modelName) { |
|
throw new Error("No embedding model specified"); |
|
} |
|
|
|
const model = EMBEDDING_MODELS.find(m => m.model === modelName); |
|
|
|
if (!model) { |
|
throw new Error(`Embedding model ${modelName} not found`); |
|
} |
|
|
|
|
|
if (model.provider === PROVIDERS.ollama) { |
|
|
|
if (!this.config.ollama_base_url || this.config.ollama_base_url.trim() === '') { |
|
throw new Error(`Ollama base URL is not configured. Please set a valid URL in the settings.`); |
|
} |
|
|
|
|
|
if (!this.config.ollama_available) { |
|
throw new Error(`Ollama server is not available. Please check your connection to ${this.config.ollama_base_url}`); |
|
} |
|
} |
|
|
|
try { |
|
switch (model.provider) { |
|
case PROVIDERS.ollama: |
|
return new OllamaEmbeddings({ |
|
baseUrl: this.config.ollama_base_url, |
|
model: model.model, |
|
}); |
|
|
|
case PROVIDERS.openai: |
|
return new OpenAIEmbeddings({ |
|
modelName: model.model, |
|
apiKey: this.config.openai_api_key, |
|
}); |
|
|
|
case PROVIDERS.gemini: |
|
return new GoogleGenerativeAIEmbeddings({ |
|
modelName: model.model, |
|
apiKey: this.config.gemini_api_key, |
|
}); |
|
|
|
default: |
|
throw new Error(`Provider ${model.provider} not implemented yet for embedding models`); |
|
} |
|
} catch (error: unknown) { |
|
console.error(`Error creating embedding model ${modelName}:`, error); |
|
throw new Error(`Failed to initialize embedding model ${modelName}: ${getErrorMessage(error)}`); |
|
} |
|
} |
|
|
|
private async getAgent( |
|
enabledTools: string[] = [], |
|
) { |
|
const prompt = ChatPromptTemplate.fromMessages([ |
|
["system", "You are a helpful assistant"], |
|
["placeholder", "{chat_history}"], |
|
["human", "{input}"], |
|
["placeholder", "{agent_scratchpad}"], |
|
]); |
|
|
|
const tools: Tool[] = []; |
|
if (enabledTools?.includes("calculator")) { |
|
tools.push(new Calculator()); |
|
} |
|
|
|
const agent = createToolCallingAgent({ |
|
llm: this.model, |
|
tools, |
|
prompt, |
|
}); |
|
|
|
const agentExecutor = new AgentExecutor({ |
|
agent, |
|
tools, |
|
returnIntermediateSteps: true, |
|
}) |
|
|
|
return new RunnableWithMessageHistory({ |
|
runnable: agentExecutor, |
|
getMessageHistory: (sessionId: string) => new DexieChatMemory(sessionId), |
|
inputMessagesKey: "input", |
|
historyMessagesKey: "chat_history" |
|
}); |
|
} |
|
|
|
private async createMessageWithAttachments( |
|
documents?: IDocument[], |
|
chatSession?: IChatSession, |
|
): Promise<HumanMessage> { |
|
if (!documents || documents.length === 0) { |
|
return new HumanMessage({ content: "" }); |
|
} |
|
|
|
const currentModel = CHAT_MODELS.find( |
|
m => m.model === (chatSession?.model || this.config.default_chat_model) |
|
); |
|
|
|
if (!currentModel) { |
|
throw new Error(`Model ${chatSession?.model || this.config.default_chat_model} not found in CHAT_MODELS`); |
|
} |
|
|
|
|
|
const processedContent: { |
|
docs: Document[]; |
|
images: File[]; |
|
audios: File[]; |
|
videos: File[]; |
|
pdfs: File[]; |
|
} = { |
|
docs: [], |
|
images: [], |
|
audios: [], |
|
videos: [], |
|
pdfs: [] |
|
}; |
|
|
|
|
|
for (const doc of documents) { |
|
const file = await this.documentManager.getDocument(doc.id); |
|
|
|
switch (doc.type) { |
|
case "image": |
|
if (currentModel.modalities.includes(MODALITIES.image)) { |
|
processedContent.images.push(file); |
|
} |
|
break; |
|
case "audio": |
|
if (currentModel.modalities.includes(MODALITIES.audio)) { |
|
processedContent.audios.push(file); |
|
} |
|
break; |
|
case "video": |
|
if (currentModel.modalities.includes(MODALITIES.video)) { |
|
processedContent.videos.push(file); |
|
} |
|
break; |
|
case "pdf": |
|
if (currentModel.modalities.includes(MODALITIES.pdf)) { |
|
processedContent.pdfs.push(file); |
|
} else { |
|
processedContent.docs.push(...(await this.documentManager.loadDocument(doc.id))); |
|
} |
|
break; |
|
default: |
|
processedContent.docs.push(...(await this.documentManager.loadDocument(doc.id))); |
|
break; |
|
} |
|
} |
|
|
|
|
|
const providerFormatters = { |
|
[PROVIDERS.openai]: async () => { |
|
const content = []; |
|
|
|
|
|
for (const image of processedContent.images) { |
|
const base64 = Buffer.from(await image.arrayBuffer()).toString("base64"); |
|
content.push({ |
|
type: "image_url", |
|
image_url: { |
|
url: `data:${image.type};base64,${base64}` |
|
} |
|
}); |
|
} |
|
|
|
|
|
for (const doc of processedContent.docs) { |
|
content.push({ |
|
type: "text", |
|
text: `File name: ${doc.metadata.name}\nFile content: ${doc.pageContent}` |
|
}); |
|
} |
|
|
|
return content; |
|
}, |
|
|
|
[PROVIDERS.anthropic]: async () => { |
|
const content = []; |
|
|
|
|
|
for (const image of processedContent.images) { |
|
const base64 = Buffer.from(await image.arrayBuffer()).toString("base64"); |
|
content.push({ |
|
type: "image_url", |
|
image_url: { |
|
url: `data:${image.type};base64,${base64}` |
|
} |
|
}); |
|
} |
|
|
|
|
|
for (const pdf of processedContent.pdfs) { |
|
content.push({ |
|
type: "document", |
|
source: { |
|
type: "base64", |
|
data: Buffer.from(await pdf.arrayBuffer()).toString("base64"), |
|
media_type: "application/pdf", |
|
} |
|
}); |
|
} |
|
|
|
|
|
for (const doc of processedContent.docs) { |
|
content.push({ |
|
type: "text", |
|
text: `File name: ${doc.metadata.name}\nFile content: ${doc.pageContent}` |
|
}); |
|
} |
|
|
|
return content; |
|
}, |
|
|
|
[PROVIDERS.ollama]: async () => { |
|
|
|
return processedContent.docs.map(doc => ({ |
|
type: "text", |
|
text: `File name: ${doc.metadata.name}\nFile content: ${doc.pageContent}` |
|
})); |
|
}, |
|
|
|
[PROVIDERS.gemini]: async () => { |
|
const content = []; |
|
|
|
|
|
const mediaFiles = [...processedContent.images, ...processedContent.audios, ...processedContent.videos, ...processedContent.pdfs]; |
|
for (const media of mediaFiles) { |
|
content.push({ |
|
type: "media", |
|
mimeType: media.type, |
|
data: Buffer.from(await media.arrayBuffer()).toString("base64") |
|
}); |
|
} |
|
|
|
|
|
for (const doc of processedContent.docs) { |
|
content.push({ |
|
type: "text", |
|
text: `File name: ${doc.metadata.name}\nFile content: ${doc.pageContent}` |
|
}); |
|
} |
|
|
|
return content; |
|
}, |
|
|
|
[PROVIDERS.huggingface]: async () => { |
|
|
|
return processedContent.docs.map(doc => ({ |
|
type: "text", |
|
text: `File name: ${doc.metadata.name}\nFile content: ${doc.pageContent}` |
|
})); |
|
} |
|
}; |
|
|
|
|
|
const formatter = providerFormatters[currentModel.provider]; |
|
if (!formatter) { |
|
throw new Error(`Provider ${currentModel.provider} not implemented for message attachments`); |
|
} |
|
|
|
|
|
const content = await formatter(); |
|
|
|
return new HumanMessage({ |
|
content, |
|
response_metadata: { |
|
documents: documents.map(document => ({ |
|
id: document.id, |
|
name: document.name, |
|
source: document.path, |
|
type: document.type, |
|
createdAt: document.createdAt, |
|
})) |
|
} |
|
}); |
|
} |
|
|
|
async *chat( |
|
sessionId: string, |
|
input: string, |
|
documents?: IDocument[], |
|
) { |
|
const memory = new DexieChatMemory(sessionId); |
|
await memory.initialize(); |
|
|
|
const chatSession = await memory.db.table("sessions").get(sessionId); |
|
|
|
this.model = await this.getChatModel(chatSession?.model || this.config.default_chat_model); |
|
this.embeddings = await this.getEmbeddingModel(chatSession?.embedding_model || this.config.default_embedding_model || null); |
|
|
|
const agent = await this.getAgent(chatSession?.enabled_tools || []); |
|
|
|
const documentMessage = await this.createMessageWithAttachments(documents, chatSession); |
|
|
|
if (documentMessage.content && documentMessage.content.length > 0) { |
|
await memory.addMessage(documentMessage); |
|
} |
|
|
|
const eventStream = await agent.streamEvents( |
|
{ input }, |
|
{ |
|
configurable: { |
|
sessionId, |
|
}, |
|
version: "v2", |
|
signal: this.controller.signal, |
|
} |
|
) |
|
let currentResponse = ""; |
|
for await (const event of eventStream) { |
|
if (event.event === "on_chat_model_stream") { |
|
const chunk = event.data?.chunk; |
|
if (chunk) { |
|
currentResponse += chunk; |
|
yield { type: "stream", content: chunk }; |
|
} |
|
} else if (event.event === "on_chat_model_end") { |
|
yield { type: "end", content: currentResponse, usageMetadata: event.data?.output?.usage_metadata }; |
|
} else if (event.event === "on_tool_start") { |
|
yield { type: "tool_start", name: event.name, input: event.data?.input }; |
|
} else if (event.event === "on_tool_end") { |
|
|
|
console.log(event) |
|
await memory.addMessage(new ToolMessage({ |
|
tool_call_id: event.name, |
|
content: event.data?.output, |
|
name: event.name, |
|
status: "success", |
|
artifact: event.data, |
|
response_metadata: { |
|
input: event.data?.input, |
|
timestamp: Date.now() |
|
} |
|
})); |
|
yield { type: "tool_end", name: event.name, output: event.data?.output }; |
|
} |
|
} |
|
} |
|
|
|
async chatChain( |
|
input: string | HumanMessage, |
|
systemPrompt?: string, |
|
) { |
|
const model = await this.getChatModel(this.config.default_chat_model); |
|
const humanMessage = typeof input === "string" ? new HumanMessage(input) : input; |
|
return await model.invoke([ |
|
{ type: "system", content: systemPrompt || "You are a helpful assistant" }, |
|
humanMessage |
|
]); |
|
} |
|
} |
|
|