Source code for sciunit.utils
"""
Utility functions for SciUnit.
"""
from __future__ import print_function
import os
import sys
import warnings
import inspect
import pkgutil
import importlib
import json
import re
import contextlib
import traceback
from io import TextIOWrapper, StringIO
from datetime import datetime
try:
from tempfile import TemporaryDirectory
except ImportError:
from tempfile import TemporaryDirectory
import bs4
import nbformat
import nbconvert
from nbconvert.preprocessors import ExecutePreprocessor
try:
from nbconvert.preprocessors.execute import CellExecutionError
except:
from nbconvert.preprocessors import CellExecutionError
from quantities.dimensionality import Dimensionality
from quantities.quantity import Quantity
import cypy
from IPython.display import HTML, display
import sciunit
from sciunit.errors import Error
from .base import SciUnit, FileNotFoundError, tkinter
from .base import PLATFORM, PYTHON_MAJOR_VERSION
try:
import unittest.mock
mock = True
except ImportError:
mock = False
mock = False # mock is probably obviated by the unittest -b flag.
settings = {'PRINT_DEBUG_STATE': False, # printd does nothing by default.
'LOGGING': True,
'PREVALIDATE': False,
'KERNEL': ('ipykernel' in sys.modules),
'CWD': os.path.realpath(sciunit.__path__[0])}
[docs]def warn_with_traceback(message, category, filename, lineno,
file=None, line=None):
"""A function to use with `warnings.showwarning` to show a traceback."""
log = file if hasattr(file, 'write') else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(
message, category, filename, lineno, line))
[docs]def set_warnings_traceback(tb=True):
"""Set to `True` to give tracebacks for all warnings, or `False` to restore
default behavior."""
if tb:
warnings._showwarning = warnings.showwarning
warnings.showwarning = warn_with_traceback
warnings.simplefilter("always")
else:
warnings.showwarning = warnings._showwarning
warnings.simplefilter("default")
[docs]def dict_combine(*dict_list):
"""Return the union of several dictionaries.
Uses the values from later dictionaries in the argument list when
duplicate keys are encountered.
In Python 3 this can simply be {**d1, **d2, ...}
but Python 2 does not support this dict unpacking syntax.
"""
return {k: v for d in dict_list for k, v in d.items()}
[docs]def rec_apply(func, n):
"""
Used to determine parent directory n levels up
by repeatedly applying os.path.dirname
"""
if n > 1:
rec_func = rec_apply(func, n - 1)
return lambda x: func(rec_func(x))
return func
[docs]def printd_set(state):
"""Enable the printd function.
Call with True for all subsequent printd commands to be passed to print.
Call with False to ignore all subsequent printd commands.
"""
global settings
settings['PRINT_DEBUG_STATE'] = (state is True)
[docs]def printd(*args, **kwargs):
"""Print if PRINT_DEBUG_STATE is True"""
global settings
if settings['PRINT_DEBUG_STATE']:
print(*args, **kwargs)
return True
return False
if PYTHON_MAJOR_VERSION == 3:
redirect_stdout = contextlib.redirect_stdout
else: # Python 2
@contextlib.contextmanager
def redirect_stdout(target):
original = sys.stdout
sys.stdout = target
yield
sys.stdout = original
[docs]def assert_dimensionless(value):
"""
Tests for dimensionlessness of input.
If input is dimensionless but expressed as a Quantity, it returns the
bare value. If it not, it raised an error.
"""
if isinstance(value, Quantity):
value = value.simplified
if value.dimensionality == Dimensionality({}):
value = value.base.item()
else:
raise TypeError("Score value %s must be dimensionless" % value)
return value
[docs]class NotebookTools(object):
"""A class for manipulating and executing Jupyter notebooks."""
# Relative path to the parent directory of the notebook.
path = ''
# Name of directory where files generated by do_notebook are stored
gen_dir_name = 'GeneratedFiles'
# Number of levels up from notebook directory
# where generated files are stored
gen_file_level = 2
[docs] def __init__(self, *args, **kwargs):
super(NotebookTools, self).__init__(*args, **kwargs)
self.fix_display()
[docs] @classmethod
def convert_path(cls, file):
"""
Check to see if an extended path is given and convert appropriately
"""
if isinstance(file, str):
return file
elif isinstance(file, list) and \
all([isinstance(x, str) for x in file]):
return "/".join(file)
else:
print("Incorrect path specified")
return -1
[docs] def get_path(self, file):
"""Get the full path of the notebook found in the directory
specified by self.path.
"""
class_path = inspect.getfile(self.__class__)
parent_path = os.path.dirname(class_path)
path = os.path.join(parent_path, self.path, file)
return os.path.realpath(path)
[docs] def fix_display(self):
"""If this is being run on a headless system the Matplotlib
backend must be changed to one that doesn't need a display.
"""
try:
tkinter.Tk()
except (tkinter.TclError, NameError): # If there is no display.
try:
import matplotlib as mpl
except ImportError:
pass
else:
printd("Setting matplotlib backend to Agg")
mpl.use('Agg')
[docs] def load_notebook(self, name):
"""Loads a notebook file into memory."""
with open(self.get_path('%s.ipynb' % name)) as f:
nb = nbformat.read(f, as_version=4)
return nb, f
[docs] def run_notebook(self, nb, f):
"""Runs a loaded notebook file."""
if PYTHON_MAJOR_VERSION == 3:
kernel_name = 'python3'
elif PYTHON_MAJOR_VERSION == 2:
kernel_name = 'python2'
else:
raise Exception('Only Python 2 and 3 are supported')
ep = ExecutePreprocessor(timeout=600, kernel_name=kernel_name)
try:
ep.preprocess(nb, {'metadata': {'path': '.'}})
except CellExecutionError:
msg = 'Error executing the notebook "%s".\n\n' % f.name
msg += 'See notebook "%s" for the traceback.' % f.name
print(msg)
raise
finally:
nbformat.write(nb, f)
[docs] def execute_notebook(self, name):
"""Loads and then runs a notebook file."""
warnings.filterwarnings("ignore", category=DeprecationWarning)
nb, f = self.load_notebook(name)
self.run_notebook(nb, f)
self.assertTrue(True)
[docs] def convert_notebook(self, name):
"""Converts a notebook into a python file."""
exporter = nbconvert.exporters.python.PythonExporter()
relative_path = self.convert_path(name)
file_path = self.get_path("%s.ipynb" % relative_path)
code = exporter.from_filename(file_path)[0]
self.write_code(name, code)
self.clean_code(name, [])
[docs] def convert_and_execute_notebook(self, name):
"""Converts a notebook into a python file and then runs it."""
self.convert_notebook(name)
code = self.read_code(name)
exec(code, globals())
[docs] def gen_file_path(self, name):
"""
Returns full path to generated files. Checks to see if directory
exists where generated files are stored and creates one otherwise.
"""
relative_path = self.convert_path(name)
file_path = self.get_path("%s.ipynb" % relative_path)
parent_path = rec_apply(os.path.dirname,
self.gen_file_level)(file_path)
# Name of generated file
gen_file_name = name if isinstance(name, str) else name[1]
gen_dir_path = self.get_path(os.path.join(parent_path,
self.gen_dir_name))
# Create folder for generated files if needed
if not os.path.exists(gen_dir_path):
os.makedirs(gen_dir_path)
new_file_path = self.get_path('%s.py' % os.path.join(gen_dir_path,
gen_file_name))
return new_file_path
[docs] def read_code(self, name):
"""Reads code from a python file called 'name'"""
file_path = self.gen_file_path(name)
with open(file_path) as f:
code = f.read()
return code
[docs] def write_code(self, name, code):
"""
Writes code to a python file called 'name',
erasing the previous contents.
Files are created in a directory specified by gen_dir_name
(see function gen_file_path)
File name is second argument of path
"""
file_path = self.gen_file_path(name)
with open(file_path, 'w') as f:
f.write(code)
[docs] def clean_code(self, name, forbidden):
"""
Remove lines containing items in 'forbidden' from the code.
Helpful for executing converted notebooks that still retain IPython
magic commands.
"""
code = self.read_code(name)
code = code.split('\n')
new_code = []
for line in code:
if [bad for bad in forbidden if bad in line]:
pass
else:
# Magics where we want to keep the command
allowed = ['time', 'timeit']
line = self.strip_line_magic(line, allowed)
if isinstance(line, list):
line = ' '.join(line)
new_code.append(line)
new_code = '\n'.join(new_code)
self.write_code(name, new_code)
return new_code
[docs] @classmethod
def strip_line_magic(cls, line, magics_allowed):
"""Handles lines that contain get_ipython.run_line_magic() commands"""
if PYTHON_MAJOR_VERSION == 2: # Python 2
stripped, magic_kind = cls.strip_line_magic_v2(line)
else: # Python 3+
stripped, magic_kind = cls.strip_line_magic_v3(line)
if line == stripped:
printd("No line magic pattern match in '%s'" % line)
if magic_kind and magic_kind not in magics_allowed:
# If the part after the magic won't work, just get rid of it
stripped = ""
return stripped
[docs] @classmethod
def strip_line_magic_v3(cls, line):
"""strip_line_magic() implementation for Python 3"""
matches = re.findall("run_line_magic\(([^]]+)", line)
if matches and matches[0]: # This line contains the pattern
match = matches[0]
if match[-1] == ')':
match = match[:-1] # Just because the re way is hard
magic_kind, stripped = eval(match)
else:
stripped = line
magic_kind = ""
return stripped, magic_kind
[docs] @classmethod
def strip_line_magic_v2(cls, line):
"""strip_line_magic() implementation for Python 2"""
matches = re.findall("magic\(([^]]+)", line)
if matches and matches[0]: # This line contains the pattern
match = matches[0]
if match[-1] == ')':
match = match[:-1] # Just because the re way is hard
stripped = eval(match)
magic_kind = stripped.split(' ')[0]
if len(stripped.split(' ')) > 1:
stripped = stripped.split(' ')[1:]
else:
stripped = ""
else:
stripped = line
magic_kind = ""
return stripped, magic_kind
[docs] def do_notebook(self, name):
"""Run a notebook file after optionally
converting it to a python file."""
CONVERT_NOTEBOOKS = int(os.getenv('CONVERT_NOTEBOOKS', True))
s = StringIO()
if mock:
out = unittest.mock.patch('sys.stdout', new=MockDevice(s))
err = unittest.mock.patch('sys.stderr', new=MockDevice(s))
self._do_notebook(name, CONVERT_NOTEBOOKS)
out.close()
err.close()
else:
self._do_notebook(name, CONVERT_NOTEBOOKS)
self.assertTrue(True)
[docs] def _do_notebook(self, name, convert_notebooks=False):
"""Called by do_notebook to actually run the notebook."""
if convert_notebooks:
self.convert_and_execute_notebook(name)
else:
self.execute_notebook(name)
[docs]class MockDevice(TextIOWrapper):
"""A mock device to temporarily suppress output to stdout
Similar to UNIX /dev/null.
"""
[docs] def write(self, s):
if s.startswith('[') and s.endswith(']'):
super(MockDevice, self).write(s)
[docs]def import_all_modules(package, skip=None, verbose=False, prefix="", depth=0):
"""Recursively imports all subpackages, modules, and submodules of a
given package.
'package' should be an imported package, not a string.
'skip' is a list of modules or subpackages not to import.
"""
skip = [] if skip is None else skip
for ff, modname, ispkg in pkgutil.walk_packages(path=package.__path__,
prefix=prefix,
onerror=lambda x: None):
if ff.path not in package.__path__[0]: # Solves weird bug
continue
if verbose:
print('\t'*depth, modname)
if modname in skip:
if verbose:
print('\t'*depth, '*Skipping*')
continue
module = '%s.%s' % (package.__name__, modname)
subpackage = importlib.import_module(module)
if ispkg:
import_all_modules(subpackage, skip=skip,
verbose=verbose, depth=depth+1)
[docs]def import_module_from_path(module_path, name=None):
directory, file_name = os.path.split(module_path)
if name is None:
name = file_name.rstrip('.py')
if name == '__init__':
name = os.path.split(directory)[1]
try:
from importlib.machinery import SourceFileLoader
sfl = SourceFileLoader(name, module_path)
module = sfl.load_module()
except ImportError:
sys.path.append(directory)
from importlib import import_module
module_name = file_name.rstrip('.py')
module = import_module(module_name)
sys.path.pop() # Remove the directory that was just added.
return module
[docs]def method_cache(by='value', method='run'):
"""A decorator used on any model method which calls the model's 'method'
method if that latter method has not been called using the current
arguments or simply sets model attributes to match the run results if
it has."""
def decorate_(func):
def decorate(*args, **kwargs):
model = args[0] # Assumed to be self.
assert hasattr(model, method), \
"Model must have a '%s' method." % method
if func.__name__ == method: # Run itself.
method_args = kwargs
else: # Any other method.
method_args = kwargs[method] if method in kwargs else {}
# If there is no run cache.
if not hasattr(model.__class__, 'cached_runs'):
# Create the method cache.
model.__class__.cached_runs = {}
cache = model.__class__.cached_runs
if by == 'value':
model_dict = {key: value for key, value in
list(model.__dict__.items()) if key[0] != '_'}
method_signature = SciUnit.dict_hash(
{'attrs': model_dict, 'args': method_args}) # Hash key.
elif by == 'instance':
method_signature = SciUnit.dict_hash(
{'id': id(model), 'args': method_args}) # Hash key.
else:
raise ValueError("Cache type must be 'value' or 'instance'")
if method_signature not in cache:
print(("Method with this signature not found in the cache. "
"Running..."))
f = getattr(model, method)
f(**method_args)
cache[method_signature] = (datetime.now(),
model.__dict__.copy())
else:
print(("Method with this signature found in the cache. "
"Restoring..."))
_, attrs = cache[method_signature]
model.__dict__.update(attrs)
return func(*args, **kwargs)
return decorate
return decorate_
class_intern = cypy.intern
method_memoize = cypy.memoize
[docs]def log(*args, **kwargs):
if settings['LOGGING']:
if settings['KERNEL']:
kernel_log(*args, **kwargs)
else:
non_kernel_log(*args, **kwargs)
[docs]def non_kernel_log(*args, **kwargs):
args = [bs4.BeautifulSoup(x, "lxml").text
if not isinstance(x, Exception) else x
for x in args]
try:
print(*args, **kwargs)
except SyntaxError: # Python 2
print(args)
[docs]def kernel_log(*args, **kwargs):
with StringIO() as f:
kwargs['file'] = f
args = [u'%s' % arg for arg in args]
try:
print(*args, **kwargs)
except SyntaxError: # Python 2
print(args)
output = f.getvalue()
display(HTML(output))
[docs]def config_get_from_path(config_path, key):
try:
with open(config_path) as f:
config = json.load(f)
value = config[key]
except FileNotFoundError:
raise Error("Config file not found at '%s'" % config_path)
except json.JSONDecodeError:
log("Config file JSON at '%s' was invalid" % config_path)
raise Error("Config file not found at '%s'" % config_path)
except KeyError:
raise Error("Config file does not contain key '%s'" % key)
return value
[docs]def config_get(key, default=None):
try:
assert isinstance(key, str), "Config key must be a string"
config_path = os.path.join(settings['CWD'], 'config.json')
value = config_get_from_path(config_path, key)
except Exception as e:
if default is not None:
log(e)
log("Using default value of %s" % default)
value = default
else:
raise e
return value
[docs]def path_escape(path):
"""Escape a path by placing backslashes in front of disallowed characters"""
for char in [' ', '(', ')']:
path = path.replace(char, '\%s' % char)
return path