AngoHF commited on
Commit
8d1a039
1 Parent(s): cce5d60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import (
7
  )
8
  from peft import PeftModel
9
  import torch
 
10
 
11
  model_path = "Qwen/Qwen1.5-1.8B-Chat"
12
  lora_path = "AngoHF/EssayGPT" #+ "/checkpoint-100"
@@ -31,11 +32,7 @@ model = PeftModel.from_pretrained(model, lora_path)
31
  model = model.merge_and_unload()
32
  model.eval()
33
 
34
- # model.config.use_cache = True
35
- # model.to("cpu")
36
- # model.save_pretrained("/data/ango/EssayGPT")
37
-
38
- # tokenizer.save_pretrained("/data/ango/EssayGPT")
39
 
40
 
41
  MAX_MATERIALS = 4
@@ -55,11 +52,14 @@ def call(related_materials, materials, question):
55
  add_generation_prompt=True
56
  )
57
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
58
- print(len(model_inputs.input_ids[0]))
 
59
  generated_ids = model.generate(
60
  model_inputs.input_ids,
61
  max_length=8096
62
  )
 
 
63
  generated_ids = [
64
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
65
  ]
@@ -107,7 +107,7 @@ def build_ui(components):
107
  def run():
108
  app = create_ui()
109
  app.queue()
110
- app.launch(share=True)
111
 
112
 
113
  if __name__ == '__main__':
 
7
  )
8
  from peft import PeftModel
9
  import torch
10
+ import time
11
 
12
  model_path = "Qwen/Qwen1.5-1.8B-Chat"
13
  lora_path = "AngoHF/EssayGPT" #+ "/checkpoint-100"
 
32
  model = model.merge_and_unload()
33
  model.eval()
34
 
35
+ model.config.use_cache = True
 
 
 
 
36
 
37
 
38
  MAX_MATERIALS = 4
 
52
  add_generation_prompt=True
53
  )
54
  model_inputs = tokenizer([text], return_tensors="pt").to(device)
55
+ print(f"Input Token Length: {len(model_inputs.input_ids[0])}")
56
+ start_time = time.time()
57
  generated_ids = model.generate(
58
  model_inputs.input_ids,
59
  max_length=8096
60
  )
61
+
62
+ print(f"Inference Cost Time: {time.time() - start_time}")
63
  generated_ids = [
64
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
65
  ]
 
107
  def run():
108
  app = create_ui()
109
  app.queue()
110
+ app.launch()
111
 
112
 
113
  if __name__ == '__main__':