jyh-zero / app.py
hiascend's picture
Update app.py
376c53e verified
raw
history blame
950 Bytes
import gradio as gr
import spaces
import torch
torch.set_default_device('cuda')
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1, device='cuda')
def forward(self, x):
out = self.linear(x)
print('weight device: ' + str(self.linear.weight.device))
return out
def process():
print('cuda avaliable: ' + str(torch.cuda.is_available()))
print('cuda device count: ' + str(torch.cuda.device_count()))
model = LinearModel().cuda()
x = torch.ones(1, device='cuda')
y = model(x)
print(model)
print(x)
print(y)
@spaces.GPU
def greet(n):
process()
return f"Hello {n} Tensor"
def func(n):
# step1, on cpu
process()
# step2, on gpu
res = greet(n)
# step3, on cpu
process()
return res
gr.Interface(fn=func, inputs=gr.Number(), outputs=gr.Text()).launch()