Thomas G. Lopes commited on
Commit
1a8c098
·
1 Parent(s): 52803ee

add provider to search params

Browse files
src/lib/components/InferencePlayground/InferencePlayground.svelte CHANGED
@@ -1,14 +1,12 @@
1
  <script lang="ts">
2
  import type { Conversation, ConversationMessage, ModelEntryWithTokenizer } from "./types";
3
 
4
- import { page } from "$app/stores";
5
  import {
6
  handleNonStreamingResponse,
7
  handleStreamingResponse,
8
  isSystemPromptSupported,
9
  } from "./inferencePlaygroundUtils";
10
 
11
- import { goto } from "$app/navigation";
12
  import { models } from "$lib/stores/models";
13
  import { session } from "$lib/stores/session";
14
  import { token } from "$lib/stores/token";
@@ -204,15 +202,6 @@
204
  const newConversation = { ...JSON.parse(JSON.stringify($session.conversations[0])), model };
205
  $session.conversations = [...$session.conversations, newConversation];
206
  generationStats = [generationStats[0], { latency: 0, generatedTokensCount: 0 }];
207
-
208
- // update query param
209
- const url = new URL($page.url);
210
- const queryParamValue = `${$session.conversations[0].model.id},${modelId}`;
211
- url.searchParams.set("modelId", queryParamValue);
212
-
213
- const parentOrigin = "https://huggingface.co";
214
- window.parent.postMessage({ queryString: `modelId=${queryParamValue}` }, parentOrigin);
215
- goto(url.toString(), { replaceState: true });
216
  }
217
 
218
  function removeCompareModal(conversationIdx: number) {
@@ -220,19 +209,6 @@
220
  $session = $session;
221
  generationStats.splice(conversationIdx, 1)[0];
222
  generationStats = generationStats;
223
-
224
- // update query param
225
- const url = new URL($page.url);
226
- const queryParamValue = url.searchParams.get("modelId");
227
- if (queryParamValue) {
228
- const modelIds = queryParamValue.split(",") as [string, string];
229
- const newQueryParamValue = conversationIdx === 1 ? modelIds[0] : modelIds[1];
230
- url.searchParams.set("modelId", newQueryParamValue);
231
-
232
- const parentOrigin = "https://huggingface.co";
233
- window.parent.postMessage({ queryString: `modelId=${newQueryParamValue}` }, parentOrigin);
234
- goto(url.toString(), { replaceState: true });
235
- }
236
  }
237
 
238
  onDestroy(() => {
 
1
  <script lang="ts">
2
  import type { Conversation, ConversationMessage, ModelEntryWithTokenizer } from "./types";
3
 
 
4
  import {
5
  handleNonStreamingResponse,
6
  handleStreamingResponse,
7
  isSystemPromptSupported,
8
  } from "./inferencePlaygroundUtils";
9
 
 
10
  import { models } from "$lib/stores/models";
11
  import { session } from "$lib/stores/session";
12
  import { token } from "$lib/stores/token";
 
202
  const newConversation = { ...JSON.parse(JSON.stringify($session.conversations[0])), model };
203
  $session.conversations = [...$session.conversations, newConversation];
204
  generationStats = [generationStats[0], { latency: 0, generatedTokensCount: 0 }];
 
 
 
 
 
 
 
 
 
205
  }
206
 
207
  function removeCompareModal(conversationIdx: number) {
 
209
  $session = $session;
210
  generationStats.splice(conversationIdx, 1)[0];
211
  generationStats = generationStats;
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  }
213
 
214
  onDestroy(() => {
src/lib/components/InferencePlayground/InferencePlaygroundConversationHeader.svelte CHANGED
@@ -3,13 +3,11 @@
3
 
4
  import { createEventDispatcher } from "svelte";
5
 
6
- import { page } from "$app/stores";
 
7
  import IconCog from "../Icons/IconCog.svelte";
8
  import GenerationConfig from "./InferencePlaygroundGenerationConfig.svelte";
9
  import ModelSelectorModal from "./InferencePlaygroundModelSelectorModal.svelte";
10
- import Avatar from "../Avatar.svelte";
11
- import { goto } from "$app/navigation";
12
- import { models } from "$lib/stores/models";
13
  import InferencePlaygroundProviderSelect from "./InferencePlaygroundProviderSelect.svelte";
14
 
15
  export let conversation: Conversation;
@@ -25,21 +23,7 @@
25
  return;
26
  }
27
  conversation.model = model;
28
-
29
- const url = new URL($page.url);
30
- const queryParamValue = url.searchParams.get("modelId");
31
- if (queryParamValue) {
32
- const modelIds = queryParamValue.split(",") as [string, string];
33
- modelIds[conversationIdx] = newModelId;
34
-
35
- const newQueryParamValue = modelIds.join(",");
36
- url.searchParams.set("modelId", newQueryParamValue);
37
-
38
- const parentOrigin = "https://huggingface.co";
39
- window.parent.postMessage({ queryString: `modelId=${newQueryParamValue}` }, parentOrigin);
40
-
41
- goto(url.toString(), { replaceState: true });
42
- }
43
  }
44
 
45
  $: nameSpace = conversation.model.id.split("/")[0] ?? "";
 
3
 
4
  import { createEventDispatcher } from "svelte";
5
 
6
+ import { models } from "$lib/stores/models";
7
+ import Avatar from "../Avatar.svelte";
8
  import IconCog from "../Icons/IconCog.svelte";
9
  import GenerationConfig from "./InferencePlaygroundGenerationConfig.svelte";
10
  import ModelSelectorModal from "./InferencePlaygroundModelSelectorModal.svelte";
 
 
 
11
  import InferencePlaygroundProviderSelect from "./InferencePlaygroundProviderSelect.svelte";
12
 
13
  export let conversation: Conversation;
 
23
  return;
24
  }
25
  conversation.model = model;
26
+ conversation.provider = undefined;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  }
28
 
29
  $: nameSpace = conversation.model.id.split("/")[0] ?? "";
src/lib/components/InferencePlayground/InferencePlaygroundModelSelector.svelte CHANGED
@@ -1,9 +1,6 @@
1
  <script lang="ts">
2
  import type { Conversation, ModelEntryWithTokenizer } from "./types";
3
 
4
- import { goto } from "$app/navigation";
5
- import { page } from "$app/stores";
6
-
7
  import { models } from "$lib/stores/models";
8
  import Avatar from "../Avatar.svelte";
9
  import IconCaret from "../Icons/IconCaret.svelte";
@@ -23,14 +20,7 @@
23
  }
24
  conversation.model = model;
25
  conversation.systemMessage = { role: "system", content: defaultSystemMessage?.[modelId] ?? "" };
26
-
27
- const url = new URL($page.url);
28
- url.searchParams.set("modelId", model.id);
29
-
30
- const parentOrigin = "https://huggingface.co";
31
- window.parent.postMessage({ queryString: `modelId=${model.id}` }, parentOrigin);
32
-
33
- goto(url.toString(), { replaceState: true });
34
  }
35
 
36
  $: nameSpace = conversation.model.id.split("/")[0] ?? "";
 
1
  <script lang="ts">
2
  import type { Conversation, ModelEntryWithTokenizer } from "./types";
3
 
 
 
 
4
  import { models } from "$lib/stores/models";
5
  import Avatar from "../Avatar.svelte";
6
  import IconCaret from "../Icons/IconCaret.svelte";
 
20
  }
21
  conversation.model = model;
22
  conversation.systemMessage = { role: "system", content: defaultSystemMessage?.[modelId] ?? "" };
23
+ conversation.provider = undefined;
 
 
 
 
 
 
 
24
  }
25
 
26
  $: nameSpace = conversation.model.id.split("/")[0] ?? "";
src/lib/components/InferencePlayground/InferencePlaygroundProviderSelect.svelte CHANGED
@@ -15,13 +15,12 @@
15
  export { classes as class };
16
 
17
  async function loadProviders(modelId: string) {
18
- console.log(modelId);
19
  if (!browser) return;
20
  providerMap = {};
21
  const res = await fetchHuggingFaceModel(modelId, $token.value);
22
  providerMap = res.inferenceProviderMapping;
23
  // Commented out. I'm not sure we want to maintain, or always random pick
24
- // if ((conversation.provider ?? "") in providerMap) return;
25
  conversation.provider = randomPick(Object.keys(providerMap));
26
  }
27
 
 
15
  export { classes as class };
16
 
17
  async function loadProviders(modelId: string) {
 
18
  if (!browser) return;
19
  providerMap = {};
20
  const res = await fetchHuggingFaceModel(modelId, $token.value);
21
  providerMap = res.inferenceProviderMapping;
22
  // Commented out. I'm not sure we want to maintain, or always random pick
23
+ if ((conversation.provider ?? "") in providerMap) return;
24
  conversation.provider = randomPick(Object.keys(providerMap));
25
  }
26
 
src/lib/stores/session.ts CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import { defaultGenerationConfig } from "$lib/components/InferencePlayground/generationConfigSettings";
2
  import {
3
  defaultSystemMessage,
@@ -6,50 +8,85 @@ import {
6
  import type { Conversation, ConversationMessage, Session } from "$lib/components/InferencePlayground/types";
7
 
8
  import { models } from "$lib/stores/models";
9
- import { safePage } from "$lib/utils/store";
10
  import { get, writable } from "svelte/store";
11
 
12
- export const session = writable<Session>(undefined, (set, update) => {
13
- const startMessageUser: ConversationMessage = { role: "user", content: "" };
14
- const modelIdsFromQueryParam = get(safePage)?.url?.searchParams?.get("modelId")?.split(",");
15
- const modelsFromQueryParam = modelIdsFromQueryParam?.map(id => get(models).find(model => model.id === id));
16
- const systemMessage: ConversationMessage = {
17
- role: "system",
18
- content: modelIdsFromQueryParam?.[0] ? (defaultSystemMessage?.[modelIdsFromQueryParam[0]] ?? "") : "",
19
- };
 
 
 
 
 
 
20
 
21
- set({
22
- conversations: [
23
- {
24
- model: get(models).find(m => FEATURED_MODELS_IDS.includes(m.id)) ??
25
- get(models)[0] ?? {
26
- id: "",
27
- downloads: 0,
28
- gated: false,
29
- likes: 0,
30
- name: "",
31
- private: false,
32
- tokenizerConfig: {},
33
- updatedAt: new Date(),
34
- },
35
- config: { ...defaultGenerationConfig },
36
- messages: [{ ...startMessageUser }],
37
- systemMessage,
38
- streaming: true,
39
- },
40
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  });
42
 
43
- if (modelsFromQueryParam?.length) {
44
- const conversations = modelsFromQueryParam.map(model => {
45
- return {
46
- model,
47
- config: { ...defaultGenerationConfig },
48
- messages: [{ ...startMessageUser }],
49
- systemMessage,
50
- streaming: true,
51
- };
52
- }) as [Conversation] | [Conversation, Conversation];
53
- update(s => ({ ...s, conversations }));
54
- }
55
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { browser } from "$app/environment";
2
+ import { goto } from "$app/navigation";
3
  import { defaultGenerationConfig } from "$lib/components/InferencePlayground/generationConfigSettings";
4
  import {
5
  defaultSystemMessage,
 
8
  import type { Conversation, ConversationMessage, Session } from "$lib/components/InferencePlayground/types";
9
 
10
  import { models } from "$lib/stores/models";
 
11
  import { get, writable } from "svelte/store";
12
 
13
+ function createSessionStore() {
14
+ const store = writable<Session>(undefined, (set, update) => {
15
+ const searchParams = new URLSearchParams(browser ? window.location.search : undefined);
16
+
17
+ const modelIdsFromSearchParam = searchParams.getAll("modelId");
18
+ const modelsFromSearchParam = modelIdsFromSearchParam?.map(id => get(models).find(model => model.id === id));
19
+
20
+ const providersFromSearchParam = searchParams.getAll("provider");
21
+
22
+ const startMessageUser: ConversationMessage = { role: "user", content: "" };
23
+ const systemMessage: ConversationMessage = {
24
+ role: "system",
25
+ content: modelIdsFromSearchParam?.[0] ? (defaultSystemMessage?.[modelIdsFromSearchParam[0]] ?? "") : "",
26
+ };
27
 
28
+ set({
29
+ conversations: [
30
+ {
31
+ model: get(models).find(m => FEATURED_MODELS_IDS.includes(m.id)) ??
32
+ get(models)[0] ?? {
33
+ id: "",
34
+ downloads: 0,
35
+ gated: false,
36
+ likes: 0,
37
+ name: "",
38
+ private: false,
39
+ tokenizerConfig: {},
40
+ updatedAt: new Date(),
41
+ },
42
+ config: { ...defaultGenerationConfig },
43
+ messages: [{ ...startMessageUser }],
44
+ systemMessage,
45
+ streaming: true,
46
+ },
47
+ ],
48
+ });
49
+
50
+ if (modelsFromSearchParam?.length) {
51
+ const conversations = modelsFromSearchParam.map((model, i) => {
52
+ return {
53
+ model,
54
+ config: { ...defaultGenerationConfig },
55
+ messages: [{ ...startMessageUser }],
56
+ systemMessage,
57
+ streaming: true,
58
+ provider: providersFromSearchParam?.[i],
59
+ };
60
+ }) as [Conversation] | [Conversation, Conversation];
61
+ update(s => ({ ...s, conversations }));
62
+ }
63
  });
64
 
65
+ const update: typeof store.update = cb => {
66
+ const query = new URLSearchParams(window.location.search);
67
+ query.delete("modelId");
68
+ query.delete("provider");
69
+
70
+ store.update($s => {
71
+ const s = cb($s);
72
+
73
+ const modelIds = s.conversations.map(c => c.model.id);
74
+ modelIds.forEach(m => query.append("modelId", m));
75
+
76
+ const providers = s.conversations.map(c => c.provider ?? "hf-inference");
77
+ providers.forEach(p => query.append("provider", p));
78
+
79
+ goto(`?${query}`, { replaceState: true });
80
+
81
+ return s;
82
+ });
83
+ };
84
+
85
+ const set: typeof store.set = (...args) => {
86
+ update(_ => args[0]);
87
+ };
88
+
89
+ return { ...store, set, update };
90
+ }
91
+
92
+ export const session = createSessionStore();