File size: 3,456 Bytes
9c9e5d3
 
df46f1b
9c9e5d3
312de82
 
d4febae
b93a813
312de82
 
b93a813
9c9e5d3
e781b23
 
 
 
 
9c9e5d3
 
 
 
e781b23
312de82
 
 
 
b93a813
 
d4febae
9c9e5d3
df46f1b
9c9e5d3
e781b23
 
 
9c9e5d3
 
d4febae
 
e781b23
 
9c9e5d3
df46f1b
 
 
 
 
9c9e5d3
 
 
 
 
 
 
d4febae
 
 
 
df46f1b
 
 
 
9c9e5d3
 
 
 
 
 
 
e781b23
 
df46f1b
 
 
 
 
 
e781b23
 
 
 
 
 
9c9e5d3
 
 
 
 
e781b23
9c9e5d3
 
 
 
e781b23
 
 
 
 
9c9e5d3
e781b23
 
9c9e5d3
 
 
 
e781b23
9c9e5d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import { HfInference } from '@huggingface/inference'
import { RepoFile } from './types.mts'
import { createLlamaCoderPrompt } from './createLlamaCoderPrompt.mts'
import { parseTutorial } from './parseTutorial.mts'
import { getGradioApp } from './getGradioApp.mts'
import { getStreamlitApp } from './getStreamlitApp.mts'
import { getWebApp } from './getWebApp.mts'
import { getReactApp } from './getReactApp.mts'
import { isStreamlitAppPrompt } from './isStreamlitAppPrompt.mts'
import { isPythonOrGradioAppPrompt } from './isPythonOrGradioAppPrompt.mts'
import { isReactAppPrompt } from './isReactAppPrompt.mts'

export const generateFiles = async (
  prompt: string,
  token: string,
  onProgress: (chunk: string) => boolean
  ) => {
  if (`${prompt}`.length < 2) {
    throw new Error(`prompt too short, please enter at least ${prompt} characters`)
  }

  const { prefix, files, instructions } =
  isStreamlitAppPrompt(prompt)
    ? getStreamlitApp(prompt)
    : isPythonOrGradioAppPrompt(prompt)
    ? getGradioApp(prompt)
    : isReactAppPrompt(prompt)
    ? getReactApp(prompt)
    : getWebApp(prompt)

  const inputs = createLlamaCoderPrompt(instructions) + "\nSure! Here are the source files:\n" + prefix

  let isAbortedOrFailed = false

  let tutorial = prefix

  try {
    const hf = new HfInference(token)

    onProgress(prefix)

    for await (const output of hf.textGenerationStream({

      model: "meta-llama/Meta-Llama-3-70B-Instruct",
      // model: "codellama/CodeLlama-34b-Instruct-hf",
      // model: "ise-uiuc/Magicoder-CL-7B" // not supported by Hugging Face right now (no stream + max token is 250)

      inputs,
      parameters: {
        do_sample: true,

        // for  "codellama/CodeLlama-34b-Instruct-hf":
        //  `inputs` tokens + `max_new_tokens` must be <= 8192
        //  error: `inputs` must have less than 4096 tokens.

        // for  "tiiuae/falcon-180B-chat":
        //  `inputs` tokens + `max_new_tokens` must be <= 8192
        //  error: `inputs` must have less than 4096 tokens.

        // for Llama-3 it is 8192
        max_new_tokens: 8192,
        temperature: 0.8,
        return_full_text: false,
      }
    })) {

      tutorial += output.token.text
      process.stdout.write(output.token.text)
      // res.write(output.token.text)
      if (
        tutorial.includes('<|end|>')
        || tutorial.includes('<|eot_id|>')
        || tutorial.includes('<|start_header_id|>assistant<|end_header_id|>')
        || tutorial.includes('</s>')
        || tutorial.includes('[ENDINSTRUCTION]')
        || tutorial.includes('[/TASK]')
        || tutorial.includes('<|assistant|>')) {
        tutorial = tutorial.replaceAll("</s>", "").replaceAll("<|end|>", "")
        break
      }
      if (!onProgress(output.token.text)) {
        console.log("aborting the LLM generation")
        isAbortedOrFailed = true
        break
      }
    }

  } catch (e) {
    isAbortedOrFailed = true
    console.log("failed:")
    console.log(e)
  } 
  
  if (isAbortedOrFailed) {
    console.log("the request was aborted, so we return an empty list")
    return []
  }

  console.log("analyzing the generated instructions..")
  const generatedFiles = parseTutorial(tutorial).map(({ filename, content }) => ({
    path: `${filename || ""}`.trim().replaceAll(" ", ""),
    content: `${content || ""}`
  } as RepoFile))
  .filter(res => res.path.length && res.content.length)

  return [...generatedFiles, ...files]
}