BounharAbdelaziz commited on
Commit
2c7bfb2
·
verified ·
1 Parent(s): c9caa80

preload models for fast run

Browse files
Files changed (1) hide show
  1. app.py +45 -21
app.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  import torch
5
  import spaces
6
 
7
-
8
  # Define model paths
9
  MODEL_PATHS = {
10
  "Terjman-Nano-v2": "BounharAbdelaziz/Terjman-Nano-v2.0",
@@ -16,12 +15,50 @@ MODEL_PATHS = {
16
  # Load environment token
17
  TOKEN = os.environ['TOKEN']
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Translation function for Nano and Large models
20
  @spaces.GPU
21
- def translate_nano_large(text, model_path):
22
- translator = pipeline("translation", model=model_path, token=TOKEN)
23
  translated = translator(
24
- text,
25
  max_length=512,
26
  num_beams=4,
27
  no_repeat_ngram_size=3,
@@ -35,30 +72,17 @@ def translate_nano_large(text, model_path):
35
 
36
  # Translation function for Ultra and Supreme models
37
  @spaces.GPU
38
- def translate_ultra_supreme(text, model_path):
39
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
40
- print(f'[INFO] Using device: {device}')
41
- model = AutoModelForSeq2SeqLM.from_pretrained(model_path, token=TOKEN)
42
- tokenizer = AutoTokenizer.from_pretrained(model_path, src_lang="eng_Latn", tgt_lang="ary_Arab", token=TOKEN)
43
- translator = pipeline(
44
- "translation",
45
- model=model,
46
- tokenizer=tokenizer,
47
- max_length=512,
48
- src_lang="eng_Latn", # Keep src_lang and tgt_lang in the pipeline
49
- tgt_lang="ary_Arab",
50
- device=device,
51
- )
52
  translation = translator(text)[0]['translation_text']
53
  return translation
54
 
55
  # Main translation function
56
  def translate_text(text, model_choice):
57
- model_path = MODEL_PATHS[model_choice]
58
  if model_choice in ["Terjman-Nano-v2", "Terjman-Large-v2"]:
59
- return translate_nano_large(text, model_path)
60
  elif model_choice in ["Terjman-Ultra-v2", "Terjman-Supreme-v2"]:
61
- return translate_ultra_supreme(text, model_path)
62
  else:
63
  return "Invalid model selection."
64
 
 
4
  import torch
5
  import spaces
6
 
 
7
  # Define model paths
8
  MODEL_PATHS = {
9
  "Terjman-Nano-v2": "BounharAbdelaziz/Terjman-Nano-v2.0",
 
15
  # Load environment token
16
  TOKEN = os.environ['TOKEN']
17
 
18
+ # Preload models and tokenizers
19
+ def preload_models():
20
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+ print(f"[INFO] Using device: {device}")
22
+
23
+ # Load Nano and Large models
24
+ nano_large_models = {}
25
+ for model_name in ["Terjman-Nano-v2", "Terjman-Large-v2"]:
26
+ print(f"[INFO] Loading {model_name}...")
27
+ translator = pipeline(
28
+ "translation",
29
+ model=MODEL_PATHS[model_name],
30
+ token=TOKEN,
31
+ device=device if device.startswith("cuda") else -1
32
+ )
33
+ nano_large_models[model_name] = translator
34
+
35
+ # Load Ultra and Supreme models
36
+ ultra_supreme_models = {}
37
+ for model_name in ["Terjman-Ultra-v2", "Terjman-Supreme-v2"]:
38
+ print(f"[INFO] Loading {model_name}...")
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATHS[model_name], token=TOKEN).to(device)
40
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[model_name], token=TOKEN)
41
+ translator = pipeline(
42
+ "translation",
43
+ model=model,
44
+ tokenizer=tokenizer,
45
+ device=device if device.startswith("cuda") else -1,
46
+ src_lang="eng_Latn",
47
+ tgt_lang="ary_Arab"
48
+ )
49
+ ultra_supreme_models[model_name] = translator
50
+
51
+ return nano_large_models, ultra_supreme_models
52
+
53
+ # Preload all models
54
+ nano_large_models, ultra_supreme_models = preload_models()
55
+
56
  # Translation function for Nano and Large models
57
  @spaces.GPU
58
+ def translate_nano_large(text, model_name):
59
+ translator = nano_large_models[model_name]
60
  translated = translator(
61
+ text,
62
  max_length=512,
63
  num_beams=4,
64
  no_repeat_ngram_size=3,
 
72
 
73
  # Translation function for Ultra and Supreme models
74
  @spaces.GPU
75
+ def translate_ultra_supreme(text, model_name):
76
+ translator = ultra_supreme_models[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
77
  translation = translator(text)[0]['translation_text']
78
  return translation
79
 
80
  # Main translation function
81
  def translate_text(text, model_choice):
 
82
  if model_choice in ["Terjman-Nano-v2", "Terjman-Large-v2"]:
83
+ return translate_nano_large(text, model_choice)
84
  elif model_choice in ["Terjman-Ultra-v2", "Terjman-Supreme-v2"]:
85
+ return translate_ultra_supreme(text, model_choice)
86
  else:
87
  return "Invalid model selection."
88