Thomas G. Lopes commited on
Commit
9ab40fd
·
1 Parent(s): 2488073
src/lib/components/InferencePlayground/InferencePlaygroundModelSelector.svelte CHANGED
@@ -15,6 +15,7 @@
15
  import ModelSelectorModal from "./InferencePlaygroundModelSelectorModal.svelte";
16
  import { defaultSystemMessage } from "./inferencePlaygroundUtils";
17
  import { createSelect, createSync } from "@melt-ui/svelte";
 
18
 
19
  export let conversation: Conversation;
20
 
@@ -41,31 +42,6 @@
41
  $: nameSpace = conversation.model.id.split("/")[0] ?? "";
42
  $: modelName = conversation.model.id.split("/")[1] ?? "";
43
  const id = crypto.randomUUID();
44
-
45
- // Provider
46
- async function loadProviders(modelId: string) {
47
- if (!browser) return;
48
- providerMap = {};
49
- const res = await fetchHuggingFaceModel(modelId, $token.value);
50
- providerMap = res.inferenceProviderMapping;
51
- if (conversation.provider ?? "" in providerMap) return;
52
- conversation.provider = randomPick(Object.keys(providerMap));
53
- }
54
-
55
- let providerMap: InferenceProviderMapping = {};
56
- $: modelId = conversation.model.id;
57
- $: loadProviders(modelId);
58
- $: provider = conversation.provider;
59
-
60
- const {
61
- elements: { trigger, menu, option },
62
- states: { selected },
63
- } = createSelect<string, false>();
64
- const sync = createSync({ selected });
65
- $: sync.selected(
66
- conversation.provider ? { value: conversation.provider } : undefined,
67
- p => (conversation.provider = p?.value)
68
- );
69
  </script>
70
 
71
  <div class="flex flex-col gap-2">
@@ -97,35 +73,4 @@
97
  />
98
  {/if}
99
 
100
- <div class="flex flex-col gap-2">
101
- <!--
102
- <label class="flex items-baseline gap-2 text-sm font-medium text-gray-900 dark:text-white">
103
- Providers<span class="text-xs font-normal text-gray-400"></span>
104
- </label>
105
- -->
106
-
107
- <button
108
- {...$trigger}
109
- use:trigger
110
- class="relative flex items-center justify-between gap-6 overflow-hidden rounded-lg border bg-gray-100/80 px-3 py-1.5 leading-tight whitespace-nowrap shadow-sm hover:brightness-95 dark:border-gray-700 dark:bg-gray-800 dark:hover:brightness-110"
111
- >
112
- <div class="flex items-center gap-1 text-sm text-gray-500 dark:text-gray-300">
113
- <IconProvider provider={conversation.provider} />
114
- {conversation.provider ?? "loading"}
115
- </div>
116
- <IconCaret classNames="text-xl bg-gray-100 dark:bg-gray-600 rounded-sm size-4 flex-none absolute right-2" />
117
- </button>
118
-
119
- <div {...$menu} use:menu class="rounded-lg border bg-gray-100/80 dark:border-gray-700 dark:bg-gray-800">
120
- {#each Object.keys(providerMap) as provider (provider)}
121
- <div {...$option({ value: provider })} use:option class="group p-1 text-sm dark:text-white">
122
- <div
123
- class="flex items-center gap-2 rounded-md px-2 py-1 group-data-[highlighted]:bg-gray-200 dark:group-data-[highlighted]:bg-gray-700"
124
- >
125
- <IconProvider {provider} />
126
- {provider}
127
- </div>
128
- </div>
129
- {/each}
130
- </div>
131
- </div>
 
15
  import ModelSelectorModal from "./InferencePlaygroundModelSelectorModal.svelte";
16
  import { defaultSystemMessage } from "./inferencePlaygroundUtils";
17
  import { createSelect, createSync } from "@melt-ui/svelte";
18
+ import ProviderSelect from "./InferencePlaygroundProviderSelect.svelte";
19
 
20
  export let conversation: Conversation;
21
 
 
42
  $: nameSpace = conversation.model.id.split("/")[0] ?? "";
43
  $: modelName = conversation.model.id.split("/")[1] ?? "";
44
  const id = crypto.randomUUID();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  </script>
46
 
47
  <div class="flex flex-col gap-2">
 
73
  />
74
  {/if}
75
 
76
+ <ProviderSelect bind:conversation />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/lib/components/InferencePlayground/InferencePlaygroundProviderSelect.svelte ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import type { Conversation } from "./types";
3
+
4
+ import { browser } from "$app/environment";
5
+ import { fetchHuggingFaceModel, type InferenceProviderMapping } from "$lib/fetchers/providers";
6
+ import { token } from "$lib/stores/token";
7
+ import { randomPick } from "$lib/utils/array";
8
+ import { createSelect, createSync } from "@melt-ui/svelte";
9
+ import IconCaret from "../Icons/IconCaret.svelte";
10
+ import IconProvider from "../Icons/IconProvider.svelte";
11
+
12
+ export let conversation: Conversation;
13
+
14
+ async function loadProviders(modelId: string) {
15
+ console.log(modelId);
16
+ if (!browser) return;
17
+ providerMap = {};
18
+ const res = await fetchHuggingFaceModel(modelId, $token.value);
19
+ providerMap = res.inferenceProviderMapping;
20
+ // Commented out. I'm not sure we want to maintain, or always random pick
21
+ // if ((conversation.provider ?? "") in providerMap) return;
22
+ conversation.provider = randomPick(Object.keys(providerMap));
23
+ }
24
+
25
+ let providerMap: InferenceProviderMapping = {};
26
+ $: modelId = conversation.model.id;
27
+ $: loadProviders(modelId);
28
+
29
+ const {
30
+ elements: { trigger, menu, option },
31
+ states: { selected },
32
+ } = createSelect<string, false>();
33
+ const sync = createSync({ selected });
34
+ $: sync.selected(
35
+ conversation.provider ? { value: conversation.provider } : undefined,
36
+ p => (conversation.provider = p?.value)
37
+ );
38
+
39
+ const nameMap: Record<string, string> = {
40
+ "sambanova": "SambaNova",
41
+ "fal": "fal",
42
+ "cerebras": "Cerebras",
43
+ "replicate": "Replicate",
44
+ "black-forest-labs": "Black Forest Labs",
45
+ "fireworks-ai": "Fireworks",
46
+ "together": "Together AI",
47
+ "nebius": "Nebius AI Studio",
48
+ "hyperbolic": "Hyperbolic",
49
+ "novita": "Novita",
50
+ "cohere": "Nohere",
51
+ "hf-inference": "HF Inference API",
52
+ };
53
+ const UPPERCASE_WORDS = ["hf", "ai"];
54
+
55
+ function formatName(provider: string) {
56
+ if (provider in nameMap) return nameMap[provider];
57
+
58
+ const words = provider
59
+ .toLowerCase()
60
+ .split("-")
61
+ .map(word => {
62
+ if (UPPERCASE_WORDS.includes(word)) {
63
+ return word.toUpperCase();
64
+ } else {
65
+ return word.charAt(0).toUpperCase() + word.slice(1).toLowerCase();
66
+ }
67
+ });
68
+
69
+ return words.join(" ");
70
+ }
71
+ </script>
72
+
73
+ <div class="flex flex-col gap-2">
74
+ <!--
75
+ <label class="flex items-baseline gap-2 text-sm font-medium text-gray-900 dark:text-white">
76
+ Providers<span class="text-xs font-normal text-gray-400"></span>
77
+ </label>
78
+ -->
79
+
80
+ <button
81
+ {...$trigger}
82
+ use:trigger
83
+ class="relative flex items-center justify-between gap-6 overflow-hidden rounded-lg border bg-gray-100/80 px-3 py-1.5 leading-tight whitespace-nowrap shadow-sm hover:brightness-95 dark:border-gray-700 dark:bg-gray-800 dark:hover:brightness-110"
84
+ >
85
+ <div class="flex items-center gap-1 text-sm text-gray-500 dark:text-gray-300">
86
+ <IconProvider provider={conversation.provider} />
87
+ {formatName(conversation.provider ?? "") ?? "loading"}
88
+ </div>
89
+ <IconCaret classNames="text-xl bg-gray-100 dark:bg-gray-600 rounded-sm size-4 flex-none absolute right-2" />
90
+ </button>
91
+
92
+ <div {...$menu} use:menu class="rounded-lg border bg-gray-100/80 dark:border-gray-700 dark:bg-gray-800">
93
+ {#each Object.keys(providerMap) as provider (provider)}
94
+ <div {...$option({ value: provider })} use:option class="group p-1 text-sm dark:text-white">
95
+ <div
96
+ class="flex items-center gap-2 rounded-md px-2 py-1.5 group-data-[highlighted]:bg-gray-200 dark:group-data-[highlighted]:bg-gray-700"
97
+ >
98
+ <IconProvider {provider} />
99
+ {formatName(provider)}
100
+ </div>
101
+ </div>
102
+ {/each}
103
+ </div>
104
+ </div>
src/lib/components/InferencePlayground/inferencePlaygroundUtils.ts CHANGED
@@ -19,6 +19,7 @@ export async function handleStreamingResponse(
19
  {
20
  model: model.id,
21
  messages,
 
22
  ...conversation.config,
23
  },
24
  { signal: abortController.signal }
@@ -43,6 +44,7 @@ export async function handleNonStreamingResponse(
43
  const response = await hf.chatCompletion({
44
  model: model.id,
45
  messages,
 
46
  ...conversation.config,
47
  });
48
 
 
19
  {
20
  model: model.id,
21
  messages,
22
+ provider: conversation.provider,
23
  ...conversation.config,
24
  },
25
  { signal: abortController.signal }
 
44
  const response = await hf.chatCompletion({
45
  model: model.id,
46
  messages,
47
+ provider: conversation.provider,
48
  ...conversation.config,
49
  });
50