File size: 5,159 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
import shutil
import json
from . import code_generator
from . import search_space_generator
from . import specific_code_generator
__all__ = ['generate_search_space', 'expand_annotations']
slash = '/'
if sys.platform == "win32":
slash = '\\'
def generate_search_space(code_dir):
"""Generate search space from Python source code.
Return a serializable search space object.
code_dir: directory path of source files (str)
"""
code_dir = str(code_dir)
search_space = {}
if code_dir.endswith(slash):
code_dir = code_dir[:-1]
for subdir, _, files in os.walk(code_dir):
# generate module name from path
if subdir == code_dir:
package = ''
else:
assert subdir.startswith(code_dir + slash), subdir
prefix_len = len(code_dir) + 1
package = subdir[prefix_len:].replace(slash, '.') + '.'
for file_name in files:
if file_name.endswith('.py'):
path = os.path.join(subdir, file_name)
module = package + file_name[:-3]
search_space.update(_generate_file_search_space(path, module))
return search_space
def _generate_file_search_space(path, module):
with open(path) as src:
try:
search_space, code = search_space_generator.generate(module, src.read())
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(path + ' ' + '\n'.join(exc.args))
else:
raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc))
with open(path, 'w') as dst:
dst.write(code)
return search_space
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None):
"""Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used
"""
src_dir, dst_dir = str(src_dir), str(dst_dir)
if src_dir[-1] == slash:
src_dir = src_dir[:-1]
if dst_dir[-1] == slash:
dst_dir = dst_dir[:-1]
annotated = False
for src_subdir, dirs, files in os.walk(src_dir):
assert src_subdir.startswith(src_dir)
dst_subdir = src_subdir.replace(src_dir, dst_dir, 1)
os.makedirs(dst_subdir, exist_ok=True)
# generate module name from path
if src_subdir == src_dir:
package = ''
else:
assert src_subdir.startswith(src_dir + slash), src_subdir
prefix_len = len(src_dir) + 1
package = src_subdir[prefix_len:].replace(slash, '.') + '.'
for file_name in files:
src_path = os.path.join(src_subdir, file_name)
dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'):
if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path, nas_mode)
else:
module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
else:
shutil.copyfile(src_path, dst_path)
for dir_name in dirs:
os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True)
return dst_dir if annotated else src_dir
def _expand_file_annotations(src_path, dst_path, nas_mode):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
annotated_code = code_generator.parse(src.read(), nas_mode)
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
with open(src_path) as src, open(dst_path, 'w') as dst:
try:
with open(os.path.expanduser('~/nni-experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd:
para_cfg = json.load(fd)
annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module)
if annotated_code is None:
shutil.copyfile(src_path, dst_path)
return False
dst.write(annotated_code)
return True
except Exception as exc: # pylint: disable=broad-except
if exc.args:
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
|