kz209 commited on
Commit
26b45a8
1 Parent(s): f664ce2
Files changed (1) hide show
  1. pages/summarization_playground.py +6 -13
pages/summarization_playground.py CHANGED
@@ -33,8 +33,7 @@ Back in Boston, Kidd is going to rely on Lively even more. He'll play close to 3
33
  random_label: ""
34
  }
35
 
36
-
37
- def get_model_batch_generation(model_name):
38
  global __model_on_gpu__
39
 
40
  if __model_on_gpu__ != model_name:
@@ -47,21 +46,15 @@ def get_model_batch_generation(model_name):
47
  model[model_name] = Model(model_name)
48
  __model_on_gpu__ = model_name
49
 
50
- return model[model_name]
51
 
 
 
52
 
53
- def generate_answer(sources, model_name, prompt):
54
- global __model_on_gpu__
55
 
56
- if __model_on_gpu__ != model_name:
57
- if __model_on_gpu__:
58
- logging.info(f"delete model {__model_on_gpu__}")
59
- del model[__model_on_gpu__]
60
- gc.collect()
61
- torch.cuda.empty_cache()
62
 
63
- model[model_name] = Model(model_name)
64
- __model_on_gpu__ = model_name
65
 
66
  content = prompt + '\n{' + sources + '}\n\nsummary:'
67
 
 
33
  random_label: ""
34
  }
35
 
36
+ def model_device_check(model_name):
 
37
  global __model_on_gpu__
38
 
39
  if __model_on_gpu__ != model_name:
 
46
  model[model_name] = Model(model_name)
47
  __model_on_gpu__ = model_name
48
 
 
49
 
50
+ def get_model_batch_generation(model_name):
51
+ model_device_check(model_name)
52
 
53
+ return model[model_name]
 
54
 
 
 
 
 
 
 
55
 
56
+ def generate_answer(sources, model_name, prompt):
57
+ model_device_check(model_name)
58
 
59
  content = prompt + '\n{' + sources + '}\n\nsummary:'
60