File size: 4,554 Bytes
660842c 4905b6b 660842c 4905b6b 660842c 4905b6b 660842c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"use server"
import { StableDiffusionParams } from "@/types"
import { serverHuggingfaceApiKey, serverHuggingfaceInferenceApiFileType, serverHuggingfaceInferenceApiModel, serverHuggingfaceInferenceApiModelRefinerModel, serverHuggingfaceInferenceApiModelTrigger } from "./config"
export async function stableDiffusion({
prompt,
negativePrompt,
guidanceScale,
seed,
width,
height,
numInferenceSteps,
hfApiKey,
}: StableDiffusionParams) {
// throw new Error("Planned maintenance")
if (!prompt) {
const error = `cannot call the rendering API without a prompt, aborting..`
console.error(error)
throw new Error(error)
}
let huggingfaceApiKey = hfApiKey || serverHuggingfaceApiKey
let huggingfaceInferenceApiModel = serverHuggingfaceInferenceApiModel
let huggingfaceInferenceApiModelRefinerModel = serverHuggingfaceInferenceApiModelRefinerModel
let huggingfaceInferenceApiModelTrigger = serverHuggingfaceInferenceApiModelTrigger
let huggingfaceInferenceApiFileType = serverHuggingfaceInferenceApiFileType
try {
if (!huggingfaceApiKey) {
throw new Error(`invalid huggingfaceApiKey, you need to configure your HF_API_TOKEN`)
}
if (!huggingfaceInferenceApiModel) {
throw new Error(`invalid huggingfaceInferenceApiModel, you need to configure your HF_INFERENCE_API_BASE_MODEL`)
}
if (!huggingfaceInferenceApiModelRefinerModel) {
throw new Error(`invalid huggingfaceInferenceApiModelRefinerModel, you need to configure your HF_INFERENCE_API_REFINER_MODEL`)
}
const baseModelUrl = `https://api-inference.huggingface.co/models/${huggingfaceInferenceApiModel}`
const positivePrompt = [
huggingfaceInferenceApiModelTrigger || "",
prompt,
].filter(x => x).join(", ")
const res = await fetch(baseModelUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: huggingfaceInferenceApiFileType,
Authorization: `Bearer ${huggingfaceApiKey}`,
},
body: JSON.stringify({
inputs: positivePrompt,
parameters: {
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale,
width,
height,
},
// this doesn't do what you think it does
use_cache: false, // withCache,
}),
cache: "no-store",
// we can also use this (see https://vercel.com/blog/vercel-cache-api-nextjs-cache)
// next: { revalidate: 1 }
})
// Recommendation: handle errors
if (res.status !== 200) {
const content = await res.text()
console.error(content)
// This will activate the closest `error.js` Error Boundary
throw new Error('Failed to fetch data')
}
const blob = await res.arrayBuffer()
const contentType = res.headers.get('content-type')
let assetUrl = `data:${contentType};base64,${Buffer.from(blob).toString('base64')}`
try {
const refinerModelUrl = `https://api-inference.huggingface.co/models/${huggingfaceInferenceApiModelRefinerModel}`
const res = await fetch(refinerModelUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${huggingfaceApiKey}`,
},
body: JSON.stringify({
inputs: Buffer.from(blob).toString('base64'),
parameters: {
prompt: positivePrompt,
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale,
width,
height,
},
// this doesn't do what you think it does
use_cache: false, // withCache,
}),
cache: "no-store",
// we can also use this (see https://vercel.com/blog/vercel-cache-api-nextjs-cache)
// next: { revalidate: 1 }
})
// Recommendation: handle errors
if (res.status !== 200) {
const content = await res.json()
// if (content.error.include("currently loading")) {
// console.log("refiner isn't ready yet")
throw new Error(content?.error || 'Failed to fetch data')
}
const refinedBlob = await res.arrayBuffer()
const contentType = res.headers.get('content-type')
assetUrl = `data:${contentType};base64,${Buffer.from(refinedBlob).toString('base64')}`
} catch (err) {
console.log(`Refiner step failed, but this is not a blocker. Error details: ${err}`)
}
return assetUrl
} catch (err) {
console.error(err)
return ""
}
}
|