Jaane commited on
Commit
4d40b03
·
verified ·
1 Parent(s): 04cb90a

gpu update

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -10,16 +10,16 @@ warnings.filterwarnings("ignore")
10
  os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
11
 
12
  # Set random state for reproducibility
13
- def random_state(seed):
14
- torch.manual_seed(seed)
15
- if torch.cuda.is_available():
16
- torch.cuda.manual_seed_all(seed)
 
17
 
18
- random_state(1234)
19
 
20
  # Initialize Parrot model
21
  parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5")
22
-
23
  # Function to paraphrase a single sentence
24
  def paraphrase_sentence(sentence):
25
  paraphrases = parrot.augment(input_phrase=sentence, max_return_phrases=10, max_length=100, adequacy_threshold=0.75, fluency_threshold=0.75)
 
10
  os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
11
 
12
  # Set random state for reproducibility
13
+ torch.manual_seed(1234)
14
+ if torch.cuda.is_available():
15
+ torch.cuda.manual_seed_all(1234)
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ print(f"Using device: {device}")
18
 
 
19
 
20
  # Initialize Parrot model
21
  parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5")
22
+ parrot.model.to(device)
23
  # Function to paraphrase a single sentence
24
  def paraphrase_sentence(sentence):
25
  paraphrases = parrot.augment(input_phrase=sentence, max_return_phrases=10, max_length=100, adequacy_threshold=0.75, fluency_threshold=0.75)