nikshep01 commited on
Commit
b42d6cf
·
verified ·
1 Parent(s): 0413754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -20
app.py CHANGED
@@ -1,28 +1,55 @@
 
 
 
1
  from transformers import MarianMTModel, MarianTokenizer
2
- from flask import Flask, request, jsonify
3
 
4
- app = Flask(__name__)
 
 
 
 
 
5
 
6
- # Load pre-trained model and tokenizer
7
- model_name = 'Helsinki-NLP/opus-mt-en-fr'
8
- model = MarianMTModel.from_pretrained(model_name)
9
- tokenizer = MarianTokenizer.from_pretrained(model_name)
10
 
11
- @app.route('/translate', methods=['POST'])
12
- def translate():
13
- input_text = request.json['text']
14
-
15
- # Tokenize input text
16
- inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)
17
 
18
- # Perform translation
19
- with torch.no_grad():
20
- translated = model.generate(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Decode translated text
23
- translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
 
24
 
25
- return jsonify({'translation': translated_text})
 
 
26
 
27
- if __name__ == '__main__':
28
- app.run(debug=True)
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, Dataset
4
  from transformers import MarianMTModel, MarianTokenizer
 
5
 
6
+ # Define dataset class
7
+ class TranslationDataset(Dataset):
8
+ def __init__(self, source_sentences, target_sentences, tokenizer):
9
+ self.source_sentences = source_sentences
10
+ self.target_sentences = target_sentences
11
+ self.tokenizer = tokenizer
12
 
13
+ def __len__(self):
14
+ return len(self.source_sentences)
 
 
15
 
16
+ def __getitem__(self, idx):
17
+ source_text = self.source_sentences[idx]
18
+ target_text = self.target_sentences[idx]
19
+ source_tokens = self.tokenizer(source_text, return_tensors='pt', padding=True, truncation=True)
20
+ target_tokens = self.tokenizer(target_text, return_tensors='pt', padding=True, truncation=True)
21
+ return {'input_ids': source_tokens['input_ids'], 'labels': target_tokens['input_ids']}
22
 
23
+ # Define training function
24
+ def train(model, dataloader, optimizer, criterion, num_epochs):
25
+ model.train()
26
+ for epoch in range(num_epochs):
27
+ total_loss = 0.0
28
+ for batch in dataloader:
29
+ input_ids = batch['input_ids'].to(device)
30
+ labels = batch['labels'].to(device)
31
+ optimizer.zero_grad()
32
+ outputs = model(input_ids=input_ids, labels=labels)
33
+ loss = outputs.loss
34
+ loss.backward()
35
+ optimizer.step()
36
+ total_loss += loss.item()
37
+ print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')
38
 
39
+ # Load tokenizer and model
40
+ tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-fr')
41
+ model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-fr').to(device)
42
 
43
+ # Prepare dataset and dataloader
44
+ dataset = TranslationDataset(source_sentences, target_sentences, tokenizer)
45
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
46
 
47
+ # Define optimizer and criterion
48
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
49
+ criterion = nn.CrossEntropyLoss()
50
+
51
+ # Train the model
52
+ train(model, dataloader, optimizer, criterion, num_epochs=10)
53
+
54
+ # Save the trained model
55
+ torch.save(model.state_dict(), 'translation_model.pth')