TiberiuCristianLeon commited on
Commit
b41b21e
·
verified ·
1 Parent(s): ad9a50c

Update src/translate/Translate.py

Browse files
Files changed (1) hide show
  1. src/translate/Translate.py +16 -7
src/translate/Translate.py CHANGED
@@ -74,14 +74,23 @@ def gemma(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
74
 
75
  def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
76
  # Load model directly
77
- if '/' not in model:
78
- model = 'Gargaz/gemma-2b-romanian-better'
79
  # limit max_new_tokens to 150% of the requestValue
80
- max_new_tokens = int(len(requestValue) + len(requestValue) * 0.5)
81
- max_new_tokens = max_new_tokens if max_new_tokens % 2 == 0 else max_new_tokens + 1
82
- messages = [{"role": "user", "content": f"Translate this text to Romanian: {requestValue}"}]
 
 
 
 
 
 
 
 
 
 
83
  tokenizer = AutoTokenizer.from_pretrained("Gargaz/gemma-2b-romanian-better")
84
- model = AutoModelForCausalLM.from_pretrained("Gargaz/gemma-2b-romanian-better")
85
 
86
  inputs = tokenizer.apply_chat_template(
87
  messages,
@@ -91,6 +100,6 @@ def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-bette
91
  return_tensors="pt",
92
  ).to(device)
93
 
94
- outputs = model.generate(**inputs, max_new_tokens=max_tokens)
95
  response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
96
  return response, model
 
74
 
75
  def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
76
  # Load model directly
77
+ model = model if '/' in model else 'Gargaz/gemma-2b-romanian-better'
 
78
  # limit max_new_tokens to 150% of the requestValue
79
+ prompt = f"Translate this text to Romanian: {request_value}"
80
+
81
+ input_ids = tokenizer.encode(request_value, add_special_tokens=True)
82
+ num_tokens = len(input_ids)
83
+ # Estimate output length (e.g., 50% longer)
84
+ max_new_tokens = int(num_tokens * 1.5)
85
+ max_new_tokens += max_new_tokens % 2 # ensure it's even
86
+
87
+ # Token count estimation and safety check
88
+ # max_new_tokens = int(len(request_value) * 1.5)
89
+ # max_new_tokens += max_new_tokens % 2 # ensure it's even
90
+
91
+ messages = [{"role": "user", "content": prompt]
92
  tokenizer = AutoTokenizer.from_pretrained("Gargaz/gemma-2b-romanian-better")
93
+ model = AutoModelForCausalLM.from_pretrained("Gargaz/gemma-2b-romanian-better").to(device)
94
 
95
  inputs = tokenizer.apply_chat_template(
96
  messages,
 
100
  return_tensors="pt",
101
  ).to(device)
102
 
103
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
104
  response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
105
  return response, model