Spaces:
Running
Running
Expose sampling controls in assistants (#955) (#959)
Browse files* Expose sampling controls in assistants (#955)
* Make sure all labels have the same font size
* styling
* Add better tooltips
* better padding & wrapping
* Revert "better padding & wrapping"
This reverts commit 1b44086465040f2cb6bc906983cfc8d95820d6fe.
* ui update
* tooltip on mobile
* lint
* Update src/lib/components/AssistantSettings.svelte
Co-authored-by: Mishig <[email protected]>
---------
Co-authored-by: Victor Mustar <[email protected]>
Co-authored-by: Mishig <[email protected]>
- src/lib/components/AssistantSettings.svelte +128 -17
- src/lib/components/HoverTooltip.svelte +12 -0
- src/lib/server/endpoints/anthropic/endpointAnthropic.ts +9 -6
- src/lib/server/endpoints/aws/endpointAws.ts +2 -2
- src/lib/server/endpoints/endpoints.ts +2 -0
- src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts +9 -7
- src/lib/server/endpoints/ollama/endpointOllama.ts +9 -7
- src/lib/server/endpoints/openai/endpointOai.ts +16 -12
- src/lib/server/endpoints/tgi/endpointTgi.ts +2 -2
- src/lib/types/Assistant.ts +6 -0
- src/routes/conversation/[id]/+server.ts +10 -3
- src/routes/settings/(nav)/assistants/[assistantId]/edit/+page.server.ts +20 -0
- src/routes/settings/(nav)/assistants/new/+page.server.ts +20 -0
src/lib/components/AssistantSettings.svelte
CHANGED
@@ -9,11 +9,14 @@
|
|
9 |
import { base } from "$app/paths";
|
10 |
import CarbonPen from "~icons/carbon/pen";
|
11 |
import CarbonUpload from "~icons/carbon/upload";
|
|
|
|
|
12 |
|
13 |
import { useSettingsStore } from "$lib/stores/settings";
|
14 |
import { isHuggingChat } from "$lib/utils/isHuggingChat";
|
15 |
import IconInternet from "./icons/IconInternet.svelte";
|
16 |
import TokensCounter from "./TokensCounter.svelte";
|
|
|
17 |
|
18 |
type ActionData = {
|
19 |
error: boolean;
|
@@ -31,16 +34,22 @@
|
|
31 |
|
32 |
let files: FileList | null = null;
|
33 |
const settings = useSettingsStore();
|
34 |
-
let modelId =
|
35 |
-
assistant?.modelId ?? models.find((_model) => _model.id === $settings.activeModel)?.name;
|
36 |
let systemPrompt = assistant?.preprompt ?? "";
|
37 |
let dynamicPrompt = assistant?.dynamicPrompt ?? false;
|
|
|
38 |
|
39 |
let compress: typeof readAndCompressImage | null = null;
|
40 |
|
41 |
onMount(async () => {
|
42 |
const module = await import("browser-image-resizer");
|
43 |
compress = module.readAndCompressImage;
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
});
|
45 |
|
46 |
let inputMessage1 = assistant?.exampleInputs[0] ?? "";
|
@@ -89,11 +98,12 @@
|
|
89 |
|
90 |
const regex = /{{\s?url=(.+?)\s?}}/g;
|
91 |
$: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
|
|
|
92 |
</script>
|
93 |
|
94 |
<form
|
95 |
method="POST"
|
96 |
-
class="flex h-full flex-col overflow-y-auto p-4 md:p-8"
|
97 |
enctype="multipart/form-data"
|
98 |
use:enhance={async ({ formData }) => {
|
99 |
loading = true;
|
@@ -246,21 +256,122 @@
|
|
246 |
|
247 |
<label>
|
248 |
<div class="mb-1 font-semibold">Model</div>
|
249 |
-
<
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
>
|
254 |
-
{
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
</label>
|
265 |
|
266 |
<label>
|
|
|
9 |
import { base } from "$app/paths";
|
10 |
import CarbonPen from "~icons/carbon/pen";
|
11 |
import CarbonUpload from "~icons/carbon/upload";
|
12 |
+
import CarbonHelpFilled from "~icons/carbon/help";
|
13 |
+
import CarbonSettingsAdjust from "~icons/carbon/settings-adjust";
|
14 |
|
15 |
import { useSettingsStore } from "$lib/stores/settings";
|
16 |
import { isHuggingChat } from "$lib/utils/isHuggingChat";
|
17 |
import IconInternet from "./icons/IconInternet.svelte";
|
18 |
import TokensCounter from "./TokensCounter.svelte";
|
19 |
+
import HoverTooltip from "./HoverTooltip.svelte";
|
20 |
|
21 |
type ActionData = {
|
22 |
error: boolean;
|
|
|
34 |
|
35 |
let files: FileList | null = null;
|
36 |
const settings = useSettingsStore();
|
37 |
+
let modelId = "";
|
|
|
38 |
let systemPrompt = assistant?.preprompt ?? "";
|
39 |
let dynamicPrompt = assistant?.dynamicPrompt ?? false;
|
40 |
+
let showModelSettings = Object.values(assistant?.generateSettings ?? {}).some((v) => !!v);
|
41 |
|
42 |
let compress: typeof readAndCompressImage | null = null;
|
43 |
|
44 |
onMount(async () => {
|
45 |
const module = await import("browser-image-resizer");
|
46 |
compress = module.readAndCompressImage;
|
47 |
+
|
48 |
+
if (assistant) {
|
49 |
+
modelId = assistant.modelId;
|
50 |
+
} else {
|
51 |
+
modelId = models.find((model) => model.id === $settings.activeModel)?.id ?? models[0].id;
|
52 |
+
}
|
53 |
});
|
54 |
|
55 |
let inputMessage1 = assistant?.exampleInputs[0] ?? "";
|
|
|
98 |
|
99 |
const regex = /{{\s?url=(.+?)\s?}}/g;
|
100 |
$: templateVariables = [...systemPrompt.matchAll(regex)].map((match) => match[1]);
|
101 |
+
$: selectedModel = models.find((m) => m.id === modelId);
|
102 |
</script>
|
103 |
|
104 |
<form
|
105 |
method="POST"
|
106 |
+
class="relative flex h-full flex-col overflow-y-auto p-4 md:p-8"
|
107 |
enctype="multipart/form-data"
|
108 |
use:enhance={async ({ formData }) => {
|
109 |
loading = true;
|
|
|
256 |
|
257 |
<label>
|
258 |
<div class="mb-1 font-semibold">Model</div>
|
259 |
+
<div class="flex gap-2">
|
260 |
+
<select
|
261 |
+
name="modelId"
|
262 |
+
class="w-full rounded-lg border-2 border-gray-200 bg-gray-100 p-2"
|
263 |
+
bind:value={modelId}
|
264 |
+
>
|
265 |
+
{#each models.filter((model) => !model.unlisted) as model}
|
266 |
+
<option value={model.id}>{model.displayName}</option>
|
267 |
+
{/each}
|
268 |
+
<p class="text-xs text-red-500">{getError("modelId", form)}</p>
|
269 |
+
</select>
|
270 |
+
<button
|
271 |
+
type="button"
|
272 |
+
class="flex aspect-square items-center gap-2 whitespace-nowrap rounded-lg border px-3 {showModelSettings
|
273 |
+
? 'border-blue-500/20 bg-blue-50 text-blue-600'
|
274 |
+
: ''}"
|
275 |
+
on:click={() => (showModelSettings = !showModelSettings)}
|
276 |
+
><CarbonSettingsAdjust class="text-xs" /></button
|
277 |
+
>
|
278 |
+
</div>
|
279 |
+
<div
|
280 |
+
class="mt-2 rounded-lg border border-blue-500/20 bg-blue-500/5 px-2 py-0.5"
|
281 |
+
class:hidden={!showModelSettings}
|
282 |
>
|
283 |
+
<p class="text-xs text-red-500">{getError("inputMessage1", form)}</p>
|
284 |
+
<div class="my-2 grid grid-cols-1 gap-2.5 sm:grid-cols-2 sm:grid-rows-2">
|
285 |
+
<label for="temperature" class="flex justify-between">
|
286 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
287 |
+
Temperature
|
288 |
+
|
289 |
+
<HoverTooltip
|
290 |
+
label="Temperature: Controls creativity, higher values allow more variety."
|
291 |
+
>
|
292 |
+
<CarbonHelpFilled
|
293 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
294 |
+
/>
|
295 |
+
</HoverTooltip>
|
296 |
+
</span>
|
297 |
+
<input
|
298 |
+
type="number"
|
299 |
+
name="temperature"
|
300 |
+
min="0.1"
|
301 |
+
max="2"
|
302 |
+
step="0.1"
|
303 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
304 |
+
placeholder={selectedModel?.parameters?.temperature?.toString() ?? "1"}
|
305 |
+
value={assistant?.generateSettings?.temperature ?? ""}
|
306 |
+
/>
|
307 |
+
</label>
|
308 |
+
<label for="top_p" class="flex justify-between">
|
309 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
310 |
+
Top P
|
311 |
+
<HoverTooltip
|
312 |
+
label="Top P: Sets word choice boundaries, lower values tighten focus."
|
313 |
+
>
|
314 |
+
<CarbonHelpFilled
|
315 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
316 |
+
/>
|
317 |
+
</HoverTooltip>
|
318 |
+
</span>
|
319 |
+
|
320 |
+
<input
|
321 |
+
type="number"
|
322 |
+
name="top_p"
|
323 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
324 |
+
min="0.05"
|
325 |
+
max="1"
|
326 |
+
step="0.05"
|
327 |
+
placeholder={selectedModel?.parameters?.top_p?.toString() ?? "1"}
|
328 |
+
value={assistant?.generateSettings?.top_p ?? ""}
|
329 |
+
/>
|
330 |
+
</label>
|
331 |
+
<label for="repetition_penalty" class="flex justify-between">
|
332 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
333 |
+
Repetition penalty
|
334 |
+
<HoverTooltip
|
335 |
+
label="Repetition penalty: Prevents reuse, higher values decrease repetition."
|
336 |
+
>
|
337 |
+
<CarbonHelpFilled
|
338 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
339 |
+
/>
|
340 |
+
</HoverTooltip>
|
341 |
+
</span>
|
342 |
+
<input
|
343 |
+
type="number"
|
344 |
+
name="repetition_penalty"
|
345 |
+
min="0.1"
|
346 |
+
max="2"
|
347 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
348 |
+
placeholder={selectedModel?.parameters?.repetition_penalty?.toString() ?? "1.0"}
|
349 |
+
value={assistant?.generateSettings?.repetition_penalty ?? ""}
|
350 |
+
/>
|
351 |
+
</label>
|
352 |
+
<label for="top_k" class="flex justify-between">
|
353 |
+
<span class="m-1 ml-0 flex items-center gap-1.5 whitespace-nowrap text-sm">
|
354 |
+
Top K <HoverTooltip
|
355 |
+
label="Top K: Restricts word options, lower values for predictability."
|
356 |
+
>
|
357 |
+
<CarbonHelpFilled
|
358 |
+
class="inline text-xxs text-gray-500 group-hover/tooltip:text-blue-600"
|
359 |
+
/>
|
360 |
+
</HoverTooltip>
|
361 |
+
</span>
|
362 |
+
<input
|
363 |
+
type="number"
|
364 |
+
name="top_k"
|
365 |
+
min="5"
|
366 |
+
max="100"
|
367 |
+
step="5"
|
368 |
+
class="w-20 rounded-lg border-2 border-gray-200 bg-gray-100 px-2 py-1"
|
369 |
+
placeholder={selectedModel?.parameters?.top_k?.toString() ?? "50"}
|
370 |
+
value={assistant?.generateSettings?.top_k ?? ""}
|
371 |
+
/>
|
372 |
+
</label>
|
373 |
+
</div>
|
374 |
+
</div>
|
375 |
</label>
|
376 |
|
377 |
<label>
|
src/lib/components/HoverTooltip.svelte
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
export let label = "";
|
3 |
+
</script>
|
4 |
+
|
5 |
+
<div class="group/tooltip md:relative">
|
6 |
+
<slot />
|
7 |
+
<div
|
8 |
+
class="invisible absolute z-10 w-64 whitespace-normal rounded-md bg-black p-2 text-center text-white group-hover/tooltip:visible group-active/tooltip:visible max-sm:left-1/2 max-sm:-translate-x-1/2"
|
9 |
+
>
|
10 |
+
{label}
|
11 |
+
</div>
|
12 |
+
</div>
|
src/lib/server/endpoints/anthropic/endpointAnthropic.ts
CHANGED
@@ -32,7 +32,7 @@ export async function endpointAnthropic(
|
|
32 |
defaultQuery,
|
33 |
});
|
34 |
|
35 |
-
return async ({ messages, preprompt }) => {
|
36 |
let system = preprompt;
|
37 |
if (messages?.[0]?.from === "system") {
|
38 |
system = messages[0].content;
|
@@ -49,15 +49,18 @@ export async function endpointAnthropic(
|
|
49 |
}[];
|
50 |
|
51 |
let tokenId = 0;
|
|
|
|
|
|
|
52 |
return (async function* () {
|
53 |
const stream = anthropic.messages.stream({
|
54 |
model: model.id ?? model.name,
|
55 |
messages: messagesFormatted,
|
56 |
-
max_tokens:
|
57 |
-
temperature:
|
58 |
-
top_p:
|
59 |
-
top_k:
|
60 |
-
stop_sequences:
|
61 |
system,
|
62 |
});
|
63 |
while (true) {
|
|
|
32 |
defaultQuery,
|
33 |
});
|
34 |
|
35 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
36 |
let system = preprompt;
|
37 |
if (messages?.[0]?.from === "system") {
|
38 |
system = messages[0].content;
|
|
|
49 |
}[];
|
50 |
|
51 |
let tokenId = 0;
|
52 |
+
|
53 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
54 |
+
|
55 |
return (async function* () {
|
56 |
const stream = anthropic.messages.stream({
|
57 |
model: model.id ?? model.name,
|
58 |
messages: messagesFormatted,
|
59 |
+
max_tokens: parameters?.max_new_tokens,
|
60 |
+
temperature: parameters?.temperature,
|
61 |
+
top_p: parameters?.top_p,
|
62 |
+
top_k: parameters?.top_k,
|
63 |
+
stop_sequences: parameters?.stop,
|
64 |
system,
|
65 |
});
|
66 |
while (true) {
|
src/lib/server/endpoints/aws/endpointAws.ts
CHANGED
@@ -36,7 +36,7 @@ export async function endpointAws(
|
|
36 |
region,
|
37 |
});
|
38 |
|
39 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
40 |
const prompt = await buildPrompt({
|
41 |
messages,
|
42 |
continueMessage,
|
@@ -46,7 +46,7 @@ export async function endpointAws(
|
|
46 |
|
47 |
return textGenerationStream(
|
48 |
{
|
49 |
-
parameters: { ...model.parameters, return_full_text: false },
|
50 |
model: url,
|
51 |
inputs: prompt,
|
52 |
},
|
|
|
36 |
region,
|
37 |
});
|
38 |
|
39 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
40 |
const prompt = await buildPrompt({
|
41 |
messages,
|
42 |
continueMessage,
|
|
|
46 |
|
47 |
return textGenerationStream(
|
48 |
{
|
49 |
+
parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
|
50 |
model: url,
|
51 |
inputs: prompt,
|
52 |
},
|
src/lib/server/endpoints/endpoints.ts
CHANGED
@@ -10,12 +10,14 @@ import {
|
|
10 |
endpointAnthropic,
|
11 |
endpointAnthropicParametersSchema,
|
12 |
} from "./anthropic/endpointAnthropic";
|
|
|
13 |
|
14 |
// parameters passed when generating text
|
15 |
export interface EndpointParameters {
|
16 |
messages: Omit<Conversation["messages"][0], "id">[];
|
17 |
preprompt?: Conversation["preprompt"];
|
18 |
continueMessage?: boolean; // used to signal that the last message will be extended
|
|
|
19 |
}
|
20 |
|
21 |
interface CommonEndpoint {
|
|
|
10 |
endpointAnthropic,
|
11 |
endpointAnthropicParametersSchema,
|
12 |
} from "./anthropic/endpointAnthropic";
|
13 |
+
import type { Model } from "$lib/types/Model";
|
14 |
|
15 |
// parameters passed when generating text
|
16 |
export interface EndpointParameters {
|
17 |
messages: Omit<Conversation["messages"][0], "id">[];
|
18 |
preprompt?: Conversation["preprompt"];
|
19 |
continueMessage?: boolean; // used to signal that the last message will be extended
|
20 |
+
generateSettings?: Partial<Model["parameters"]>;
|
21 |
}
|
22 |
|
23 |
interface CommonEndpoint {
|
src/lib/server/endpoints/llamacpp/endpointLlamacpp.ts
CHANGED
@@ -19,7 +19,7 @@ export function endpointLlamacpp(
|
|
19 |
input: z.input<typeof endpointLlamacppParametersSchema>
|
20 |
): Endpoint {
|
21 |
const { url, model } = endpointLlamacppParametersSchema.parse(input);
|
22 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
23 |
const prompt = await buildPrompt({
|
24 |
messages,
|
25 |
continueMessage,
|
@@ -27,6 +27,8 @@ export function endpointLlamacpp(
|
|
27 |
model,
|
28 |
});
|
29 |
|
|
|
|
|
30 |
const r = await fetch(`${url}/completion`, {
|
31 |
method: "POST",
|
32 |
headers: {
|
@@ -35,12 +37,12 @@ export function endpointLlamacpp(
|
|
35 |
body: JSON.stringify({
|
36 |
prompt,
|
37 |
stream: true,
|
38 |
-
temperature:
|
39 |
-
top_p:
|
40 |
-
top_k:
|
41 |
-
stop:
|
42 |
-
repeat_penalty:
|
43 |
-
n_predict:
|
44 |
cache_prompt: true,
|
45 |
}),
|
46 |
});
|
|
|
19 |
input: z.input<typeof endpointLlamacppParametersSchema>
|
20 |
): Endpoint {
|
21 |
const { url, model } = endpointLlamacppParametersSchema.parse(input);
|
22 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
23 |
const prompt = await buildPrompt({
|
24 |
messages,
|
25 |
continueMessage,
|
|
|
27 |
model,
|
28 |
});
|
29 |
|
30 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
31 |
+
|
32 |
const r = await fetch(`${url}/completion`, {
|
33 |
method: "POST",
|
34 |
headers: {
|
|
|
37 |
body: JSON.stringify({
|
38 |
prompt,
|
39 |
stream: true,
|
40 |
+
temperature: parameters.temperature,
|
41 |
+
top_p: parameters.top_p,
|
42 |
+
top_k: parameters.top_k,
|
43 |
+
stop: parameters.stop,
|
44 |
+
repeat_penalty: parameters.repetition_penalty,
|
45 |
+
n_predict: parameters.max_new_tokens,
|
46 |
cache_prompt: true,
|
47 |
}),
|
48 |
});
|
src/lib/server/endpoints/ollama/endpointOllama.ts
CHANGED
@@ -14,7 +14,7 @@ export const endpointOllamaParametersSchema = z.object({
|
|
14 |
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
|
15 |
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
|
16 |
|
17 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
18 |
const prompt = await buildPrompt({
|
19 |
messages,
|
20 |
continueMessage,
|
@@ -22,6 +22,8 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
|
|
22 |
model,
|
23 |
});
|
24 |
|
|
|
|
|
25 |
const r = await fetch(`${url}/api/generate`, {
|
26 |
method: "POST",
|
27 |
headers: {
|
@@ -32,12 +34,12 @@ export function endpointOllama(input: z.input<typeof endpointOllamaParametersSch
|
|
32 |
model: ollamaName ?? model.name,
|
33 |
raw: true,
|
34 |
options: {
|
35 |
-
top_p:
|
36 |
-
top_k:
|
37 |
-
temperature:
|
38 |
-
repeat_penalty:
|
39 |
-
stop:
|
40 |
-
num_predict:
|
41 |
},
|
42 |
}),
|
43 |
});
|
|
|
14 |
export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
|
15 |
const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);
|
16 |
|
17 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
18 |
const prompt = await buildPrompt({
|
19 |
messages,
|
20 |
continueMessage,
|
|
|
22 |
model,
|
23 |
});
|
24 |
|
25 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
26 |
+
|
27 |
const r = await fetch(`${url}/api/generate`, {
|
28 |
method: "POST",
|
29 |
headers: {
|
|
|
34 |
model: ollamaName ?? model.name,
|
35 |
raw: true,
|
36 |
options: {
|
37 |
+
top_p: parameters.top_p,
|
38 |
+
top_k: parameters.top_k,
|
39 |
+
temperature: parameters.temperature,
|
40 |
+
repeat_penalty: parameters.repetition_penalty,
|
41 |
+
stop: parameters.stop,
|
42 |
+
num_predict: parameters.max_new_tokens,
|
43 |
},
|
44 |
}),
|
45 |
});
|
src/lib/server/endpoints/openai/endpointOai.ts
CHANGED
@@ -38,7 +38,7 @@ export async function endpointOai(
|
|
38 |
});
|
39 |
|
40 |
if (completion === "completions") {
|
41 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
42 |
const prompt = await buildPrompt({
|
43 |
messages,
|
44 |
continueMessage,
|
@@ -46,21 +46,23 @@ export async function endpointOai(
|
|
46 |
model,
|
47 |
});
|
48 |
|
|
|
|
|
49 |
return openAICompletionToTextGenerationStream(
|
50 |
await openai.completions.create({
|
51 |
model: model.id ?? model.name,
|
52 |
prompt,
|
53 |
stream: true,
|
54 |
-
max_tokens:
|
55 |
-
stop:
|
56 |
-
temperature:
|
57 |
-
top_p:
|
58 |
-
frequency_penalty:
|
59 |
})
|
60 |
);
|
61 |
};
|
62 |
} else if (completion === "chat_completions") {
|
63 |
-
return async ({ messages, preprompt }) => {
|
64 |
let messagesOpenAI = messages.map((message) => ({
|
65 |
role: message.from,
|
66 |
content: message.content,
|
@@ -74,16 +76,18 @@ export async function endpointOai(
|
|
74 |
messagesOpenAI[0].content = preprompt ?? "";
|
75 |
}
|
76 |
|
|
|
|
|
77 |
return openAIChatToTextGenerationStream(
|
78 |
await openai.chat.completions.create({
|
79 |
model: model.id ?? model.name,
|
80 |
messages: messagesOpenAI,
|
81 |
stream: true,
|
82 |
-
max_tokens:
|
83 |
-
stop:
|
84 |
-
temperature:
|
85 |
-
top_p:
|
86 |
-
frequency_penalty:
|
87 |
})
|
88 |
);
|
89 |
};
|
|
|
38 |
});
|
39 |
|
40 |
if (completion === "completions") {
|
41 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
42 |
const prompt = await buildPrompt({
|
43 |
messages,
|
44 |
continueMessage,
|
|
|
46 |
model,
|
47 |
});
|
48 |
|
49 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
50 |
+
|
51 |
return openAICompletionToTextGenerationStream(
|
52 |
await openai.completions.create({
|
53 |
model: model.id ?? model.name,
|
54 |
prompt,
|
55 |
stream: true,
|
56 |
+
max_tokens: parameters?.max_new_tokens,
|
57 |
+
stop: parameters?.stop,
|
58 |
+
temperature: parameters?.temperature,
|
59 |
+
top_p: parameters?.top_p,
|
60 |
+
frequency_penalty: parameters?.repetition_penalty,
|
61 |
})
|
62 |
);
|
63 |
};
|
64 |
} else if (completion === "chat_completions") {
|
65 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
66 |
let messagesOpenAI = messages.map((message) => ({
|
67 |
role: message.from,
|
68 |
content: message.content,
|
|
|
76 |
messagesOpenAI[0].content = preprompt ?? "";
|
77 |
}
|
78 |
|
79 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
80 |
+
|
81 |
return openAIChatToTextGenerationStream(
|
82 |
await openai.chat.completions.create({
|
83 |
model: model.id ?? model.name,
|
84 |
messages: messagesOpenAI,
|
85 |
stream: true,
|
86 |
+
max_tokens: parameters?.max_new_tokens,
|
87 |
+
stop: parameters?.stop,
|
88 |
+
temperature: parameters?.temperature,
|
89 |
+
top_p: parameters?.top_p,
|
90 |
+
frequency_penalty: parameters?.repetition_penalty,
|
91 |
})
|
92 |
);
|
93 |
};
|
src/lib/server/endpoints/tgi/endpointTgi.ts
CHANGED
@@ -16,7 +16,7 @@ export const endpointTgiParametersSchema = z.object({
|
|
16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
18 |
|
19 |
-
return async ({ messages, preprompt, continueMessage }) => {
|
20 |
const prompt = await buildPrompt({
|
21 |
messages,
|
22 |
preprompt,
|
@@ -26,7 +26,7 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
|
|
26 |
|
27 |
return textGenerationStream(
|
28 |
{
|
29 |
-
parameters: { ...model.parameters, return_full_text: false },
|
30 |
model: url,
|
31 |
inputs: prompt,
|
32 |
accessToken,
|
|
|
16 |
export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
|
17 |
const { url, accessToken, model, authorization } = endpointTgiParametersSchema.parse(input);
|
18 |
|
19 |
+
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
|
20 |
const prompt = await buildPrompt({
|
21 |
messages,
|
22 |
preprompt,
|
|
|
26 |
|
27 |
return textGenerationStream(
|
28 |
{
|
29 |
+
parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
|
30 |
model: url,
|
31 |
inputs: prompt,
|
32 |
accessToken,
|
src/lib/types/Assistant.ts
CHANGED
@@ -19,6 +19,12 @@ export interface Assistant extends Timestamps {
|
|
19 |
allowedDomains: string[];
|
20 |
allowedLinks: string[];
|
21 |
};
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
dynamicPrompt?: boolean;
|
23 |
searchTokens: string[];
|
24 |
}
|
|
|
19 |
allowedDomains: string[];
|
20 |
allowedLinks: string[];
|
21 |
};
|
22 |
+
generateSettings?: {
|
23 |
+
temperature?: number;
|
24 |
+
top_p?: number;
|
25 |
+
repetition_penalty?: number;
|
26 |
+
top_k?: number;
|
27 |
+
};
|
28 |
dynamicPrompt?: boolean;
|
29 |
searchTokens: string[];
|
30 |
}
|
src/routes/conversation/[id]/+server.ts
CHANGED
@@ -338,8 +338,11 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
338 |
|
339 |
// check if assistant has a rag
|
340 |
const assistant = await collections.assistants.findOne<
|
341 |
-
Pick<Assistant, "rag" | "dynamicPrompt">
|
342 |
-
>(
|
|
|
|
|
|
|
343 |
|
344 |
const assistantHasRAG =
|
345 |
ENABLE_ASSISTANTS_RAG === "true" &&
|
@@ -403,12 +406,15 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
403 |
|
404 |
const previousText = messageToWriteTo.content;
|
405 |
|
|
|
|
|
406 |
try {
|
407 |
const endpoint = await model.getEndpoint();
|
408 |
for await (const output of await endpoint({
|
409 |
messages: processedMessages,
|
410 |
preprompt,
|
411 |
continueMessage: isContinue,
|
|
|
412 |
})) {
|
413 |
// if not generated_text is here it means the generation is not done
|
414 |
if (!output.generated_text) {
|
@@ -448,10 +454,11 @@ export async function POST({ request, locals, params, getClientAddress }) {
|
|
448 |
}
|
449 |
}
|
450 |
} catch (e) {
|
|
|
451 |
update({ type: "status", status: "error", message: (e as Error).message });
|
452 |
} finally {
|
453 |
// check if no output was generated
|
454 |
-
if (messageToWriteTo.content === previousText) {
|
455 |
update({
|
456 |
type: "status",
|
457 |
status: "error",
|
|
|
338 |
|
339 |
// check if assistant has a rag
|
340 |
const assistant = await collections.assistants.findOne<
|
341 |
+
Pick<Assistant, "rag" | "dynamicPrompt" | "generateSettings">
|
342 |
+
>(
|
343 |
+
{ _id: conv.assistantId },
|
344 |
+
{ projection: { rag: 1, dynamicPrompt: 1, generateSettings: 1 } }
|
345 |
+
);
|
346 |
|
347 |
const assistantHasRAG =
|
348 |
ENABLE_ASSISTANTS_RAG === "true" &&
|
|
|
406 |
|
407 |
const previousText = messageToWriteTo.content;
|
408 |
|
409 |
+
let hasError = false;
|
410 |
+
|
411 |
try {
|
412 |
const endpoint = await model.getEndpoint();
|
413 |
for await (const output of await endpoint({
|
414 |
messages: processedMessages,
|
415 |
preprompt,
|
416 |
continueMessage: isContinue,
|
417 |
+
generateSettings: assistant?.generateSettings,
|
418 |
})) {
|
419 |
// if not generated_text is here it means the generation is not done
|
420 |
if (!output.generated_text) {
|
|
|
454 |
}
|
455 |
}
|
456 |
} catch (e) {
|
457 |
+
hasError = true;
|
458 |
update({ type: "status", status: "error", message: (e as Error).message });
|
459 |
} finally {
|
460 |
// check if no output was generated
|
461 |
+
if (!hasError && messageToWriteTo.content === previousText) {
|
462 |
update({
|
463 |
type: "status",
|
464 |
status: "error",
|
src/routes/settings/(nav)/assistants/[assistantId]/edit/+page.server.ts
CHANGED
@@ -25,6 +25,20 @@ const newAsssistantSchema = z.object({
|
|
25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
});
|
29 |
|
30 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
@@ -143,6 +157,12 @@ export const actions: Actions = {
|
|
143 |
},
|
144 |
dynamicPrompt: parse.data.dynamicPrompt,
|
145 |
searchTokens: generateSearchTokens(parse.data.name),
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
},
|
147 |
}
|
148 |
);
|
|
|
25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
28 |
+
temperature: z
|
29 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
30 |
+
.transform((v) => (v === "" ? undefined : v)),
|
31 |
+
top_p: z
|
32 |
+
.union([z.literal(""), z.coerce.number().min(0.05).max(1)])
|
33 |
+
.transform((v) => (v === "" ? undefined : v)),
|
34 |
+
|
35 |
+
repetition_penalty: z
|
36 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
37 |
+
.transform((v) => (v === "" ? undefined : v)),
|
38 |
+
|
39 |
+
top_k: z
|
40 |
+
.union([z.literal(""), z.coerce.number().min(5).max(100)])
|
41 |
+
.transform((v) => (v === "" ? undefined : v)),
|
42 |
});
|
43 |
|
44 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
|
|
157 |
},
|
158 |
dynamicPrompt: parse.data.dynamicPrompt,
|
159 |
searchTokens: generateSearchTokens(parse.data.name),
|
160 |
+
generateSettings: {
|
161 |
+
temperature: parse.data.temperature,
|
162 |
+
top_p: parse.data.top_p,
|
163 |
+
repetition_penalty: parse.data.repetition_penalty,
|
164 |
+
top_k: parse.data.top_k,
|
165 |
+
},
|
166 |
},
|
167 |
}
|
168 |
);
|
src/routes/settings/(nav)/assistants/new/+page.server.ts
CHANGED
@@ -25,6 +25,20 @@ const newAsssistantSchema = z.object({
|
|
25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
});
|
29 |
|
30 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
@@ -125,6 +139,12 @@ export const actions: Actions = {
|
|
125 |
},
|
126 |
dynamicPrompt: parse.data.dynamicPrompt,
|
127 |
searchTokens: generateSearchTokens(parse.data.name),
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
});
|
129 |
|
130 |
// add insertedId to user settings
|
|
|
25 |
ragDomainList: z.preprocess(parseStringToList, z.string().array()),
|
26 |
ragAllowAll: z.preprocess((v) => v === "true", z.boolean()),
|
27 |
dynamicPrompt: z.preprocess((v) => v === "on", z.boolean()),
|
28 |
+
temperature: z
|
29 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
30 |
+
.transform((v) => (v === "" ? undefined : v)),
|
31 |
+
top_p: z
|
32 |
+
.union([z.literal(""), z.coerce.number().min(0.05).max(1)])
|
33 |
+
.transform((v) => (v === "" ? undefined : v)),
|
34 |
+
|
35 |
+
repetition_penalty: z
|
36 |
+
.union([z.literal(""), z.coerce.number().min(0.1).max(2)])
|
37 |
+
.transform((v) => (v === "" ? undefined : v)),
|
38 |
+
|
39 |
+
top_k: z
|
40 |
+
.union([z.literal(""), z.coerce.number().min(5).max(100)])
|
41 |
+
.transform((v) => (v === "" ? undefined : v)),
|
42 |
});
|
43 |
|
44 |
const uploadAvatar = async (avatar: File, assistantId: ObjectId): Promise<string> => {
|
|
|
139 |
},
|
140 |
dynamicPrompt: parse.data.dynamicPrompt,
|
141 |
searchTokens: generateSearchTokens(parse.data.name),
|
142 |
+
generateSettings: {
|
143 |
+
temperature: parse.data.temperature,
|
144 |
+
top_p: parse.data.top_p,
|
145 |
+
repetition_penalty: parse.data.repetition_penalty,
|
146 |
+
top_k: parse.data.top_k,
|
147 |
+
},
|
148 |
});
|
149 |
|
150 |
// add insertedId to user settings
|