Spaces:
Sleeping
Sleeping
File size: 3,100 Bytes
0bcc252 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import { z } from 'zod';
import {generateObject, LanguageModelUsage, NoObjectGeneratedError} from "ai";
import {TokenTracker} from "./token-tracker";
import {getModel, ToolName, getToolConfig} from "../config";
interface GenerateObjectResult<T> {
object: T;
usage: LanguageModelUsage;
}
interface GenerateOptions<T> {
model: ToolName;
schema: z.ZodType<T>;
prompt: string;
}
export class ObjectGeneratorSafe {
private tokenTracker: TokenTracker;
constructor(tokenTracker?: TokenTracker) {
this.tokenTracker = tokenTracker || new TokenTracker();
}
async generateObject<T>(options: GenerateOptions<T>): Promise<GenerateObjectResult<T>> {
const {
model,
schema,
prompt,
} = options;
try {
// Primary attempt with main model
const result = await generateObject({
model: getModel(model),
schema,
prompt,
maxTokens: getToolConfig(model).maxTokens,
temperature: getToolConfig(model).temperature,
});
this.tokenTracker.trackUsage(model, result.usage);
return result;
} catch (error) {
// First fallback: Try manual JSON parsing of the error response
try {
const errorResult = await this.handleGenerateObjectError<T>(error);
this.tokenTracker.trackUsage(model, errorResult.usage);
return errorResult;
} catch (parseError) {
// Second fallback: Try with fallback model if provided
const fallbackModel = getModel('fallback');
if (NoObjectGeneratedError.isInstance(parseError)) {
const failedOutput = (parseError as any).text;
console.error(`${model} failed on object generation ${failedOutput} -> manual parsing failed again -> trying fallback model`, fallbackModel);
try {
const fallbackResult = await generateObject({
model: fallbackModel,
schema,
prompt: `Extract the desired information from this text: \n ${failedOutput}`,
maxTokens: getToolConfig('fallback').maxTokens,
temperature: getToolConfig('fallback').temperature,
});
this.tokenTracker.trackUsage(model, fallbackResult.usage);
return fallbackResult;
} catch (fallbackError) {
// If fallback model also fails, try parsing its error response
return await this.handleGenerateObjectError<T>(fallbackError);
}
}
// If no fallback model or all attempts failed, throw the original error
throw error;
}
}
}
private async handleGenerateObjectError<T>(error: unknown): Promise<GenerateObjectResult<T>> {
if (NoObjectGeneratedError.isInstance(error)) {
console.error('Object not generated according to schema, fallback to manual JSON parsing');
try {
const partialResponse = JSON.parse((error as any).text);
return {
object: partialResponse as T,
usage: (error as any).usage
};
} catch (parseError) {
throw error;
}
}
throw error;
}
} |