File size: 3,170 Bytes
1a8c098
 
e8b5344
39318e7
b924465
e8b5344
 
2cadf2a
39318e7
e8b5344
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
 
9662103
39318e7
 
 
1a8c098
 
 
39318e7
 
b924465
 
 
 
 
1a8c098
 
b924465
 
 
 
 
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52803ee
 
1a8c098
a379843
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
a379843
 
 
 
 
 
1a8c098
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import { browser } from "$app/environment";
import { goto } from "$app/navigation";
import { defaultGenerationConfig } from "$lib/components/InferencePlayground/generationConfigSettings";
import { defaultSystemMessage } from "$lib/components/InferencePlayground/inferencePlaygroundUtils";
import { PipelineTag, type Conversation, type ConversationMessage, type Session } from "$lib/types";

import { models } from "$lib/stores/models";
import { get, writable } from "svelte/store";
import { getTrending } from "$lib/utils/model";

function createSessionStore() {
	const store = writable<Session>(undefined, (set, update) => {
		const searchParams = new URLSearchParams(browser ? window.location.search : undefined);

		const modelIdsFromSearchParam = searchParams.getAll("modelId");
		const modelsFromSearchParam = modelIdsFromSearchParam?.map(id => get(models).find(model => model.id === id));

		const providersFromSearchParam = searchParams.getAll("provider");

		const startMessageUser: ConversationMessage = { role: "user", content: "" };
		const systemMessage: ConversationMessage = {
			role: "system",
			content: modelIdsFromSearchParam?.[0] ? (defaultSystemMessage?.[modelIdsFromSearchParam[0]] ?? "") : "",
		};

		const $models = get(models);
		const featured = getTrending($models);

		set({
			conversations: [
				{
					model: featured[0] ??
						$models[0] ?? {
							_id: "",
							inferenceProviderMapping: [],
							pipeline_tag: PipelineTag.TextGeneration,
							trendingScore: 0,
							tags: ["text-generation"],
							id: "",
							tokenizerConfig: {},
							config: {
								architectures: [] as string[],
								model_type: "",
								tokenizer_config: {},
							},
						},
					config: { ...defaultGenerationConfig },
					messages: [{ ...startMessageUser }],
					systemMessage,
					streaming: true,
				},
			],
		});

		if (modelsFromSearchParam?.length) {
			const conversations = modelsFromSearchParam.map((model, i) => {
				return {
					model,
					config: { ...defaultGenerationConfig },
					messages: [{ ...startMessageUser }],
					systemMessage,
					streaming: true,
					provider: providersFromSearchParam?.[i],
				};
			}) as [Conversation] | [Conversation, Conversation];
			update(s => ({ ...s, conversations }));
		}
	});

	const update: typeof store.update = cb => {
		const prevQuery = window.location.search;
		const query = new URLSearchParams(window.location.search);
		query.delete("modelId");
		query.delete("provider");

		store.update($s => {
			const s = cb($s);

			const modelIds = s.conversations.map(c => c.model.id);
			modelIds.forEach(m => query.append("modelId", m));

			const providers = s.conversations.map(c => c.provider ?? "hf-inference");
			providers.forEach(p => query.append("provider", p));

			const newQuery = query.toString();
			// slice to remove the ? prefix
			if (newQuery !== prevQuery.slice(1)) {
				console.log(prevQuery, newQuery);
				goto(`?${query}`, { replaceState: true });
			}

			return s;
		});
	};

	const set: typeof store.set = (...args) => {
		update(_ => args[0]);
	};

	return { ...store, set, update };
}

export const session = createSessionStore();