jbilcke-hf HF staff commited on
Commit
ef22617
·
1 Parent(s): ceb44b0

working on SDXL + segmentation

Browse files
src/index.mts CHANGED
@@ -4,7 +4,7 @@ import path from "node:path"
4
  import { validate as uuidValidate } from "uuid"
5
  import express from "express"
6
 
7
- import { Video, VideoStatus, VideoAPIRequest, RenderRequest, RenderAPIResponse } from "./types.mts"
8
  import { parseVideoRequest } from "./utils/parseVideoRequest.mts"
9
  import { savePendingVideo } from "./scheduler/savePendingVideo.mts"
10
  import { getVideo } from "./scheduler/getVideo.mts"
@@ -48,8 +48,8 @@ app.post("/render", async (req, res) => {
48
  return
49
  }
50
 
51
- let result: RenderAPIResponse = {
52
- videoUrl: "",
53
  maskBase64: "",
54
  error: "",
55
  segments: []
@@ -83,15 +83,51 @@ app.post("/render", async (req, res) => {
83
  }
84
  })
85
 
 
86
  /*
87
  app.post("/segment", async (req, res) => {
88
- const payload = req.body as
 
 
 
 
 
 
 
 
 
 
89
  try {
90
- await segmentImage()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  }
92
  })
93
  */
94
 
 
 
95
  app.post("/:ownerId", async (req, res) => {
96
  const request = req.body as VideoAPIRequest
97
 
 
4
  import { validate as uuidValidate } from "uuid"
5
  import express from "express"
6
 
7
+ import { Video, VideoStatus, VideoAPIRequest, RenderRequest, RenderedScene } from "./types.mts"
8
  import { parseVideoRequest } from "./utils/parseVideoRequest.mts"
9
  import { savePendingVideo } from "./scheduler/savePendingVideo.mts"
10
  import { getVideo } from "./scheduler/getVideo.mts"
 
48
  return
49
  }
50
 
51
+ let result: RenderedScene = {
52
+ assetUrl: "",
53
  maskBase64: "",
54
  error: "",
55
  segments: []
 
83
  }
84
  })
85
 
86
+ // a "fast track" pipeline
87
  /*
88
  app.post("/segment", async (req, res) => {
89
+
90
+ const request = req.body as RenderRequest
91
+ console.log(req.body)
92
+
93
+ let result: RenderedScene = {
94
+ assetUrl: "",
95
+ maskBase64: "",
96
+ error: "",
97
+ segments: []
98
+ }
99
+
100
  try {
101
+ result = await renderScene(request)
102
+ } catch (err) {
103
+ // console.log("failed to render scene!")
104
+ result.error = `failed to render scene: ${err}`
105
+ }
106
+
107
+ if (result.error === "already rendering") {
108
+ console.log("server busy")
109
+ res.status(200)
110
+ res.write(JSON.stringify({ url: "", error: result.error }))
111
+ res.end()
112
+ return
113
+ } else if (result.error.length > 0) {
114
+ // console.log("server error")
115
+ res.status(500)
116
+ res.write(JSON.stringify({ url: "", error: result.error }))
117
+ res.end()
118
+ return
119
+ } else {
120
+ // console.log("all good")
121
+ res.status(200)
122
+ res.write(JSON.stringify(result))
123
+ res.end()
124
+ return
125
  }
126
  })
127
  */
128
 
129
+
130
+
131
  app.post("/:ownerId", async (req, res) => {
132
  const request = req.body as VideoAPIRequest
133
 
src/production/renderScene.mts CHANGED
@@ -1,104 +1,13 @@
1
- import { v4 as uuidv4 } from "uuid"
2
-
3
- import { ImageSegment, RenderAPIResponse, RenderRequest } from "../types.mts"
4
- import { downloadFileToTmp } from "../utils/downloadFileToTmp.mts"
5
- import { generateSeed } from "../utils/generateSeed.mts"
6
- import { getValidNumber } from "../utils/getValidNumber.mts"
7
- import { generateVideo } from "./generateVideo.mts"
8
- import { getFirstVideoFrame } from "../utils/getFirstVideoFrame.mts"
9
- import { segmentImage } from "../utils/segmentImage.mts"
10
-
11
- const state = {
12
- isRendering: false
13
- }
14
-
15
- const seed = generateSeed()
16
-
17
- export async function renderScene(scene: RenderRequest): Promise<RenderAPIResponse> {
18
- // console.log("renderScene")
19
-
20
- // let's disable this for now
21
- // this is only reliable if nothing crashes anyway..
22
- /*
23
- if (state.isRendering) {
24
- // console.log("renderScene: isRendering")
25
- return {
26
- videoUrl: "",
27
- error: "already rendering",
28
- maskBase64: "",
29
- segments: [],
30
- }
31
- }
32
- */
33
-
34
- // onsole.log("marking as isRendering")
35
- state.isRendering = true
36
-
37
- let url = ""
38
- let error = ""
39
-
40
- try {
41
- url = await generateVideo(scene.prompt, {
42
- seed: getValidNumber(scene.seed, 0, 4294967295, generateSeed()),
43
- nbFrames: getValidNumber(scene.nbFrames, 8, 24, 16), // 2 seconds by default
44
- nbSteps: getValidNumber(scene.nbSteps, 1, 50, 10), // use 10 by default to go fast, but not too sloppy
45
- })
46
- // console.log("successfull generation")
47
- error = ""
48
- } catch (err) {
49
- error = `failed to render scene: ${err}`
50
  }
51
-
52
-
53
-
54
- // TODO add segmentation here
55
- const actionnables = Array.isArray(scene.actionnables) ? scene.actionnables : []
56
-
57
- let mask = ""
58
- let segments: ImageSegment[] = []
59
-
60
- if (actionnables.length > 0) {
61
- console.log("we have some actionnables:", actionnables)
62
- if (scene.segmentation === "firstframe") {
63
- console.log("going to grab the first frame")
64
- const tmpVideoFilePath = await downloadFileToTmp(url, `${uuidv4()}`)
65
- console.log("downloaded the first frame to ", tmpVideoFilePath)
66
- const firstFrameFilePath = await getFirstVideoFrame(tmpVideoFilePath)
67
- console.log("downloaded the first frame to ", firstFrameFilePath)
68
-
69
- if (!firstFrameFilePath) {
70
- console.error("failed to get the image")
71
- error = "failed to segment the image"
72
- } else {
73
- console.log("got the first frame! segmenting..")
74
- const result = await segmentImage(firstFrameFilePath, actionnables)
75
- mask = result.pngInBase64
76
- segments = result.segments
77
- // console.log("success!", { segments })
78
- }
79
- /*
80
- const jpgBase64 = await getFirstVideoFrame(tmpVideoFileName)
81
- if (!jpgBase64) {
82
- console.error("failed to get the image")
83
- error = "failed to segment the image"
84
- } else {
85
- console.log(`got the first frame (${jpgBase64.length})`)
86
-
87
- console.log("TODO: call segmentImage with the base64 image")
88
- await segmentImage()
89
- }
90
- */
91
- }
92
- }
93
-
94
- // console.log("marking as not rendering anymore")
95
- state.isRendering = false
96
- error = ""
97
-
98
- return {
99
- videoUrl: url,
100
- error,
101
- maskBase64: mask,
102
- segments
103
- } as RenderAPIResponse
104
  }
 
1
+ import { RenderedScene, RenderRequest } from "../types.mts"
2
+ import { renderStaticScene } from "./renderStaticScene.mts"
3
+ import { renderVideoScene } from "./renderVideoScene.mts"
4
+
5
+ export async function renderScene(scene: RenderRequest): Promise<RenderedScene> {
6
+ if (scene?.nbFrames === 1) {
7
+ console.log(`calling renderStaticScene`)
8
+ return renderStaticScene(scene)
9
+ } else {
10
+ console.log(`calling renderVideoScene`)
11
+ return renderVideoScene(scene)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  }
src/production/renderStaticScene.mts ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import path from "node:path"
2
+
3
+ import { v4 as uuidv4 } from "uuid"
4
+ import tmpDir from "temp-dir"
5
+
6
+ import { ImageSegment, RenderedScene, RenderRequest } from "../types.mts"
7
+ import { downloadFileToTmp } from "../utils/downloadFileToTmp.mts"
8
+ import { segmentImage } from "../utils/segmentImage.mts"
9
+ import { generateImageSDXLAsBase64 } from "../utils/generateImageSDXL.mts"
10
+ import { writeBase64ToFile } from "../utils/writeBase64ToFile.mts"
11
+
12
+ export async function renderStaticScene(scene: RenderRequest): Promise<RenderedScene> {
13
+
14
+ let imageBase64 = ""
15
+ let error = ""
16
+
17
+ try {
18
+ console.log(`calling generateImageSDXLAsBase64 with: `, JSON.stringify({
19
+ positivePrompt: scene.prompt,
20
+ seed: scene.seed || undefined,
21
+ nbSteps: scene.nbSteps || undefined,
22
+ width: 1024,
23
+ height: 512
24
+ }, null, 2))
25
+ imageBase64 = await generateImageSDXLAsBase64({
26
+ positivePrompt: scene.prompt,
27
+ seed: scene.seed || undefined,
28
+ nbSteps: scene.nbSteps || undefined,
29
+ width: 1024,
30
+ height: 512
31
+ })
32
+ console.log("successful generation!", imageBase64.slice(0, 30))
33
+ error = ""
34
+ if (!imageBase64?.length) {
35
+ throw new Error(`the generated image is empty`)
36
+ }
37
+ } catch (err) {
38
+ error = `failed to render scene: ${err}`
39
+ return {
40
+ assetUrl: imageBase64,
41
+ error,
42
+ maskBase64: "",
43
+ segments: []
44
+ } as RenderedScene
45
+ }
46
+
47
+ const actionnables = Array.isArray(scene.actionnables) ? scene.actionnables : []
48
+
49
+ let mask = ""
50
+ let segments: ImageSegment[] = []
51
+
52
+ if (actionnables.length > 0) {
53
+ console.log("we have some actionnables:", actionnables)
54
+ console.log("going to grab the first frame")
55
+
56
+ const tmpImageFilePath = path.join(tmpDir, `${uuidv4()}.png`)
57
+
58
+ console.log("beginning:", imageBase64.slice(0, 100))
59
+ await writeBase64ToFile(imageBase64, tmpImageFilePath)
60
+ console.log("wrote the image to ", tmpImageFilePath)
61
+
62
+ if (!tmpImageFilePath) {
63
+ console.error("failed to get the image")
64
+ error = "failed to segment the image"
65
+ } else {
66
+ console.log("got the first frame! segmenting..")
67
+ const result = await segmentImage(tmpImageFilePath, actionnables)
68
+ mask = result.pngInBase64
69
+ segments = result.segments
70
+ console.log("success!", { segments })
71
+ }
72
+ } else {
73
+ console.log("no actionnables: just returning the image, then")
74
+ }
75
+
76
+ error = ""
77
+
78
+ return {
79
+ assetUrl: imageBase64,
80
+ error,
81
+ maskBase64: mask,
82
+ segments
83
+ } as RenderedScene
84
+ }
src/production/renderVideoScene.mts ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { v4 as uuidv4 } from "uuid"
2
+
3
+ import { ImageSegment, RenderedScene, RenderRequest } from "../types.mts"
4
+ import { downloadFileToTmp } from "../utils/downloadFileToTmp.mts"
5
+ import { generateSeed } from "../utils/generateSeed.mts"
6
+ import { getValidNumber } from "../utils/getValidNumber.mts"
7
+ import { generateVideo } from "./generateVideo.mts"
8
+ import { getFirstVideoFrame } from "../utils/getFirstVideoFrame.mts"
9
+ import { segmentImage } from "../utils/segmentImage.mts"
10
+
11
+ export async function renderVideoScene(scene: RenderRequest): Promise<RenderedScene> {
12
+
13
+ let url = ""
14
+ let error = ""
15
+
16
+ try {
17
+ url = await generateVideo(scene.prompt, {
18
+ seed: getValidNumber(scene.seed, 0, 2147483647, generateSeed()),
19
+ nbFrames: getValidNumber(scene.nbFrames, 8, 24, 16), // 2 seconds by default
20
+ nbSteps: getValidNumber(scene.nbSteps, 1, 50, 10), // use 10 by default to go fast, but not too sloppy
21
+ })
22
+ // console.log("successfull generation")
23
+ error = ""
24
+ if (!url?.length) {
25
+ throw new Error(`url for the generated image is empty`)
26
+ }
27
+ } catch (err) {
28
+ error = `failed to render scene: ${err}`
29
+ }
30
+
31
+
32
+
33
+ // TODO add segmentation here
34
+ const actionnables = Array.isArray(scene.actionnables) ? scene.actionnables : []
35
+
36
+ let mask = ""
37
+ let segments: ImageSegment[] = []
38
+
39
+ if (actionnables.length > 0) {
40
+ console.log("we have some actionnables:", actionnables)
41
+ if (scene.segmentation === "firstframe") {
42
+ console.log("going to grab the first frame")
43
+ const tmpVideoFilePath = await downloadFileToTmp(url, `${uuidv4()}`)
44
+ console.log("downloaded the first frame to ", tmpVideoFilePath)
45
+ const firstFrameFilePath = await getFirstVideoFrame(tmpVideoFilePath)
46
+ console.log("downloaded the first frame to ", firstFrameFilePath)
47
+
48
+ if (!firstFrameFilePath) {
49
+ console.error("failed to get the image")
50
+ error = "failed to segment the image"
51
+ } else {
52
+ console.log("got the first frame! segmenting..")
53
+ const result = await segmentImage(firstFrameFilePath, actionnables)
54
+ mask = result.pngInBase64
55
+ segments = result.segments
56
+ // console.log("success!", { segments })
57
+ }
58
+ /*
59
+ const jpgBase64 = await getFirstVideoFrame(tmpVideoFileName)
60
+ if (!jpgBase64) {
61
+ console.error("failed to get the image")
62
+ error = "failed to segment the image"
63
+ } else {
64
+ console.log(`got the first frame (${jpgBase64.length})`)
65
+
66
+ console.log("TODO: call segmentImage with the base64 image")
67
+ await segmentImage()
68
+ }
69
+ */
70
+ }
71
+ }
72
+
73
+ error = ""
74
+
75
+ return {
76
+ assetUrl: url,
77
+ error,
78
+ maskBase64: mask,
79
+ segments
80
+ } as RenderedScene
81
+ }
src/types.mts CHANGED
@@ -302,12 +302,13 @@ export interface ImageSegmentationRequest {
302
  export interface ImageSegment {
303
  id: number
304
  box: number[]
 
305
  label: string
306
  score: number
307
  }
308
 
309
- export interface RenderAPIResponse {
310
- videoUrl: string
311
  error: string
312
  maskBase64: string
313
  segments: ImageSegment[]
 
302
  export interface ImageSegment {
303
  id: number
304
  box: number[]
305
+ color: number[]
306
  label: string
307
  score: number
308
  }
309
 
310
+ export interface RenderedScene {
311
+ assetUrl: string
312
  error: string
313
  maskBase64: string
314
  segments: ImageSegment[]
src/utils/generateImage.mts ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HfInference } from "@huggingface/inference"
2
+ import { getValidNumber } from "./getValidNumber.mts";
3
+ import { generateSeed } from "./generateSeed.mts";
4
+
5
+ const hf = new HfInference(process.env.VC_HF_API_TOKEN)
6
+
7
+ export async function generateImage(options: {
8
+ positivePrompt: string;
9
+ negativePrompt: string;
10
+ seed?: number;
11
+ width?: number;
12
+ height?: number;
13
+ nbSteps?: number;
14
+ }) {
15
+
16
+ const positivePrompt = options?.positivePrompt || ""
17
+ if (!positivePrompt) {
18
+ throw new Error("missing prompt")
19
+ }
20
+ const negativePrompt = options?.negativePrompt || ""
21
+ const seed = getValidNumber(options?.seed, 0, 2147483647, generateSeed())
22
+ const width = getValidNumber(options?.width, 256, 1024, 512)
23
+ const height = getValidNumber(options?.height, 256, 1024, 512)
24
+ const nbSteps = getValidNumber(options?.nbSteps, 5, 50, 25)
25
+
26
+ const blob = await hf.textToImage({
27
+ inputs: [
28
+ positivePrompt,
29
+ "bautiful",
30
+ "award winning",
31
+ "intricate details",
32
+ "high resolution"
33
+ ].filter(word => word)
34
+ .join(", "),
35
+ model: "stabilityai/stable-diffusion-2-1",
36
+ parameters: {
37
+ negative_prompt: [
38
+ negativePrompt,
39
+ "blurry",
40
+ // "artificial",
41
+ // "cropped",
42
+ "low quality",
43
+ "ugly"
44
+ ].filter(word => word)
45
+ .join(", ")
46
+ }
47
+ })
48
+ const buffer = Buffer.from(await blob.arrayBuffer())
49
+
50
+ return buffer
51
+ }
src/utils/generateImageSDXL.mts ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { client } from "@gradio/client"
2
+
3
+ import { generateSeed } from "../utils/generateSeed.mts"
4
+ import { getValidNumber } from "./getValidNumber.mts"
5
+
6
+ // we don't use replicas yet, because it ain't easy to get their hostname
7
+ const instances: string[] = [
8
+ `${process.env.VC_SDXL_SPACE_API_URL_1 || ""}`,
9
+ ].filter(instance => instance?.length > 0)
10
+
11
+ export async function generateImageSDXLAsBase64(options: {
12
+ positivePrompt: string;
13
+ negativePrompt?: string;
14
+ seed?: number;
15
+ width?: number;
16
+ height?: number;
17
+ nbSteps?: number;
18
+ }) {
19
+
20
+ const positivePrompt = options?.positivePrompt || ""
21
+ if (!positivePrompt) {
22
+ throw new Error("missing prompt")
23
+ }
24
+ const negativePrompt = options?.negativePrompt || ""
25
+ const seed = getValidNumber(options?.seed, 0, 2147483647, generateSeed())
26
+ const width = getValidNumber(options?.width, 256, 1024, 512)
27
+ const height = getValidNumber(options?.height, 256, 1024, 512)
28
+ const nbSteps = getValidNumber(options?.nbSteps, 5, 100, 20)
29
+
30
+ const instance = instances.shift()
31
+ instances.push(instance)
32
+
33
+ const positive = [
34
+ positivePrompt,
35
+ "beautiful",
36
+ "award winning",
37
+ "intricate details",
38
+ "high resolution"
39
+ ].filter(word => word)
40
+ .join(", ")
41
+
42
+ const negative = [
43
+ negativePrompt,
44
+ "blurry",
45
+ // "artificial",
46
+ // "cropped",
47
+ "low quality",
48
+ "ugly"
49
+ ].filter(word => word)
50
+ .join(", ")
51
+
52
+ const api = await client(instance, {
53
+ hf_token: `${process.env.VC_HF_API_TOKEN}` as any
54
+ })
55
+
56
+
57
+ const rawResponse = (await api.predict("/run", [
58
+ positive, // string in 'Prompt' Textbox component
59
+ negative, // string in 'Negative prompt' Textbox component
60
+ positive, // string in 'Prompt 2' Textbox component
61
+ negative, // string in 'Negative prompt 2' Textbox component
62
+ true, // boolean in 'Use negative prompt' Checkbox component
63
+ false, // boolean in 'Use prompt 2' Checkbox component
64
+ false, // boolean in 'Use negative prompt 2' Checkbox component
65
+ seed, // number (numeric value between 0 and 2147483647) in 'Seed' Slider component
66
+ width, // number (numeric value between 256 and 1024) in 'Width' Slider component
67
+ height, // number (numeric value between 256 and 1024) in 'Height' Slider component
68
+ 7, // number (numeric value between 1 and 20) in 'Guidance scale for base' Slider component
69
+ 7, // number (numeric value between 1 and 20) in 'Guidance scale for refiner' Slider component
70
+ nbSteps, // number (numeric value between 10 and 100) in 'Number of inference steps for base' Slider component
71
+ nbSteps, // number (numeric value between 10 and 100) in 'Number of inference steps for refiner' Slider component
72
+ true, // boolean in 'Apply refiner' Checkbox component
73
+ ])) as any
74
+
75
+ return rawResponse?.data?.[0] as string
76
+ }
src/utils/parseShotRequest.mts CHANGED
@@ -44,7 +44,7 @@ export const parseShotRequest = async (sequence: VideoSequence, maybeShotMeta: P
44
  actorDialoguePrompt: `${maybeShotMeta.actorDialoguePrompt || ""}`,
45
 
46
  // a video sequence SHOULD NOT HAVE a consistent seed, to avoid weird geometry similarities
47
- seed: getValidNumber(maybeShotMeta.seed, 0, 4294967295, generateSeed()),
48
 
49
  // a video sequence SHOULD HAVE a consistent grain
50
  noise: sequence.noise,
 
44
  actorDialoguePrompt: `${maybeShotMeta.actorDialoguePrompt || ""}`,
45
 
46
  // a video sequence SHOULD NOT HAVE a consistent seed, to avoid weird geometry similarities
47
+ seed: getValidNumber(maybeShotMeta.seed, 0, 2147483647, generateSeed()),
48
 
49
  // a video sequence SHOULD HAVE a consistent grain
50
  noise: sequence.noise,
src/utils/parseVideoRequest.mts CHANGED
@@ -57,7 +57,7 @@ export const parseVideoRequest = async (ownerId: string, request: VideoAPIReques
57
  // describe the main actor dialogue line
58
  actorDialoguePrompt: `${request.sequence.actorDialoguePrompt || ''}`,
59
 
60
- seed: getValidNumber(request.sequence.seed, 0, 4294967295, generateSeed()),
61
 
62
  noise: request.sequence.noise === true,
63
  noiseAmount: request.sequence.noise === true ? 2 : 0,
 
57
  // describe the main actor dialogue line
58
  actorDialoguePrompt: `${request.sequence.actorDialoguePrompt || ''}`,
59
 
60
+ seed: getValidNumber(request.sequence.seed, 0, 2147483647, generateSeed()),
61
 
62
  noise: request.sequence.noise === true,
63
  noiseAmount: request.sequence.noise === true ? 2 : 0,
src/utils/segmentImage.mts CHANGED
@@ -29,7 +29,7 @@ export async function segmentImage(
29
 
30
  const browser = await puppeteer.launch({
31
  headless: true,
32
- protocolTimeout: 70000,
33
  })
34
 
35
  const page = await browser.newPage()
@@ -42,8 +42,6 @@ export async function segmentImage(
42
  // console.log(`uploading file..`)
43
  await fileField.uploadFile(inputImageFilePath)
44
 
45
- await sleep(500)
46
-
47
  const firstTextarea = await page.$('textarea[data-testid="textbox"]')
48
 
49
  const conceptsToDetect = actionnables.join(" . ")
@@ -52,13 +50,13 @@ export async function segmentImage(
52
  // console.log('looking for the button to submit')
53
  const submitButton = await page.$('button.lg')
54
 
55
- await sleep(500)
56
 
57
  // console.log('clicking on the button')
58
  await submitButton.click()
59
 
60
  await page.waitForSelector('img[data-testid="detailed-image"]', {
61
- timeout: 70000, // need to be large enough in case someone else attemps to use our space
62
  })
63
 
64
  const maskUrl = await page.$$eval('img[data-testid="detailed-image"]', el => el.map(x => x.getAttribute("src"))[0])
 
29
 
30
  const browser = await puppeteer.launch({
31
  headless: true,
32
+ protocolTimeout: 120000,
33
  })
34
 
35
  const page = await browser.newPage()
 
42
  // console.log(`uploading file..`)
43
  await fileField.uploadFile(inputImageFilePath)
44
 
 
 
45
  const firstTextarea = await page.$('textarea[data-testid="textbox"]')
46
 
47
  const conceptsToDetect = actionnables.join(" . ")
 
50
  // console.log('looking for the button to submit')
51
  const submitButton = await page.$('button.lg')
52
 
53
+ await sleep(200)
54
 
55
  // console.log('clicking on the button')
56
  await submitButton.click()
57
 
58
  await page.waitForSelector('img[data-testid="detailed-image"]', {
59
+ timeout: 120000, // need to be large enough in case someone else attemps to use our space
60
  })
61
 
62
  const maskUrl = await page.$$eval('img[data-testid="detailed-image"]', el => el.map(x => x.getAttribute("src"))[0])
src/utils/segmentImageFromURL.mts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { v4 as uuidv4 } from "uuid"
2
+
3
+ import { downloadFileToTmp } from "./downloadFileToTmp.mts"
4
+ import { segmentImage } from "./segmentImage.mts"
5
+
6
+ // TODO we should use an inference endpoint instead
7
+
8
+ // note: on a large T4 (8 vCPU)
9
+ // it takes about 30 seconds to compute
10
+ export async function segmentImageFromURL(
11
+ inputUrl: string,
12
+ actionnables: string[]
13
+ ) {
14
+ if (!actionnables?.length) {
15
+ throw new Error("cannot segment image without actionnables!")
16
+ }
17
+ console.log(`segmenting image from URL: "${inputUrl}"`)
18
+ const tmpFileName = `${uuidv4()}`
19
+ const tmpFilePath = await downloadFileToTmp(inputUrl, tmpFileName)
20
+
21
+ const results = await segmentImage(tmpFilePath, actionnables)
22
+
23
+ console.log("image has been segmented!", results)
24
+ return results
25
+ }
src/utils/writeBase64ToFile.mts ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { promises as fs } from "node:fs"
2
+
3
+ export async function writeBase64ToFile(content: string, filePath: string): Promise<void> {
4
+
5
+ // Remove "data:image/png;base64," from the start of the data url
6
+ const base64Data = content.split(",")[1]
7
+
8
+ // Convert base64 to binary
9
+ const data = Buffer.from(base64Data, "base64")
10
+
11
+ // Write binary data to file
12
+ try {
13
+ await fs.writeFile(filePath, data)
14
+ console.log("File written successfully")
15
+ } catch (error) {
16
+ console.error("An error occurred:", error)
17
+ }
18
+ }