tungdop2 commited on
Commit
affa4ce
·
1 Parent(s): feb3cba

fix docker

Browse files
Files changed (3) hide show
  1. model.py +32 -21
  2. packages.txt +0 -3
  3. requirements.txt +5 -7
model.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
- from vllm import LLM, SamplingParams
 
4
  import logging
5
 
6
  # Configure logging
@@ -9,41 +10,51 @@ logging.basicConfig(
9
  )
10
  logger = logging.getLogger(__name__)
11
 
 
12
  class ChallengePromptGenerator:
13
  def __init__(
14
  self,
15
- model_local_dir="./checkpoint-15000",
16
  ):
17
- self.generator = LLM(
18
- model_local_dir,
19
- )
20
-
21
 
22
  def infer_prompt(
23
- self,
24
  prompts,
25
  max_generation_length=77,
26
  beam_size=1,
27
  sampling_temperature=0.9,
28
  sampling_topk=1,
29
- sampling_topp=1,
30
  ):
31
- added_prompts = [f"{self.generator.get_tokenizer().bos_token} {prompt}" for prompt in prompts]
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- sampling_params = SamplingParams(
34
- max_tokens=max_generation_length,
 
 
 
35
  temperature=sampling_temperature,
36
  top_k=sampling_topk,
37
  top_p=sampling_topp,
38
- use_beam_search=(beam_size > 1),
 
39
  )
40
 
41
- outputs = self.generator.generate(added_prompts, sampling_params)
42
- out = []
43
- for i in range(len(outputs)):
44
- tmp_out = prompts[i] + outputs[i].outputs[0].text
45
- # droop last unfished sentence
46
- if tmp_out[-1] != ".":
47
- tmp_out = ".".join(tmp_out.split(".")[:-1])
48
- out.append(tmp_out)
49
- return out
 
1
  import os
2
  import torch
3
+ # from vllm import LLM, SamplingParams
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import logging
6
 
7
  # Configure logging
 
10
  )
11
  logger = logging.getLogger(__name__)
12
 
13
+
14
  class ChallengePromptGenerator:
15
  def __init__(
16
  self,
17
+ model_local_dir="checkpoint-15000",
18
  ):
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.generator = AutoModelForCausalLM.from_pretrained(model_local_dir, device_map=self.device)
21
+ self.generator.to_bettertransformer()
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_local_dir)
23
 
24
  def infer_prompt(
25
+ self,
26
  prompts,
27
  max_generation_length=77,
28
  beam_size=1,
29
  sampling_temperature=0.9,
30
  sampling_topk=1,
31
+ sampling_topp=1
32
  ):
33
+ # Add bos
34
+ prompts = [f"{self.tokenizer.bos_token} {prompt}" for prompt in prompts]
35
+
36
+ # Prepare inputs
37
+ inputs = self.tokenizer(
38
+ prompts,
39
+ return_tensors="pt",
40
+ padding=True,
41
+ truncation=True,
42
+ max_length=256,
43
+ add_special_tokens=False
44
+ ).to(self.device)
45
 
46
+ # Generate
47
+ outputs = self.generator.generate(
48
+ **inputs,
49
+ max_length=max_generation_length,
50
+ num_beams=beam_size,
51
  temperature=sampling_temperature,
52
  top_k=sampling_topk,
53
  top_p=sampling_topp,
54
+ do_sample=True,
55
+ pad_token_id=self.tokenizer.pad_token_id
56
  )
57
 
58
+ # Decode
59
+ decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
60
+ return decoded_outputs
 
 
 
 
 
 
packages.txt DELETED
@@ -1,3 +0,0 @@
1
- gcc-12
2
- g++-12
3
- libnuma-dev
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,5 @@
1
- wheel
2
- packaging
3
- ninja
4
- numpy
5
- gradio
6
- torch -f https://download.pytorch.org/whl/cpu
7
- vllm
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ 'transformers[torch]'
5
+ optimum