File size: 2,912 Bytes
755dd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import { fetchEventData } from 'fetch-sse';
import { httpRequest }  from '../utils/utils';
import memoryCache from '../cache';
import { BaseChat } from './base/base';
import { IChatInputMessage, IStreamHandler } from '../interface';
import { type MemoryCache } from 'cache-manager';

const BASE_URL =
  'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat';

const TokenUrl = 'https://aip.baidubce.com/oauth/2.0/token';

export class BaiduChat implements BaseChat {
  private key?: string;
  private secret?: string;
  private cache: MemoryCache;
  public platform = 'baidu';

  constructor() {
    this.key = process.env.BAIDU_KEY;
    this.secret = process.env.BAIDU_SECRET;
    this.cache = memoryCache;
  }

  public async chat(
    messages: IChatInputMessage[],
    model: string,
    system?: string
  ) {
    const token = await this.getAccessToken();
    const res = await httpRequest({
      endpoint: `${BASE_URL}/${model}`,
      method: 'POST',
      query: {
        access_token: token
      },
      data: JSON.stringify({
        messages,
        system,
        stream: false,
      })
    });
    const data = await res.json();
    if (data.error_code) {
      const msg = `${data.error_code}: ${data.error_msg}`;
      throw new Error(msg);
    }
    return data.result;
  }

  public async chatStream(
    messages: IChatInputMessage[],
    onMessage: IStreamHandler,
    model: string,
    system?: string
  ): Promise<void> {
    const token = await this.getAccessToken();
    const url = `${BASE_URL}/${model}?access_token=${token}`;
    const abort = new AbortController();
    await fetchEventData(url, {
      method: 'POST',
      data: {
        messages,
        system,
        stream: true,
      },
      signal: abort.signal,
      onMessage: (eventData) => {
        const data = eventData?.data;
        const result = JSON.parse(data || '{}');
        const msg = result.result ?? '';
        onMessage(msg, false);
      }
    });
  }

  /**
   * @description baidu access_token默认有效期30天,单位是秒,生产环境注意及时刷新。
   */
  protected async getAccessToken(): Promise<string> {
    if (!this.key || !this.secret) {
      throw new Error('Invalid Baidu params: key or secret');
    }
    const { key, secret } = this;
    const cachedToken: string | undefined = await this.cache.get(key);
    if (cachedToken) {
      return cachedToken;
    }
    const res = await httpRequest({
      method: 'POST',
      endpoint: TokenUrl,
      query: {
        grant_type: 'client_credentials',
        client_id: key,
        client_secret: secret,
      }
    });
    const data = await res.json();
    if (data?.error) {
      throw new Error(data.error);
    }
    const { access_token, expires_in } = data;
    this.cache.set(key, access_token, expires_in - 10);
    return access_token;
  }
}

export const baidu = new BaiduChat();