trysem commited on
Commit
c529ee8
·
verified ·
1 Parent(s): 92fd8d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ from IndicTransToolkit import IndicProcessor
4
+ # recommended to run this on a gpu with flash_attn installed
5
+ # don't set attn_implemetation if you don't have flash_attn
6
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ src_lang, tgt_lang = "eng_Latn", "hin_Deva"
9
+ model_name = "ai4bharat/indictrans2-en-indic-1B"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
11
+
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(
13
+ model_name,
14
+ trust_remote_code=True,
15
+ torch_dtype=torch.float16, # performance might slightly vary for bfloat16
16
+ attn_implementation="flash_attention_2"
17
+ ).to(DEVICE)
18
+
19
+ ip = IndicProcessor(inference=True)
20
+
21
+ input_sentences = [
22
+ "When I was young, I used to go to the park every day.",
23
+ "We watched a new movie last week, which was very inspiring.",
24
+ "If you had met me at that time, we would have gone out to eat.",
25
+ "My friend has invited me to his birthday party, and I will give him a gift.",
26
+ ]
27
+
28
+ batch = ip.preprocess_batch(
29
+ input_sentences,
30
+ src_lang=src_lang,
31
+ tgt_lang=tgt_lang,
32
+ )
33
+
34
+ # Tokenize the sentences and generate input encodings
35
+ inputs = tokenizer(
36
+ batch,
37
+ truncation=True,
38
+ padding="longest",
39
+ return_tensors="pt",
40
+ return_attention_mask=True,
41
+ ).to(DEVICE)
42
+
43
+ # Generate translations using the model
44
+ with torch.no_grad():
45
+ generated_tokens = model.generate(
46
+ **inputs,
47
+ use_cache=True,
48
+ min_length=0,
49
+ max_length=256,
50
+ num_beams=5,
51
+ num_return_sequences=1,
52
+ )
53
+
54
+ # Decode the generated tokens into text
55
+ with tokenizer.as_target_tokenizer():
56
+ generated_tokens = tokenizer.batch_decode(
57
+ generated_tokens.detach().cpu().tolist(),
58
+ skip_special_tokens=True,
59
+ clean_up_tokenization_spaces=True,
60
+ )
61
+
62
+ # Postprocess the translations, including entity replacement
63
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
64
+
65
+ for input_sentence, translation in zip(input_sentences, translations):
66
+ print(f"{src_lang}: {input_sentence}")
67
+ print(f"{tgt_lang}: {translation}")