TiberiuCristianLeon commited on
Commit
49ae858
·
verified ·
1 Parent(s): 2130cf2

Update src/translate/Translate.py

Browse files
Files changed (1) hide show
  1. src/translate/Translate.py +24 -30
src/translate/Translate.py CHANGED
@@ -17,7 +17,7 @@ modelROMENG.to(device)
17
  modelENGROM.to(device)
18
 
19
 
20
- def paraphraseTranslateMethod(requestValue : str):
21
 
22
  exception = ""
23
  result_value = ""
@@ -30,33 +30,27 @@ def paraphraseTranslateMethod(requestValue : str):
30
  tokenized_sent_list = sent_tokenize(requestValue)
31
 
32
  for SENTENCE in tokenized_sent_list:
33
-
34
- input_ids1 = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
35
-
36
- output1 = modelROMENG.generate(
37
- input_ids=input_ids1.input_ids,
38
- do_sample=True,
39
- max_length=256,
40
- top_k=90,
41
- top_p=0.97,
42
- early_stopping=False
43
- )
44
-
45
- result1 = tokenizerROMENG.batch_decode(output1, skip_special_tokens=True)[0]
46
-
47
- input_ids = tokenizerENGROM(result1, return_tensors='pt').to(device)
48
-
49
- output = modelENGROM.generate(
50
- input_ids=input_ids.input_ids,
51
- do_sample=True,
52
- max_length=256,
53
- top_k=90,
54
- top_p=0.97,
55
- early_stopping=False
56
- )
57
-
58
  result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
59
-
60
- result_value += result + " "
61
-
62
- return result_value, ""
 
17
  modelENGROM.to(device)
18
 
19
 
20
+ def paraphraseTranslateMethod(requestValue : str, model: str):
21
 
22
  exception = ""
23
  result_value = ""
 
30
  tokenized_sent_list = sent_tokenize(requestValue)
31
 
32
  for SENTENCE in tokenized_sent_list:
33
+ if model == 'roen'
34
+ input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
35
+ output = modelROMENG.generate(
36
+ input_ids=input_ids1.input_ids,
37
+ do_sample=True,
38
+ max_length=512,
39
+ top_k=90,
40
+ top_p=0.97,
41
+ early_stopping=False
42
+ )
43
+ result = tokenizerROMENG.batch_decode(output1, skip_special_tokens=True)[0]
44
+ else:
45
+ input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
46
+
47
+ output = modelENGROM.generate(
48
+ input_ids=input_ids.input_ids,
49
+ do_sample=True,
50
+ max_length=512,
51
+ top_k=90,
52
+ top_p=0.97,
53
+ early_stopping=False
54
+ )
 
 
 
55
  result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
56
+ return result.strip(), ""