Spaces:
Sleeping
Sleeping
import { ApplicationError, Prop, RPC_CALL_ENVIRONMENT } from "civkit/civ-rpc"; | |
import { marshalErrorLike } from "civkit/lang"; | |
import { randomUUID } from "crypto"; | |
import { once } from "events"; | |
import type { NextFunction, Request, Response } from "express"; | |
import { JinaEmbeddingsAuthDTO } from "./dto/jina-embeddings-auth"; | |
import rateLimitControl, { API_CALL_STATUS, RateLimitDesc } from "./rate-limit"; | |
import asyncLocalContext from "./lib/async-context"; | |
import globalLogger from "./lib/logger"; | |
import { InsufficientBalanceError } from "./lib/errors"; | |
import { FirestoreRecord } from "./lib/firestore"; | |
import cors from "cors"; | |
globalLogger.serviceReady(); | |
const logger = globalLogger.child({ service: 'JinaAISaaSMiddleware' }); | |
const appName = 'DEEPRESEARCH'; | |
export class KnowledgeItem extends FirestoreRecord { | |
static override collectionName = 'knowledgeItems'; | |
({ | |
required: true | |
}) | |
traceId!: string; | |
({ | |
required: true | |
}) | |
uid!: string; | |
({ | |
default: '' | |
}) | |
question!: string; | |
({ | |
default: '' | |
}) | |
answer!: string; | |
({ | |
default: '' | |
}) | |
type!: string; | |
({ | |
arrayOf: Object, | |
default: [] | |
}) | |
references!: any[]; | |
({ | |
defaultFactory: () => new Date() | |
}) | |
createdAt!: Date; | |
({ | |
defaultFactory: () => new Date() | |
}) | |
updatedAt!: Date; | |
} | |
const corsMiddleware = cors(); | |
export const jinaAiMiddleware = (req: Request, res: Response, next: NextFunction) => { | |
if (req.path === '/ping') { | |
res.status(200).end('pone'); | |
return; | |
} | |
if (req.path.startsWith('/v1/models')) { | |
next(); | |
return; | |
} | |
if (req.method !== 'POST' && req.method !== 'GET') { | |
next(); | |
return; | |
} | |
asyncLocalContext.run(async () => { | |
const googleTraceId = req.get('x-cloud-trace-context')?.split('/')?.[0]; | |
const ctx = asyncLocalContext.ctx; | |
ctx.traceId = req.get('x-request-id') || req.get('request-id') || googleTraceId || randomUUID(); | |
ctx.traceT0 = new Date(); | |
ctx.ip = req?.ip; | |
try { | |
const authDto = JinaEmbeddingsAuthDTO.from({ | |
[RPC_CALL_ENVIRONMENT]: { req, res } | |
}); | |
const user = await authDto.assertUser(); | |
const uid = await authDto.assertUID(); | |
if (!(user.wallet.total_balance > 0)) { | |
throw new InsufficientBalanceError(`Account balance not enough to run this query, please recharge.`); | |
} | |
await rateLimitControl.serviceReady(); | |
const rateLimitPolicy = authDto.getRateLimits(appName) || [ | |
parseInt(user.metadata?.speed_level) >= 2 ? | |
RateLimitDesc.from({ | |
occurrence: 30, | |
periodSeconds: 60 | |
}) : | |
RateLimitDesc.from({ | |
occurrence: 10, | |
periodSeconds: 60 | |
}) | |
]; | |
const criterions = rateLimitPolicy.map((c) => rateLimitControl.rateLimitDescToCriterion(c)); | |
await Promise.all(criterions.map(([pointInTime, n]) => rateLimitControl.assertUidPeriodicLimit(uid, pointInTime, n, appName))); | |
const apiRoll = rateLimitControl.record({ uid, tags: [appName] }) | |
apiRoll.save().catch((err) => logger.warn(`Failed to save rate limit record`, { err: marshalErrorLike(err) })); | |
const pResClose = once(res, 'close'); | |
next(); | |
await pResClose; | |
const chargeAmount = ctx.chargeAmount; | |
if (chargeAmount) { | |
authDto.reportUsage(chargeAmount, `reader-${appName}`).catch((err) => { | |
logger.warn(`Unable to report usage for ${uid}`, { err: marshalErrorLike(err) }); | |
}); | |
apiRoll.chargeAmount = chargeAmount; | |
} | |
apiRoll.status = res.statusCode === 200 ? API_CALL_STATUS.SUCCESS : API_CALL_STATUS.ERROR; | |
apiRoll.save().catch((err) => logger.warn(`Failed to save rate limit record`, { err: marshalErrorLike(err) })); | |
logger.info(`HTTP ${res.statusCode} for request ${ctx.traceId} after ${Date.now() - ctx.traceT0.valueOf()}ms`, { | |
uid, | |
chargeAmount, | |
}); | |
if (ctx.promptContext?.knowledge?.length) { | |
Promise.all(ctx.promptContext.knowledge.map((x: any) => KnowledgeItem.save( | |
KnowledgeItem.from({ | |
...x, | |
uid, | |
traceId: ctx.traceId, | |
}) | |
))).catch((err: any) => { | |
logger.warn(`Failed to save knowledge`, { err: marshalErrorLike(err) }); | |
}); | |
} | |
} catch (err: any) { | |
if (!res.headersSent) { | |
corsMiddleware(req, res, () => 'noop'); | |
if (err instanceof ApplicationError) { | |
res.status(parseInt(err.code as string) || 500).json({ error: err.message }); | |
return; | |
} | |
res.status(500).json({ error: 'Internal' }); | |
} | |
logger.error(`Error in billing middleware`, { err: marshalErrorLike(err) }); | |
if (err.stack) { | |
logger.error(err.stack); | |
} | |
} | |
}); | |
} |