Spaces:
Sleeping
Sleeping
import { AutoCastable, ResourcePolicyDenyError, Also, Prop } from 'civkit/civ-rpc'; | |
import { AsyncService } from 'civkit/async-service'; | |
import { getTraceId } from 'civkit/async-context'; | |
import { singleton, container } from 'tsyringe'; | |
import { RateLimitTriggeredError } from './lib/errors'; | |
import { FirestoreRecord } from './lib/firestore'; | |
import { GlobalLogger } from './lib/logger'; | |
export enum API_CALL_STATUS { | |
SUCCESS = 'success', | |
ERROR = 'error', | |
PENDING = 'pending', | |
} | |
dictOf: Object }) | ({|
export class APICall extends FirestoreRecord { | |
static override collectionName = 'apiRoll'; | |
({ | |
required: true, | |
defaultFactory: () => getTraceId() | |
}) | |
traceId!: string; | |
() | |
uid?: string; | |
() | |
ip?: string; | |
({ | |
arrayOf: String, | |
default: [], | |
}) | |
tags!: string[]; | |
({ | |
required: true, | |
defaultFactory: () => new Date(), | |
}) | |
createdAt!: Date; | |
() | |
completedAt?: Date; | |
({ | |
required: true, | |
default: API_CALL_STATUS.PENDING, | |
}) | |
status!: API_CALL_STATUS; | |
({ | |
required: true, | |
defaultFactory: () => new Date(Date.now() + 1000 * 60 * 60 * 24 * 90), | |
}) | |
expireAt!: Date; | |
[k: string]: any; | |
tag(...tags: string[]) { | |
for (const t of tags) { | |
if (!this.tags.includes(t)) { | |
this.tags.push(t); | |
} | |
} | |
} | |
save() { | |
return (this.constructor as typeof APICall).save(this); | |
} | |
} | |
export class RateLimitDesc extends AutoCastable { | |
({ | |
default: 1000 | |
}) | |
occurrence!: number; | |
({ | |
default: 3600 | |
}) | |
periodSeconds!: number; | |
() | |
notBefore?: Date; | |
() | |
notAfter?: Date; | |
isEffective() { | |
const now = new Date(); | |
if (this.notBefore && this.notBefore > now) { | |
return false; | |
} | |
if (this.notAfter && this.notAfter < now) { | |
return false; | |
} | |
return true; | |
} | |
} | |
() | |
export class RateLimitControl extends AsyncService { | |
logger = this.globalLogger.child({ service: this.constructor.name }); | |
constructor( | |
protected globalLogger: GlobalLogger, | |
) { | |
super(...arguments); | |
} | |
override async init() { | |
await this.dependencyReady(); | |
this.emit('ready'); | |
} | |
async queryByUid(uid: string, pointInTime: Date, ...tags: string[]) { | |
let q = APICall.COLLECTION | |
.orderBy('createdAt', 'asc') | |
.where('createdAt', '>=', pointInTime) | |
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING]) | |
.where('uid', '==', uid); | |
if (tags.length) { | |
q = q.where('tags', 'array-contains-any', tags); | |
} | |
return APICall.fromFirestoreQuery(q); | |
} | |
async queryByIp(ip: string, pointInTime: Date, ...tags: string[]) { | |
let q = APICall.COLLECTION | |
.orderBy('createdAt', 'asc') | |
.where('createdAt', '>=', pointInTime) | |
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING]) | |
.where('ip', '==', ip); | |
if (tags.length) { | |
q = q.where('tags', 'array-contains-any', tags); | |
} | |
return APICall.fromFirestoreQuery(q); | |
} | |
async assertUidPeriodicLimit(uid: string, pointInTime: Date, limit: number, ...tags: string[]) { | |
if (limit <= 0) { | |
throw new ResourcePolicyDenyError(`This UID(${uid}) is not allowed to call this endpoint (rate limit quota is 0).`); | |
} | |
let q = APICall.COLLECTION | |
.orderBy('createdAt', 'asc') | |
.where('createdAt', '>=', pointInTime) | |
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING]) | |
.where('uid', '==', uid); | |
if (tags.length) { | |
q = q.where('tags', 'array-contains-any', tags); | |
} | |
const count = (await q.count().get()).data().count; | |
if (count >= limit) { | |
const r = await APICall.fromFirestoreQuery(q.limit(1)); | |
const [r1] = r; | |
const dtMs = Math.abs(r1.createdAt?.valueOf() - pointInTime.valueOf()); | |
const dtSec = Math.ceil(dtMs / 1000); | |
throw RateLimitTriggeredError.from({ | |
message: `Per UID rate limit exceeded (${tags.join(',') || 'called'} ${limit} times since ${pointInTime})`, | |
retryAfter: dtSec, | |
}); | |
} | |
return count + 1; | |
} | |
async assertIPPeriodicLimit(ip: string, pointInTime: Date, limit: number, ...tags: string[]) { | |
let q = APICall.COLLECTION | |
.orderBy('createdAt', 'asc') | |
.where('createdAt', '>=', pointInTime) | |
.where('status', 'in', [API_CALL_STATUS.SUCCESS, API_CALL_STATUS.PENDING]) | |
.where('ip', '==', ip); | |
if (tags.length) { | |
q = q.where('tags', 'array-contains-any', tags); | |
} | |
const count = (await q.count().get()).data().count; | |
if (count >= limit) { | |
const r = await APICall.fromFirestoreQuery(q.limit(1)); | |
const [r1] = r; | |
const dtMs = Math.abs(r1.createdAt?.valueOf() - pointInTime.valueOf()); | |
const dtSec = Math.ceil(dtMs / 1000); | |
throw RateLimitTriggeredError.from({ | |
message: `Per IP rate limit exceeded (${tags.join(',') || 'called'} ${limit} times since ${pointInTime})`, | |
retryAfter: dtSec, | |
}); | |
} | |
return count + 1; | |
} | |
record(partialRecord: Partial<APICall>) { | |
const record = APICall.from(partialRecord); | |
const newId = APICall.COLLECTION.doc().id; | |
record._id = newId; | |
return record; | |
} | |
// async simpleRPCUidBasedLimit(rpcReflect: RPCReflection, uid: string, tags: string[] = [], | |
// ...inputCriterion: RateLimitDesc[] | [Date, number][]) { | |
// const criterion = inputCriterion.map((c) => { return Array.isArray(c) ? c : this.rateLimitDescToCriterion(c); }); | |
// await Promise.all(criterion.map(([pointInTime, n]) => | |
// this.assertUidPeriodicLimit(uid, pointInTime, n, ...tags))); | |
// const r = this.record({ | |
// uid, | |
// tags, | |
// }); | |
// r.save().catch((err) => this.logger.warn(`Failed to save rate limit record`, { err })); | |
// rpcReflect.then(() => { | |
// r.status = API_CALL_STATUS.SUCCESS; | |
// r.save() | |
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err })); | |
// }); | |
// rpcReflect.catch((err) => { | |
// r.status = API_CALL_STATUS.ERROR; | |
// r.error = err.toString(); | |
// r.save() | |
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err })); | |
// }); | |
// return r; | |
// } | |
rateLimitDescToCriterion(rateLimitDesc: RateLimitDesc) { | |
return [new Date(Date.now() - rateLimitDesc.periodSeconds * 1000), rateLimitDesc.occurrence] as [Date, number]; | |
} | |
// async simpleRpcIPBasedLimit(rpcReflect: RPCReflection, ip: string, tags: string[] = [], | |
// ...inputCriterion: RateLimitDesc[] | [Date, number][]) { | |
// const criterion = inputCriterion.map((c) => { return Array.isArray(c) ? c : this.rateLimitDescToCriterion(c); }); | |
// await Promise.all(criterion.map(([pointInTime, n]) => | |
// this.assertIPPeriodicLimit(ip, pointInTime, n, ...tags))); | |
// const r = this.record({ | |
// ip, | |
// tags, | |
// }); | |
// r.save().catch((err) => this.logger.warn(`Failed to save rate limit record`, { err })); | |
// rpcReflect.then(() => { | |
// r.status = API_CALL_STATUS.SUCCESS; | |
// r.save() | |
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err })); | |
// }); | |
// rpcReflect.catch((err) => { | |
// r.status = API_CALL_STATUS.ERROR; | |
// r.error = err.toString(); | |
// r.save() | |
// .catch((err) => this.logger.warn(`Failed to save rate limit record`, { err })); | |
// }); | |
// return r; | |
// } | |
} | |
const instance = container.resolve(RateLimitControl); | |
export default instance; | |