File size: 2,776 Bytes
1a8c098
 
e8b5344
 
 
 
 
2cadf2a
e8b5344
 
2cadf2a
e8b5344
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
 
9662103
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52803ee
 
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import { browser } from "$app/environment";
import { goto } from "$app/navigation";
import { defaultGenerationConfig } from "$lib/components/InferencePlayground/generationConfigSettings";
import {
	defaultSystemMessage,
	FEATURED_MODELS_IDS,
} from "$lib/components/InferencePlayground/inferencePlaygroundUtils";
import type { Conversation, ConversationMessage, Session } from "$lib/components/InferencePlayground/types";

import { models } from "$lib/stores/models";
import { get, writable } from "svelte/store";

function createSessionStore() {
	const store = writable<Session>(undefined, (set, update) => {
		const searchParams = new URLSearchParams(browser ? window.location.search : undefined);

		const modelIdsFromSearchParam = searchParams.getAll("modelId");
		const modelsFromSearchParam = modelIdsFromSearchParam?.map(id => get(models).find(model => model.id === id));

		const providersFromSearchParam = searchParams.getAll("provider");

		const startMessageUser: ConversationMessage = { role: "user", content: "" };
		const systemMessage: ConversationMessage = {
			role: "system",
			content: modelIdsFromSearchParam?.[0] ? (defaultSystemMessage?.[modelIdsFromSearchParam[0]] ?? "") : "",
		};

		set({
			conversations: [
				{
					model: get(models).find(m => FEATURED_MODELS_IDS.includes(m.id)) ??
						get(models)[0] ?? {
							id: "",
							downloads: 0,
							gated: false,
							likes: 0,
							name: "",
							private: false,
							tokenizerConfig: {},
							updatedAt: new Date(),
						},
					config: { ...defaultGenerationConfig },
					messages: [{ ...startMessageUser }],
					systemMessage,
					streaming: true,
				},
			],
		});

		if (modelsFromSearchParam?.length) {
			const conversations = modelsFromSearchParam.map((model, i) => {
				return {
					model,
					config: { ...defaultGenerationConfig },
					messages: [{ ...startMessageUser }],
					systemMessage,
					streaming: true,
					provider: providersFromSearchParam?.[i],
				};
			}) as [Conversation] | [Conversation, Conversation];
			update(s => ({ ...s, conversations }));
		}
	});

	const update: typeof store.update = cb => {
		const query = new URLSearchParams(window.location.search);
		query.delete("modelId");
		query.delete("provider");

		store.update($s => {
			const s = cb($s);

			const modelIds = s.conversations.map(c => c.model.id);
			modelIds.forEach(m => query.append("modelId", m));

			const providers = s.conversations.map(c => c.provider ?? "hf-inference");
			providers.forEach(p => query.append("provider", p));

			goto(`?${query}`, { replaceState: true });

			return s;
		});
	};

	const set: typeof store.set = (...args) => {
		update(_ => args[0]);
	};

	return { ...store, set, update };
}

export const session = createSessionStore();