黄腾 aopstudio commited on
Commit
36d0b06
·
1 Parent(s): b37fedc

add support for LocalLLM (#1744)

Browse files

### What problem does this PR solve?

add support for LocalLLM

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

Files changed (2) hide show
  1. rag/llm/chat_model.py +36 -23
  2. rag/svr/jina_server.py +93 -0
rag/llm/chat_model.py CHANGED
@@ -27,6 +27,8 @@ from groq import Groq
27
  import os
28
  import json
29
  import requests
 
 
30
 
31
  class Base(ABC):
32
  def __init__(self, key, model_name, base_url):
@@ -381,8 +383,10 @@ class LocalLLM(Base):
381
 
382
  def __conn(self):
383
  from multiprocessing.connection import Client
 
384
  self._connection = Client(
385
- (self.host, self.port), authkey=b'infiniflow-token4kevinhu')
 
386
 
387
  def __getattr__(self, name):
388
  import pickle
@@ -390,8 +394,7 @@ class LocalLLM(Base):
390
  def do_rpc(*args, **kwargs):
391
  for _ in range(3):
392
  try:
393
- self._connection.send(
394
- pickle.dumps((name, args, kwargs)))
395
  return pickle.loads(self._connection.recv())
396
  except Exception as e:
397
  self.__conn()
@@ -399,35 +402,45 @@ class LocalLLM(Base):
399
 
400
  return do_rpc
401
 
402
- def __init__(self, key, model_name="glm-3-turbo"):
403
- self.client = LocalLLM.RPCProxy("127.0.0.1", 7860)
404
 
405
- def chat(self, system, history, gen_conf):
406
- if system:
407
- history.insert(0, {"role": "system", "content": system})
408
- try:
409
- ans = self.client.chat(
410
- history,
411
- gen_conf
412
- )
413
- return ans, num_tokens_from_string(ans)
414
- except Exception as e:
415
- return "**ERROR**: " + str(e), 0
416
 
417
- def chat_streamly(self, system, history, gen_conf):
418
  if system:
419
  history.insert(0, {"role": "system", "content": system})
420
- token_count = 0
 
 
 
 
421
  answer = ""
422
  try:
423
- for ans in self.client.chat_streamly(history, gen_conf):
424
- answer += ans
425
- token_count += 1
426
- yield answer
 
 
 
 
 
 
427
  except Exception as e:
428
  yield answer + "\n**ERROR**: " + str(e)
 
 
 
 
 
 
 
 
429
 
430
- yield token_count
 
 
431
 
432
 
433
  class VolcEngineChat(Base):
 
27
  import os
28
  import json
29
  import requests
30
+ import asyncio
31
+ from rag.svr.jina_server import Prompt,Generation
32
 
33
  class Base(ABC):
34
  def __init__(self, key, model_name, base_url):
 
383
 
384
  def __conn(self):
385
  from multiprocessing.connection import Client
386
+
387
  self._connection = Client(
388
+ (self.host, self.port), authkey=b"infiniflow-token4kevinhu"
389
+ )
390
 
391
  def __getattr__(self, name):
392
  import pickle
 
394
  def do_rpc(*args, **kwargs):
395
  for _ in range(3):
396
  try:
397
+ self._connection.send(pickle.dumps((name, args, kwargs)))
 
398
  return pickle.loads(self._connection.recv())
399
  except Exception as e:
400
  self.__conn()
 
402
 
403
  return do_rpc
404
 
405
+ def __init__(self, key, model_name):
406
+ from jina import Client
407
 
408
+ self.client = Client(port=12345, protocol="grpc", asyncio=True)
 
 
 
 
 
 
 
 
 
 
409
 
410
+ def _prepare_prompt(self, system, history, gen_conf):
411
  if system:
412
  history.insert(0, {"role": "system", "content": system})
413
+ if "max_tokens" in gen_conf:
414
+ gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
415
+ return Prompt(message=history, gen_conf=gen_conf)
416
+
417
+ def _stream_response(self, endpoint, prompt):
418
  answer = ""
419
  try:
420
+ res = self.client.stream_doc(
421
+ on=endpoint, inputs=prompt, return_type=Generation
422
+ )
423
+ loop = asyncio.get_event_loop()
424
+ try:
425
+ while True:
426
+ answer = loop.run_until_complete(res.__anext__()).text
427
+ yield answer
428
+ except StopAsyncIteration:
429
+ pass
430
  except Exception as e:
431
  yield answer + "\n**ERROR**: " + str(e)
432
+ yield num_tokens_from_string(answer)
433
+
434
+ def chat(self, system, history, gen_conf):
435
+ prompt = self._prepare_prompt(system, history, gen_conf)
436
+ chat_gen = self._stream_response("/chat", prompt)
437
+ ans = next(chat_gen)
438
+ total_tokens = next(chat_gen)
439
+ return ans, total_tokens
440
 
441
+ def chat_streamly(self, system, history, gen_conf):
442
+ prompt = self._prepare_prompt(system, history, gen_conf)
443
+ return self._stream_response("/stream", prompt)
444
 
445
 
446
  class VolcEngineChat(Base):
rag/svr/jina_server.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jina import Deployment
2
+ from docarray import BaseDoc
3
+ from jina import Executor, requests
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
5
+ import argparse
6
+ import torch
7
+
8
+
9
+ class Prompt(BaseDoc):
10
+ message: list[dict]
11
+ gen_conf: dict
12
+
13
+
14
+ class Generation(BaseDoc):
15
+ text: str
16
+
17
+
18
+ tokenizer = None
19
+ model_name = ""
20
+
21
+
22
+ class TokenStreamingExecutor(Executor):
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_name, device_map="auto", torch_dtype="auto"
27
+ )
28
+
29
+ @requests(on="/chat")
30
+ async def generate(self, doc: Prompt, **kwargs) -> Generation:
31
+ text = tokenizer.apply_chat_template(
32
+ doc.message,
33
+ tokenize=False,
34
+ )
35
+ inputs = tokenizer([text], return_tensors="pt")
36
+ generation_config = GenerationConfig(
37
+ **doc.gen_conf,
38
+ eos_token_id=tokenizer.eos_token_id,
39
+ pad_token_id=tokenizer.eos_token_id
40
+ )
41
+ generated_ids = self.model.generate(
42
+ inputs.input_ids, generation_config=generation_config
43
+ )
44
+ generated_ids = [
45
+ output_ids[len(input_ids) :]
46
+ for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
47
+ ]
48
+
49
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
+ yield Generation(text=response)
51
+
52
+ @requests(on="/stream")
53
+ async def task(self, doc: Prompt, **kwargs) -> Generation:
54
+ text = tokenizer.apply_chat_template(
55
+ doc.message,
56
+ tokenize=False,
57
+ )
58
+ input = tokenizer([text], return_tensors="pt")
59
+ input_len = input["input_ids"].shape[1]
60
+ max_new_tokens = 512
61
+ if "max_new_tokens" in doc.gen_conf:
62
+ max_new_tokens = doc.gen_conf.pop("max_new_tokens")
63
+ generation_config = GenerationConfig(
64
+ **doc.gen_conf,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ pad_token_id=tokenizer.eos_token_id
67
+ )
68
+ for _ in range(max_new_tokens):
69
+ output = self.model.generate(
70
+ **input, max_new_tokens=1, generation_config=generation_config
71
+ )
72
+ if output[0][-1] == tokenizer.eos_token_id:
73
+ break
74
+ yield Generation(
75
+ text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
76
+ )
77
+ input = {
78
+ "input_ids": output,
79
+ "attention_mask": torch.ones(1, len(output[0])),
80
+ }
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument("--model_name", type=str, help="Model name or path")
86
+ parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
87
+ args = parser.parse_args()
88
+ model_name = args.model_name
89
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
90
+ with Deployment(
91
+ uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
92
+ ) as dep:
93
+ dep.block()