File size: 4,115 Bytes
ec194c9
 
 
 
 
 
 
 
 
cd4ee95
74bfab8
 
ec194c9
 
cd4ee95
ec194c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74bfab8
ec194c9
 
 
74bfab8
ec194c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"use server"

import Replicate from "replicate"

import { generateSeed } from "../../utils/misc/generateSeed.mts"
import { sleep } from "../../utils/misc/sleep.mts"
import { getNegativePrompt, getPositivePrompt } from "./defaultPrompts.mts"
import { VideoGenerationOptions } from "./types.mts"

const replicateToken = `${process.env.VC_REPLICATE_API_TOKEN || ""}`
const replicateModel = `${process.env.VC_HOTSHOT_XL_REPLICATE_MODEL || ""}`
const replicateModelVersion = `${process.env.VC_HOTSHOT_XL_REPLICATE_MODEL_VERSION || ""}`

if (!replicateToken) {
  throw new Error(`you need to configure your VC_REPLICATE_API_TOKEN`)
}

const replicate = new Replicate({ auth: replicateToken })

/**
 * Generate a video with hotshot through Replicate
 * 
 * Note that if nbFrames == 1, then it will generate a jpg
 * 
 */
export async function generateVideoWithHotshotReplicate({
    positivePrompt,
    negativePrompt = "",
    seed,
    nbFrames = 8, // for now the only values that make sense are 1 (for a jpg) or 8 (for a video)
    videoDuration = 1000, // for now Hotshot doesn't really supports anything else
    nbSteps = 30, // when rendering a final video, we want a value like 50 or 70 here
    size = "768x320",

    // for a replicate LoRa this is always the same ("In the style of TOK")
    // triggerWord = "In the style of TOK",

    // for jbilcke-hf/sdxl-cinematic-2 it is "cinematic-2"
    triggerWord = "cinematic-2",

    huggingFaceLora = "jbilcke-hf/sdxl-cinematic-2",

    // url to the weight
    replicateLora,
  }: VideoGenerationOptions): Promise<string> {

  if (!positivePrompt?.length) {
    throw new Error(`prompt is too short!`)
  }

  if (!replicateModel) {
    throw new Error(`you need to configure your VC_HOTSHOT_XL_REPLICATE_MODEL`)
  }

  if (!replicateModelVersion) {
    throw new Error(`you need to configure your VC_HOTSHOT_XL_REPLICATE_MODEL_VERSION`)
  }

  // pimp the prompt
  positivePrompt = getPositivePrompt(positivePrompt, triggerWord)
  negativePrompt = getNegativePrompt(negativePrompt)

  const [width, height] = size.split("x").map(x => Number(x))
  
  // see an example here: 
  // https://replicate.com/p/incraplbv23g3zv6woinhgdira
  // for params and doc see https://replicate.com/cloneofsimo/hotshot-xl-lora-controlnet
  const prediction = await replicate.predictions.create({
    version: replicateModelVersion,
    input: {
      prompt: positivePrompt,
      negative_prompt: negativePrompt,

      // this is not a URL but a model name
      hf_lora_url: replicateLora?.length ? undefined : huggingFaceLora,

      // this is a URL to the .tar (we can get it from the "trainings" page)
      replicate_weights_url: huggingFaceLora?.length ? undefined : replicateLora,

      width,
      height,

      // those are used to create an upsampling or downsampling
      // original_width: width,
      // original_height: height,
      // target_width: width,
      // target_height: height,

      steps: nbSteps,
    
      
      // note: right now it only makes sense to use either 1 (a jpg)
      video_length: nbFrames, // nb frames

      video_duration: videoDuration, // video duration in ms
      
      seed: !isNaN(seed) && isFinite(seed) ? seed : generateSeed()
    }
  })
    
  // console.log("prediction:", prediction)

  // Replicate requires at least 30 seconds of mandatory delay
  await sleep(30000)

  let res: Response
  let pollingCount = 0
  do {
    // Check every 5 seconds
    await sleep(5000)

    res = await fetch(`https://api.replicate.com/v1/predictions/${prediction.id}`, {
      method: "GET",
      headers: {
        Authorization: `Token ${replicateToken}`,
      },
      cache: 'no-store',
    })

    if (res.status === 200) {
      const response = (await res.json()) as any
      const error = `${response?.error || ""}`
      if (error) {
        throw new Error(error)
      }
    }

    pollingCount++

    // To prevent indefinite polling, we can stop after a certain number, here 30 (i.e. about 2 and half minutes)
    if (pollingCount >= 30) {
      throw new Error('Request time out.')
    }
  } while (true)
}