File size: 1,652 Bytes
5bcc945
e74c7c5
 
5bcc945
e74c7c5
5bcc945
9981df6
eded150
e16d912
7a73074
9981df6
9aabe06
d0dc5fc
0feee31
c2483bf
9981df6
 
 
eb04c56
eb89bef
3d201bf
cbf8f16
eb89bef
eb04c56
 
 
 
 
 
cbf8f16
9aabe06
 
 
9981df6
eb04c56
c92a84b
808b6c8
9981df6
8041893
e905d1f
808b6c8
 
36a2e00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import sys
#import subprocess
#from torch.utils.checkpoint import checkpoint
# implement pip as a subprocess:
#subprocess.check_call([sys.executable, '-m', 'pip', 'install','--quiet','sentencepiece==0.1.95'])

import gradio as gr
#from transformers import pipeline
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
model = torch.load("helsinki_fineTuned.pt", map_location=torch.device('cpu'))
model.eval()
#translation_pipeline = pipeline(model)


def translate_gradio(input):
    '''
    with tokenizer.as_target_tokenizer():
        input_ids = tokenizer(input, return_tensors='pt')
    encode = model.generate(**input_ids)
#    encode = model.generate(**tokenizer.prepare_seq2seq_batch(input,return_tensors='pt'))
    text_ar = tokenizer.batch_decode(encode,skip_special_tokens=True)[0]'''
    
    tokenized_text = tokenizer.prepare_seq2seq_batch([input], return_tensors='pt')
    
    # Perform translation and decode the output
    encode = model.generate(**tokenized_text)
    text_ar = tokenizer.batch_decode(encode,skip_special_tokens=True)[0]
    return text_ar





#description = 'Translating "English Data Science" content into Arabic'
translate_interface = gr.Interface(fn = translate_gradio,
                                   title = 'Translating "English Data Science" content into Arabic',
                                   inputs=gr.inputs.Textbox(lines = 7, label = 'english content'),
                                   outputs="Translated Output"
                                                               )
translate_interface.launch(inline = False)