"""The base class for many SciUnit objects."""
import os
import sys
import json
import pickle
import hashlib
import numpy as np
import pandas as pd
import git
from git.exc import GitCommandError, InvalidGitRepositoryError
from git.cmd import Git
PYTHON_MAJOR_VERSION = sys.version_info.major
PLATFORM = sys.platform
if PYTHON_MAJOR_VERSION < 3: # Python 2
from StringIO import StringIO
try:
import Tkinter as tkinter
except ImportError:
pass # Handled in the importing modules's fix_display()
FileNotFoundError = OSError
json.JSONDecodeError = ValueError
else:
from io import StringIO
import tkinter
FileNotFoundError = FileNotFoundError
KERNEL = ('ipykernel' in sys.modules)
LOGGING = True
HERE = os.path.dirname(os.path.realpath(__file__))
[docs]class Versioned(object):
"""A Mixin class for SciUnit objects.
Provides a version string based on the Git repository where the model
is tracked. Provided in part by Andrew Davison in issue #53.
"""
[docs] def get_repo(self, cached=True):
"""Get a git repository object for this instance."""
module = sys.modules[self.__module__]
# We use module.__file__ instead of module.__path__[0]
# to include modules without a __path__ attribute.
if hasattr(self.__class__, '_repo') and cached:
repo = self.__class__._repo
elif hasattr(module, '__file__'):
path = os.path.realpath(module.__file__)
try:
repo = git.Repo(path, search_parent_directories=True)
except InvalidGitRepositoryError:
repo = None
else:
repo = None
self.__class__._repo = repo
return repo
[docs] def get_version(self, cached=True):
"""Get a git version (i.e. a git commit hash) for this instance."""
if hasattr(self.__class__, '_version') and cached:
version = self.__class__._version
else:
repo = self.get_repo()
if repo is not None:
head = repo.head
version = head.commit.hexsha
if repo.is_dirty():
version += "*"
else:
version = None
self.__class__._version = version
return version
version = property(get_version)
[docs] def get_remote(self, remote='origin'):
"""Get a git remote object for this instance."""
repo = self.get_repo()
if repo is not None:
remotes = {r.name: r for r in repo.remotes}
r = repo.remotes[0] if remote not in remotes else remotes[remote]
else:
r = None
return r
[docs] def get_remote_url(self, remote='origin', cached=True):
"""Get a git remote URL for this instance."""
if hasattr(self.__class__, '_remote_url') and cached:
url = self.__class__._remote_url
else:
r = self.get_remote(remote)
try:
url = list(r.urls)[0]
except GitCommandError as ex:
if 'correct access rights' in str(ex):
# If ssh is not setup to access this repository
cmd = ['git', 'config', '--get', 'remote.%s.url' % r.name]
url = Git().execute(cmd)
else:
raise ex
except AttributeError:
url = None
if url is not None and url.startswith('git@'):
domain = url.split('@')[1].split(':')[0]
path = url.split(':')[1]
url = "http://%s/%s" % (domain, path)
self.__class__._remote_url = url
return url
remote_url = property(get_remote_url)
[docs]class SciUnit(Versioned):
"""Abstract base class for models, tests, and scores."""
[docs] def __init__(self):
"""Instantiate a SciUnit object."""
self.unpicklable = []
#: A list of attributes that cannot or should not be pickled.
unpicklable = []
#: A URL where the code for this object can be found.
_url = None
#: A verbosity level for printing information.
verbose = 1
[docs] def __getstate__(self):
"""Copy the object's state from self.__dict__.
Contains all of the instance attributes. Always uses the dict.copy()
method to avoid modifying the original state.
"""
state = self.__dict__.copy()
# Remove the unpicklable entries.
if hasattr(self, 'unpicklable'):
for key in set(self.unpicklable).intersection(state):
del state[key]
return state
[docs] def _state(self, state=None, keys=None, exclude=None):
if state is None:
state = self.__getstate__()
if keys:
state = {key: state[key] for key in keys if key in state.keys()}
if exclude:
state = {key: state[key] for key in state.keys()
if key not in exclude}
state = deep_exclude(state, exclude)
return state
[docs] def _properties(self, keys=None, exclude=None):
result = {}
props = self.raw_props()
exclude = exclude if exclude else []
exclude += ['state', 'id']
for prop in set(props).difference(exclude):
if prop == 'properties':
pass # Avoid infinite recursion
elif not keys or prop in keys:
result[prop] = getattr(self, prop)
return result
[docs] def raw_props(self):
class_attrs = dir(self.__class__)
return [p for p in class_attrs
if isinstance(getattr(self.__class__, p, None), property)]
@property
def state(self):
return self._state()
@property
def properties(self):
return self._properties()
[docs] @classmethod
def dict_hash(cls, d):
od = [(key, d[key]) for key in sorted(d)]
try:
s = pickle.dumps(od)
except AttributeError:
s = json.dumps(od, cls=SciUnitEncoder).encode('utf-8')
return hashlib.sha224(s).hexdigest()
@property
def hash(self):
"""A unique numeric identifier of the current model state"""
return self.dict_hash(self.state)
[docs] def json(self, add_props=False, keys=None, exclude=None, string=True,
indent=None):
result = json.dumps(self, cls=SciUnitEncoder,
add_props=add_props, keys=keys, exclude=exclude,
indent=indent)
if not string:
result = json.loads(result)
return result
@property
def _id(self):
return id(self)
@property
def _class(self):
url = '' if self.url is None else self.url
import_path = '{}.{}'.format(
self.__class__.__module__,
self.__class__.__name__
)
return {'name': self.__class__.__name__,
'import_path': import_path,
'url': url}
@property
def id(self):
return str(self.json)
@property
def url(self):
return self._url if self._url else self.remote_url
[docs]class SciUnitEncoder(json.JSONEncoder):
"""Custom JSON encoder for SciUnit objects"""
[docs] def __init__(self, *args, **kwargs):
for key in ['add_props', 'keys', 'exclude']:
if key in kwargs:
setattr(self.__class__, key, kwargs[key])
kwargs.pop(key)
super(SciUnitEncoder, self).__init__(*args, **kwargs)
[docs] def default(self, obj):
try:
if isinstance(obj, pd.DataFrame):
o = obj.to_dict(orient='split')
if isinstance(obj, SciUnit):
for old, new in [('data', 'scores'),
('columns', 'tests'),
('index', 'models')]:
o[new] = o.pop(old)
elif isinstance(obj, np.ndarray) and len(obj.shape):
o = obj.tolist()
elif isinstance(obj, SciUnit):
state = obj.state
if self.add_props:
state.update(obj.properties)
o = obj._state(state=state, keys=self.keys,
exclude=self.exclude)
elif isinstance(obj, (dict, list, tuple, str, type(None), bool,
float, int)):
o = json.JSONEncoder.default(self, obj)
else: # Something we don't know how to serialize;
# just represent it as truncated string
o = "%.20s..." % obj
except Exception as e:
print("Could not JSON encode object %s" % obj)
raise e
return o
[docs]class TestWeighted(object):
"""Base class for objects with test weights."""
@property
def weights(self):
"""Returns a normalized list of test weights."""
n = len(self.tests)
if self.weights_:
assert all([x >= 0 for x in self.weights_]),\
"All test weights must be >=0"
summ = sum(self.weights_) # Sum of test weights
assert summ > 0, "Sum of test weights must be > 0"
weights = [x/summ for x in self.weights_] # Normalize to sum
else:
weights = [1.0/n for i in range(n)]
return weights
[docs]def deep_exclude(state, exclude):
tuples = [key for key in exclude if isinstance(key, tuple)]
s = state
for loc in tuples:
for key in loc:
try:
s[key]
except Exception:
pass
else:
if key == loc[-1]:
s[key] = '*removed*'
else:
s = s[key]
return state