PoseyATX commited on
Commit
dfdcf80
·
1 Parent(s): 99bed5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -2,6 +2,19 @@ import sys
2
  import numpy as np
3
  import gradio as gr
4
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def greet(payload):
 
2
  import numpy as np
3
  import gradio as gr
4
  import requests
5
+ from accelerate import Accelerator
6
+
7
+ accelerator = Accelerator(gradient_accumulation_steps=2)
8
+ dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
9
+
10
+ with accelerator.accumulate():
11
+ for input, output in dataloader:
12
+ outputs = model(input)
13
+ loss = loss_func(outputs)
14
+ loss.backward()
15
+ optimizer.step()
16
+ scheduler.step()
17
+ optimizer.zero_grad()
18
 
19
 
20
  def greet(payload):