gputest / app.py
feng2022's picture
Update app.py
02422ea
raw
history blame
1.05 kB
import numpy as np
import argparse
import functools
import os
import pickle
import sys
from datasets import Dataset
import gradio as gr
from pynvml import *
from transformers import pipeline
pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
def predict(text):
return pipe(text)[0]["translation_text"]
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used//1024**2} MB.")
def print_summary(result):
print(f"Time: {result.metrics['train_runtime']:.2f}")
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
print_gpu_utilization()
seq_len, dataset_size = 512, 512
dummy_data = {
"input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),
"labels": np.random.randint(0, 1, (dataset_size)),
}
ds = Dataset.from_dict(dummy_data)
ds.set_format("pt")
iface = gr.Interface(
fn=predict,
inputs='text',
outputs='text',
examples=[[f'result']]
)
iface.launch()