Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Predict classes for logistic regression model | |
SYNOPSIS:: | |
SCRIPT [options] | |
Description | |
=========== | |
Columns of Outputs: | |
1. true sample class number | |
2. predicted class number | |
3. sensitive feature | |
4. class 0 probability | |
5. class 1 probability | |
Delimiters of columns are a single space. | |
Options | |
======= | |
-i <INPUT>, --in <INPUT> | |
specify <INPUT> file name | |
-o <OUTPUT>, --out <OUTPUT> | |
specify <OUTPUT> file name | |
-m <MODEL>, --model <MODEL> | |
trained classifier (default "classification.model") | |
--ns | |
ignore sensitive features | |
--hideinfo | |
suppress output meta information | |
-q, --quiet | |
set logging level to ERROR, no messages unless errors | |
--rseed <RSEED> | |
random number seed. if None, use /dev/urandom (default None) | |
Attributes | |
========== | |
N_NS : int | |
the number of non sensitive features | |
""" | |
#============================================================================== | |
# Module metadata variables | |
#============================================================================== | |
__author__ = "Toshihiro Kamishima ( http://www.kamishima.net/ )" | |
__date__ = "2012/08/26" | |
__version__ = "3.0.0" | |
__copyright__ = "Copyright (c) 2012 Toshihiro Kamishima all rights reserved." | |
__license__ = "MIT License: http://www.opensource.org/licenses/mit-license.php" | |
__docformat__ = "restructuredtext en" | |
#============================================================================== | |
# Imports | |
#============================================================================== | |
import sys | |
import argparse | |
import os | |
import platform | |
from subprocess import getoutput | |
import logging | |
import datetime | |
import pickle | |
import numpy as np | |
# private modeules ------------------------------------------------------------ | |
from fadm.util import fill_missing_with_mean | |
#============================================================================== | |
# Public symbols | |
#============================================================================== | |
__all__ = [] | |
#============================================================================== | |
# Constants | |
#============================================================================== | |
N_NS = 1 | |
#============================================================================== | |
# Module variables | |
#============================================================================== | |
#============================================================================== | |
# Classes | |
#============================================================================== | |
#============================================================================== | |
# Functions | |
#============================================================================== | |
#============================================================================== | |
# Main routine | |
#============================================================================== | |
def main(opt): | |
""" Main routine that exits with status code 0 | |
""" | |
### pre process | |
# load model file | |
clr = pickle.load(opt.model) | |
clr_info = pickle.load(opt.model) | |
# read data | |
D = np.loadtxt(opt.infile) | |
# split data and process missing values | |
y = np.array(D[:, -1]) | |
if opt.ns: | |
X = fill_missing_with_mean(D[:, :-(1 + N_NS)]) | |
else: | |
X = fill_missing_with_mean(D[:, :-1]) | |
S = np.atleast_2d(D[:, -(1 + N_NS):-1]) | |
### main process | |
# set starting time | |
start_time = datetime.datetime.now() | |
start_utime = os.times()[0] | |
opt.start_time = start_time.isoformat() | |
logger.info("start time = " + start_time.isoformat()) | |
# prediction and write results | |
p = clr.predict_proba(X) | |
# output prediction | |
n = 0 | |
m = 0 | |
for i in range(p.shape[0]): | |
c = np.argmax(p[i, :]) | |
opt.outfile.write("%d %d " % (y[i], c)) | |
opt.outfile.write(" ".join(S[i, :].astype(str)) + " ") | |
opt.outfile.write(str(p[i, 0]) + " " + str(p[i, 1]) + "\n") | |
n += 1 | |
m += 1 if c == y[i] else 0 | |
# set end and elapsed time | |
end_time = datetime.datetime.now() | |
end_utime = os.times()[0] | |
logger.info("end time = " + end_time.isoformat()) | |
opt.end_time = end_time.isoformat() | |
logger.info("elapsed_time = " + str((end_time - start_time))) | |
opt.elapsed_time = str((end_time - start_time)) | |
logger.info("elapsed_utime = " + str((end_utime - start_utime))) | |
opt.elapsed_utime = str((end_utime - start_utime)) | |
### output | |
# add meta info | |
opt.nos_samples = n | |
logger.info('nos_samples = ' + str(opt.nos_samples)) | |
opt.nos_correct_samples = m | |
logger.info('nos_correct_samples = ' + str(opt.nos_correct_samples)) | |
opt.accuracy = m / float(n) | |
logger.info('accuracy = ' + str(opt.accuracy)) | |
opt.negative_mean_prob = np.mean(p[:, 0]) | |
logger.info('negative_mean_prob = ' + str(opt.negative_mean_prob)) | |
opt.positive_mean_prob = np.mean(p[:, 1]) | |
logger.info('positive_mean_prob = ' + str(opt.positive_mean_prob)) | |
# output meta information | |
if opt.info: | |
for key in clr_info.keys(): | |
opt.outfile.write("#classifier_%s=%s\n" % | |
(key, str(clr_info[key]))) | |
for key, key_val in vars(opt).items(): | |
opt.outfile.write("#%s=%s\n" % (key, str(key_val))) | |
### post process | |
# close file | |
if opt.infile != sys.stdin: | |
opt.infile.close() | |
if opt.outfile != sys.stdout: | |
opt.outfile.close() | |
if opt.model != sys.stdout: | |
opt.model.close() | |
sys.exit(0) | |
### Preliminary processes before executing a main routine | |
if __name__ == '__main__': | |
### set script name | |
script_name = sys.argv[0].split('/')[-1] | |
### init logging system | |
logger = logging.getLogger(script_name) | |
logging.basicConfig(level=logging.INFO, | |
format='[%(name)s: %(levelname)s' | |
' @ %(asctime)s] %(message)s') | |
### command-line option parsing | |
ap = argparse.ArgumentParser( | |
description='pydoc is useful for learning the details.') | |
# common options | |
ap.add_argument('--version', action='version', | |
version='%(prog)s ' + __version__) | |
apg = ap.add_mutually_exclusive_group() | |
apg.set_defaults(verbose=True) | |
apg.add_argument('--verbose', action='store_true') | |
apg.add_argument('-q', '--quiet', action='store_false', dest='verbose') | |
ap.add_argument("--rseed", type=int, default=None) | |
# basic file i/o | |
ap.add_argument('-i', '--in', dest='infile', | |
default=None, type=argparse.FileType('r')) | |
ap.add_argument('infilep', nargs='?', metavar='INFILE', | |
default=sys.stdin, type=argparse.FileType('r')) | |
ap.add_argument('-o', '--out', dest='outfile', | |
default=None, type=argparse.FileType('w')) | |
ap.add_argument('outfilep', nargs='?', metavar='OUTFILE', | |
default=sys.stdout, type=argparse.FileType('w')) | |
# script specific options | |
ap.add_argument('-m', '--model', type=argparse.FileType('rb'), | |
required=True) | |
ap.set_defaults(ns=False) | |
ap.add_argument("--ns", dest="ns", action="store_true") | |
ap.set_defaults(info=True) | |
ap.add_argument('--hideinfo', dest='info', action='store_false') | |
# parsing | |
opt = ap.parse_args() | |
# post-processing for command-line options | |
# disable logging messages by changing logging level | |
if not opt.verbose: | |
logger.setLevel(logging.ERROR) | |
# set random seed | |
np.random.seed(opt.rseed) | |
# basic file i/o | |
if opt.infile is None: | |
opt.infile = opt.infilep | |
del vars(opt)['infilep'] | |
logger.info("input_file = " + opt.infile.name) | |
if opt.outfile is None: | |
opt.outfile = opt.outfilep | |
del vars(opt)['outfilep'] | |
logger.info("output_file = " + opt.outfile.name) | |
### set meta-data of script and machine | |
opt.script_name = script_name | |
opt.script_version = __version__ | |
opt.python_version = platform.python_version() | |
opt.sys_uname = platform.uname() | |
if platform.system() == 'Darwin': | |
opt.sys_info =\ | |
getoutput('system_profiler' | |
' -detailLevel mini SPHardwareDataType')\ | |
.split('\n')[4:-1] | |
elif platform.system() == 'FreeBSD': | |
opt.sys_info = getoutput('sysctl hw').split('\n') | |
elif platform.system() == 'Linux': | |
opt.sys_info = getoutput('cat /proc/cpuinfo').split('\n') | |
### suppress warnings in numerical computation | |
np.seterr(all='ignore') | |
### call main routine | |
main(opt) | |