erasmopurif's picture
First commit
d2a8669
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Predict classes for logistic regression model
SYNOPSIS::
SCRIPT [options]
Description
===========
Columns of Outputs:
1. true sample class number
2. predicted class number
3. sensitive feature
4. class 0 probability
5. class 1 probability
Delimiters of columns are a single space.
Options
=======
-i <INPUT>, --in <INPUT>
specify <INPUT> file name
-o <OUTPUT>, --out <OUTPUT>
specify <OUTPUT> file name
-m <MODEL>, --model <MODEL>
trained classifier (default "classification.model")
--ns
ignore sensitive features
--hideinfo
suppress output meta information
-q, --quiet
set logging level to ERROR, no messages unless errors
--rseed <RSEED>
random number seed. if None, use /dev/urandom (default None)
Attributes
==========
N_NS : int
the number of non sensitive features
"""
#==============================================================================
# Module metadata variables
#==============================================================================
__author__ = "Toshihiro Kamishima ( http://www.kamishima.net/ )"
__date__ = "2012/08/26"
__version__ = "3.0.0"
__copyright__ = "Copyright (c) 2012 Toshihiro Kamishima all rights reserved."
__license__ = "MIT License: http://www.opensource.org/licenses/mit-license.php"
__docformat__ = "restructuredtext en"
#==============================================================================
# Imports
#==============================================================================
import sys
import argparse
import os
import platform
from subprocess import getoutput
import logging
import datetime
import pickle
import numpy as np
# private modeules ------------------------------------------------------------
from fadm.util import fill_missing_with_mean
#==============================================================================
# Public symbols
#==============================================================================
__all__ = []
#==============================================================================
# Constants
#==============================================================================
N_NS = 1
#==============================================================================
# Module variables
#==============================================================================
#==============================================================================
# Classes
#==============================================================================
#==============================================================================
# Functions
#==============================================================================
#==============================================================================
# Main routine
#==============================================================================
def main(opt):
""" Main routine that exits with status code 0
"""
### pre process
# load model file
clr = pickle.load(opt.model)
clr_info = pickle.load(opt.model)
# read data
D = np.loadtxt(opt.infile)
# split data and process missing values
y = np.array(D[:, -1])
if opt.ns:
X = fill_missing_with_mean(D[:, :-(1 + N_NS)])
else:
X = fill_missing_with_mean(D[:, :-1])
S = np.atleast_2d(D[:, -(1 + N_NS):-1])
### main process
# set starting time
start_time = datetime.datetime.now()
start_utime = os.times()[0]
opt.start_time = start_time.isoformat()
logger.info("start time = " + start_time.isoformat())
# prediction and write results
p = clr.predict_proba(X)
# output prediction
n = 0
m = 0
for i in range(p.shape[0]):
c = np.argmax(p[i, :])
opt.outfile.write("%d %d " % (y[i], c))
opt.outfile.write(" ".join(S[i, :].astype(str)) + " ")
opt.outfile.write(str(p[i, 0]) + " " + str(p[i, 1]) + "\n")
n += 1
m += 1 if c == y[i] else 0
# set end and elapsed time
end_time = datetime.datetime.now()
end_utime = os.times()[0]
logger.info("end time = " + end_time.isoformat())
opt.end_time = end_time.isoformat()
logger.info("elapsed_time = " + str((end_time - start_time)))
opt.elapsed_time = str((end_time - start_time))
logger.info("elapsed_utime = " + str((end_utime - start_utime)))
opt.elapsed_utime = str((end_utime - start_utime))
### output
# add meta info
opt.nos_samples = n
logger.info('nos_samples = ' + str(opt.nos_samples))
opt.nos_correct_samples = m
logger.info('nos_correct_samples = ' + str(opt.nos_correct_samples))
opt.accuracy = m / float(n)
logger.info('accuracy = ' + str(opt.accuracy))
opt.negative_mean_prob = np.mean(p[:, 0])
logger.info('negative_mean_prob = ' + str(opt.negative_mean_prob))
opt.positive_mean_prob = np.mean(p[:, 1])
logger.info('positive_mean_prob = ' + str(opt.positive_mean_prob))
# output meta information
if opt.info:
for key in clr_info.keys():
opt.outfile.write("#classifier_%s=%s\n" %
(key, str(clr_info[key])))
for key, key_val in vars(opt).items():
opt.outfile.write("#%s=%s\n" % (key, str(key_val)))
### post process
# close file
if opt.infile != sys.stdin:
opt.infile.close()
if opt.outfile != sys.stdout:
opt.outfile.close()
if opt.model != sys.stdout:
opt.model.close()
sys.exit(0)
### Preliminary processes before executing a main routine
if __name__ == '__main__':
### set script name
script_name = sys.argv[0].split('/')[-1]
### init logging system
logger = logging.getLogger(script_name)
logging.basicConfig(level=logging.INFO,
format='[%(name)s: %(levelname)s'
' @ %(asctime)s] %(message)s')
### command-line option parsing
ap = argparse.ArgumentParser(
description='pydoc is useful for learning the details.')
# common options
ap.add_argument('--version', action='version',
version='%(prog)s ' + __version__)
apg = ap.add_mutually_exclusive_group()
apg.set_defaults(verbose=True)
apg.add_argument('--verbose', action='store_true')
apg.add_argument('-q', '--quiet', action='store_false', dest='verbose')
ap.add_argument("--rseed", type=int, default=None)
# basic file i/o
ap.add_argument('-i', '--in', dest='infile',
default=None, type=argparse.FileType('r'))
ap.add_argument('infilep', nargs='?', metavar='INFILE',
default=sys.stdin, type=argparse.FileType('r'))
ap.add_argument('-o', '--out', dest='outfile',
default=None, type=argparse.FileType('w'))
ap.add_argument('outfilep', nargs='?', metavar='OUTFILE',
default=sys.stdout, type=argparse.FileType('w'))
# script specific options
ap.add_argument('-m', '--model', type=argparse.FileType('rb'),
required=True)
ap.set_defaults(ns=False)
ap.add_argument("--ns", dest="ns", action="store_true")
ap.set_defaults(info=True)
ap.add_argument('--hideinfo', dest='info', action='store_false')
# parsing
opt = ap.parse_args()
# post-processing for command-line options
# disable logging messages by changing logging level
if not opt.verbose:
logger.setLevel(logging.ERROR)
# set random seed
np.random.seed(opt.rseed)
# basic file i/o
if opt.infile is None:
opt.infile = opt.infilep
del vars(opt)['infilep']
logger.info("input_file = " + opt.infile.name)
if opt.outfile is None:
opt.outfile = opt.outfilep
del vars(opt)['outfilep']
logger.info("output_file = " + opt.outfile.name)
### set meta-data of script and machine
opt.script_name = script_name
opt.script_version = __version__
opt.python_version = platform.python_version()
opt.sys_uname = platform.uname()
if platform.system() == 'Darwin':
opt.sys_info =\
getoutput('system_profiler'
' -detailLevel mini SPHardwareDataType')\
.split('\n')[4:-1]
elif platform.system() == 'FreeBSD':
opt.sys_info = getoutput('sysctl hw').split('\n')
elif platform.system() == 'Linux':
opt.sys_info = getoutput('cat /proc/cpuinfo').split('\n')
### suppress warnings in numerical computation
np.seterr(all='ignore')
### call main routine
main(opt)