File size: 3,103 Bytes
755dd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import { BaseChat } from './base/base';
import { IChatInputMessage, IStreamHandler } from '../interface';
import { DefaultSystem } from '../utils/constant';
import { httpRequest } from '../utils/utils';
import { fetchEventData } from 'fetch-sse';

const BASE_URL = 'https://generativelanguage.googleapis.com/v1beta';

const URLS = {
  geminiPro: '/models/gemini-pro:generateContent',
  geminiProStream: '/models/gemini-pro:streamGenerateContent?alt=sse',
};

export class GoogleChat implements BaseChat {
  private key?: string;
  private baseUrl?: string;
  public platform = 'google';

  constructor() {
    this.key = process.env.GOOGLE_KEY;
    this.baseUrl = process.env.GOOGLE_PROXY_URL || BASE_URL;
    console.log('GoogleAI BaseURL: ', this.baseUrl);
  }

  public async chat(
    messages: IChatInputMessage[],
    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    model: string
  ) {
    const msgs = this.transformMessage(messages);
    const url = `${this.baseUrl}/${URLS.geminiProStream}`;
    const res = await httpRequest({
      endpoint: url,
      method: 'POST',
      data: JSON.stringify({
        contents: msgs
      }),
      query: {
        key: this.key,
      },
    });
    const data = await res.json();
    const resMsg = data.candidates?.[0];
    if (res.status !== 200 || !resMsg) {
      throw new Error(data.message ?? 'Google AI request error.');
    }
    return resMsg.content?.parts[0]?.text;
  }

  public async chatStream(
    messages: IChatInputMessage[],
    onMessage: IStreamHandler,
    // eslint-disable-next-line @typescript-eslint/no-unused-vars
    model: string,
    system = DefaultSystem
  ) {
    const msgs = this.transformMessage(messages);
    if (system) {
      msgs.unshift({
        role: 'user',
        parts: [
          {
            text: system
          }
        ]
      }, {
        role: 'model',
        parts: [
          {
            text: 'ok.'
          }
        ]
      });
    }
    const url = `${this.baseUrl}${URLS.geminiProStream}`;
    const data = {
      contents: msgs
    };
    const abort = new AbortController();
    await fetchEventData(url, {
      method: 'POST',
      data,
      signal: abort.signal,
      headers: {
        'Content-Type': 'application/json',
        'x-goog-api-key': this.key
      },
      onOpen: async () => {
        //
      },
      onMessage: (eventData) => {
        const data = eventData?.data;
        const result = JSON.parse(data || '{}');
        const msg = result.candidates?.[0]?.content?.parts[0]?.text ?? '';
        onMessage(msg, false);
      },
      onClose: () => {
        onMessage(null, true);
      },
      onError: (error) => {
        abort.abort();
        console.log(error);
      },
    });
  }

  private transformMessage(messages: IChatInputMessage[]) {
    return messages.map(msg => {
      const role = msg.role === 'assistant' ? 'model' : 'user';
      return {
        role,
        parts: [
          {
            text: msg.content,
          },
        ],
      };
    });
  }
}

export const google = new GoogleChat();