File size: 2,975 Bytes
f152ae2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import type { ChatCompletionChunk as WebLlmChatCompletionChunk } from "@mlc-ai/web-llm";
import type { ChatMessage } from "gpt-tokenizer/GptEncoding";
import type { ChatCompletionChunk } from "openai/resources/chat/completions.mjs";
import type { Stream } from "openai/streaming.mjs";
import {
  getQuery,
  getSearchPromise,
  getSearchResults,
  getSettings,
  getTextGenerationState,
  updateTextGenerationState,
} from "./pubSub";
import { defaultSettings } from "./settings";
import { getSystemPrompt } from "./systemPrompt";

export class ChatGenerationError extends Error {
  constructor(message: string) {
    super(message);
    this.name = "ChatGenerationError";
  }
}

export function getFormattedSearchResults(shouldIncludeUrl: boolean) {
  const searchResults = getSearchResults().textResults.slice(
    0,
    getSettings().searchResultsToConsider,
  );

  if (searchResults.length === 0) return "None.";

  if (shouldIncludeUrl) {
    return searchResults
      .map(
        ([title, snippet, url], index) =>
          `${index + 1}. [${title}](${url}) | ${snippet}`,
      )
      .join("\n");
  }

  return searchResults
    .map(([title, snippet]) => `- ${title} | ${snippet}`)
    .join("\n");
}

export async function canStartResponding() {
  if (getSettings().searchResultsToConsider > 0) {
    updateTextGenerationState("awaitingSearchResults");
    await getSearchPromise();
  }
}

export function getDefaultChatCompletionCreateParamsStreaming() {
  return {
    stream: true,
    max_tokens: 2048,
    temperature: defaultSettings.inferenceTemperature,
    top_p: defaultSettings.inferenceTopP,
    frequency_penalty: defaultSettings.inferenceFrequencyPenalty,
    presence_penalty: defaultSettings.inferencePresencePenalty,
  } as const;
}

export async function handleStreamingResponse(
  completion:
    | Stream<ChatCompletionChunk>
    | AsyncIterable<WebLlmChatCompletionChunk>,
  onChunk: (streamedMessage: string) => void,
  options?: {
    abortController?: { abort: () => void };
    shouldUpdateGeneratingState?: boolean;
  },
) {
  let streamedMessage = "";

  for await (const chunk of completion) {
    const deltaContent = chunk.choices[0].delta.content;

    if (deltaContent) {
      streamedMessage += deltaContent;
      onChunk(streamedMessage);
    }

    if (getTextGenerationState() === "interrupted") {
      if (options?.abortController) {
        options.abortController.abort();
      }
      throw new ChatGenerationError("Chat generation interrupted");
    }

    if (
      options?.shouldUpdateGeneratingState &&
      getTextGenerationState() !== "generating"
    ) {
      updateTextGenerationState("generating");
    }
  }

  return streamedMessage;
}

export function getDefaultChatMessages(searchResults: string): ChatMessage[] {
  return [
    {
      role: "user",
      content: getSystemPrompt(searchResults),
    },
    { role: "assistant", content: "Ok!" },
    { role: "user", content: getQuery() },
  ];
}