Antonio Ramos antoniora commited on
Commit
f12455d
·
unverified ·
1 Parent(s): 4538f1d

Add langserve endpoint (#1009)

Browse files

* Add support for langserve endpoints

* Add support for langserve endpoints

* Fix linting

* Fix linting issues

* Fix issue import

---------

Co-authored-by: antoniora <[email protected]>

README.md CHANGED
@@ -618,6 +618,24 @@ MODELS=`[
618
 
619
  ```
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  ### Custom endpoint authorization
622
 
623
  #### Basic and Bearer
 
618
 
619
  ```
620
 
621
+ ##### LangServe
622
+
623
+ LangChain applications that are deployed using LangServe can be called with the following config:
624
+
625
+ ```
626
+ MODELS=`[
627
+ //...
628
+ {
629
+ "name": "summarization-chain", //model-name
630
+ "endpoints" : [{
631
+ "type": "langserve",
632
+ "url" : "http://127.0.0.1:8100",
633
+ }]
634
+ },
635
+ ]`
636
+
637
+ ```
638
+
639
  ### Custom endpoint authorization
640
 
641
  #### Basic and Bearer
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -17,6 +17,9 @@ import endpointCloudflare, {
17
  endpointCloudflareParametersSchema,
18
  } from "./cloudflare/endpointCloudflare";
19
  import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere";
 
 
 
20
 
21
  // parameters passed when generating text
22
  export interface EndpointParameters {
@@ -48,6 +51,7 @@ export const endpoints = {
48
  vertex: endpointVertex,
49
  cloudflare: endpointCloudflare,
50
  cohere: endpointCohere,
 
51
  };
52
 
53
  export const endpointSchema = z.discriminatedUnion("type", [
@@ -60,5 +64,6 @@ export const endpointSchema = z.discriminatedUnion("type", [
60
  endpointVertexParametersSchema,
61
  endpointCloudflareParametersSchema,
62
  endpointCohereParametersSchema,
 
63
  ]);
64
  export default endpoints;
 
17
  endpointCloudflareParametersSchema,
18
  } from "./cloudflare/endpointCloudflare";
19
  import { endpointCohere, endpointCohereParametersSchema } from "./cohere/endpointCohere";
20
+ import endpointLangserve, {
21
+ endpointLangserveParametersSchema,
22
+ } from "./langserve/endpointLangserve";
23
 
24
  // parameters passed when generating text
25
  export interface EndpointParameters {
 
51
  vertex: endpointVertex,
52
  cloudflare: endpointCloudflare,
53
  cohere: endpointCohere,
54
+ langserve: endpointLangserve,
55
  };
56
 
57
  export const endpointSchema = z.discriminatedUnion("type", [
 
64
  endpointVertexParametersSchema,
65
  endpointCloudflareParametersSchema,
66
  endpointCohereParametersSchema,
67
+ endpointLangserveParametersSchema,
68
  ]);
69
  export default endpoints;
src/lib/server/endpoints/langserve/endpointLangserve.ts ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { buildPrompt } from "$lib/buildPrompt";
2
+ import { z } from "zod";
3
+ import type { Endpoint } from "../endpoints";
4
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
5
+
6
+ export const endpointLangserveParametersSchema = z.object({
7
+ weight: z.number().int().positive().default(1),
8
+ model: z.any(),
9
+ type: z.literal("langserve"),
10
+ url: z.string().url(),
11
+ });
12
+
13
+ export function endpointLangserve(
14
+ input: z.input<typeof endpointLangserveParametersSchema>
15
+ ): Endpoint {
16
+ const { url, model } = endpointLangserveParametersSchema.parse(input);
17
+
18
+ return async ({ messages, preprompt, continueMessage }) => {
19
+ const prompt = await buildPrompt({
20
+ messages,
21
+ continueMessage,
22
+ preprompt,
23
+ model,
24
+ });
25
+
26
+ const r = await fetch(`${url}/stream`, {
27
+ method: "POST",
28
+ headers: {
29
+ "Content-Type": "application/json",
30
+ },
31
+ body: JSON.stringify({
32
+ input: { text: prompt },
33
+ }),
34
+ });
35
+
36
+ if (!r.ok) {
37
+ throw new Error(`Failed to generate text: ${await r.text()}`);
38
+ }
39
+
40
+ const encoder = new TextDecoderStream();
41
+ const reader = r.body?.pipeThrough(encoder).getReader();
42
+
43
+ return (async function* () {
44
+ let stop = false;
45
+ let generatedText = "";
46
+ let tokenId = 0;
47
+ let accumulatedData = ""; // Buffer to accumulate data chunks
48
+
49
+ while (!stop) {
50
+ // Read the stream and log the outputs to console
51
+ const out = (await reader?.read()) ?? { done: false, value: undefined };
52
+
53
+ // If it's done, we cancel
54
+ if (out.done) {
55
+ reader?.cancel();
56
+ return;
57
+ }
58
+
59
+ if (!out.value) {
60
+ return;
61
+ }
62
+
63
+ // Accumulate the data chunk
64
+ accumulatedData += out.value;
65
+ // Keep read data to check event type
66
+ const eventData = out.value;
67
+
68
+ // Process each complete JSON object in the accumulated data
69
+ while (accumulatedData.includes("\n")) {
70
+ // Assuming each JSON object ends with a newline
71
+ const endIndex = accumulatedData.indexOf("\n");
72
+ let jsonString = accumulatedData.substring(0, endIndex).trim();
73
+ // Remove the processed part from the buffer
74
+
75
+ accumulatedData = accumulatedData.substring(endIndex + 1);
76
+
77
+ // Stopping with end event
78
+ if (eventData.startsWith("event: end")) {
79
+ stop = true;
80
+ yield {
81
+ token: {
82
+ id: tokenId++,
83
+ text: "",
84
+ logprob: 0,
85
+ special: true,
86
+ },
87
+ generated_text: generatedText,
88
+ details: null,
89
+ } satisfies TextGenerationStreamOutput;
90
+ reader?.cancel();
91
+ continue;
92
+ }
93
+
94
+ if (eventData.startsWith("event: data") && jsonString.startsWith("data: ")) {
95
+ jsonString = jsonString.slice(6);
96
+ let data = null;
97
+
98
+ // Handle the parsed data
99
+ try {
100
+ data = JSON.parse(jsonString);
101
+ } catch (e) {
102
+ console.error("Failed to parse JSON", e);
103
+ console.error("Problematic JSON string:", jsonString);
104
+ continue; // Skip this iteration and try the next chunk
105
+ }
106
+ // Assuming content within data is a plain string
107
+ if (data) {
108
+ generatedText += data;
109
+ const output: TextGenerationStreamOutput = {
110
+ token: {
111
+ id: tokenId++,
112
+ text: data,
113
+ logprob: 0,
114
+ special: false,
115
+ },
116
+ generated_text: null,
117
+ details: null,
118
+ };
119
+ yield output;
120
+ }
121
+ }
122
+ }
123
+ }
124
+ })();
125
+ };
126
+ }
127
+
128
+ export default endpointLangserve;
src/lib/server/models.ts CHANGED
@@ -177,6 +177,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
177
  return await endpoints.cloudflare(args);
178
  case "cohere":
179
  return await endpoints.cohere(args);
 
 
180
  default:
181
  // for legacy reason
182
  return endpoints.tgi(args);
 
177
  return await endpoints.cloudflare(args);
178
  case "cohere":
179
  return await endpoints.cohere(args);
180
+ case "langserve":
181
+ return await endpoints.langserve(args);
182
  default:
183
  // for legacy reason
184
  return endpoints.tgi(args);