DeepResearch / jina-ai /src /patch-express.ts
zhengr's picture
init
0bcc252
raw
history blame
5.49 kB
import { ApplicationError, Prop, RPC_CALL_ENVIRONMENT } from "civkit/civ-rpc";
import { marshalErrorLike } from "civkit/lang";
import { randomUUID } from "crypto";
import { once } from "events";
import type { NextFunction, Request, Response } from "express";
import { JinaEmbeddingsAuthDTO } from "./dto/jina-embeddings-auth";
import rateLimitControl, { API_CALL_STATUS, RateLimitDesc } from "./rate-limit";
import asyncLocalContext from "./lib/async-context";
import globalLogger from "./lib/logger";
import { InsufficientBalanceError } from "./lib/errors";
import { FirestoreRecord } from "./lib/firestore";
import cors from "cors";
globalLogger.serviceReady();
const logger = globalLogger.child({ service: 'JinaAISaaSMiddleware' });
const appName = 'DEEPRESEARCH';
export class KnowledgeItem extends FirestoreRecord {
static override collectionName = 'knowledgeItems';
@Prop({
required: true
})
traceId!: string;
@Prop({
required: true
})
uid!: string;
@Prop({
default: ''
})
question!: string;
@Prop({
default: ''
})
answer!: string;
@Prop({
default: ''
})
type!: string;
@Prop({
arrayOf: Object,
default: []
})
references!: any[];
@Prop({
defaultFactory: () => new Date()
})
createdAt!: Date;
@Prop({
defaultFactory: () => new Date()
})
updatedAt!: Date;
}
const corsMiddleware = cors();
export const jinaAiMiddleware = (req: Request, res: Response, next: NextFunction) => {
if (req.path === '/ping') {
res.status(200).end('pone');
return;
}
if (req.path.startsWith('/v1/models')) {
next();
return;
}
if (req.method !== 'POST' && req.method !== 'GET') {
next();
return;
}
asyncLocalContext.run(async () => {
const googleTraceId = req.get('x-cloud-trace-context')?.split('/')?.[0];
const ctx = asyncLocalContext.ctx;
ctx.traceId = req.get('x-request-id') || req.get('request-id') || googleTraceId || randomUUID();
ctx.traceT0 = new Date();
ctx.ip = req?.ip;
try {
const authDto = JinaEmbeddingsAuthDTO.from({
[RPC_CALL_ENVIRONMENT]: { req, res }
});
const user = await authDto.assertUser();
const uid = await authDto.assertUID();
if (!(user.wallet.total_balance > 0)) {
throw new InsufficientBalanceError(`Account balance not enough to run this query, please recharge.`);
}
await rateLimitControl.serviceReady();
const rateLimitPolicy = authDto.getRateLimits(appName) || [
parseInt(user.metadata?.speed_level) >= 2 ?
RateLimitDesc.from({
occurrence: 30,
periodSeconds: 60
}) :
RateLimitDesc.from({
occurrence: 10,
periodSeconds: 60
})
];
const criterions = rateLimitPolicy.map((c) => rateLimitControl.rateLimitDescToCriterion(c));
await Promise.all(criterions.map(([pointInTime, n]) => rateLimitControl.assertUidPeriodicLimit(uid, pointInTime, n, appName)));
const apiRoll = rateLimitControl.record({ uid, tags: [appName] })
apiRoll.save().catch((err) => logger.warn(`Failed to save rate limit record`, { err: marshalErrorLike(err) }));
const pResClose = once(res, 'close');
next();
await pResClose;
const chargeAmount = ctx.chargeAmount;
if (chargeAmount) {
authDto.reportUsage(chargeAmount, `reader-${appName}`).catch((err) => {
logger.warn(`Unable to report usage for ${uid}`, { err: marshalErrorLike(err) });
});
apiRoll.chargeAmount = chargeAmount;
}
apiRoll.status = res.statusCode === 200 ? API_CALL_STATUS.SUCCESS : API_CALL_STATUS.ERROR;
apiRoll.save().catch((err) => logger.warn(`Failed to save rate limit record`, { err: marshalErrorLike(err) }));
logger.info(`HTTP ${res.statusCode} for request ${ctx.traceId} after ${Date.now() - ctx.traceT0.valueOf()}ms`, {
uid,
chargeAmount,
});
if (ctx.promptContext?.knowledge?.length) {
Promise.all(ctx.promptContext.knowledge.map((x: any) => KnowledgeItem.save(
KnowledgeItem.from({
...x,
uid,
traceId: ctx.traceId,
})
))).catch((err: any) => {
logger.warn(`Failed to save knowledge`, { err: marshalErrorLike(err) });
});
}
} catch (err: any) {
if (!res.headersSent) {
corsMiddleware(req, res, () => 'noop');
if (err instanceof ApplicationError) {
res.status(parseInt(err.code as string) || 500).json({ error: err.message });
return;
}
res.status(500).json({ error: 'Internal' });
}
logger.error(`Error in billing middleware`, { err: marshalErrorLike(err) });
if (err.stack) {
logger.error(err.stack);
}
}
});
}