You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
7.5 KiB
205 lines
7.5 KiB
from __future__ import print_function |
|
import os, sys, pickle, pprint |
|
import numpy as np |
|
import pyqtgraph as pg |
|
from .. import AUDIT_TESTS |
|
|
|
|
|
class UserTester(object): |
|
""" |
|
Base class for testing when a human is required to verify the results. |
|
|
|
When a test is passed by the user, its output is saved and used as a basis |
|
for future tests. If future test results do not match the stored results, |
|
then the user is asked to decide whether to fail the test, or pass the |
|
test and store new results. |
|
|
|
Subclasses must reimplement run_test() to return a dictionary of results |
|
to store. Optionally, compare_results and audit_result may also be |
|
reimplemented to customize the testing behavior. |
|
|
|
By default, test results are stored in a 'test_data' directory relative |
|
to the file that defines the UserTester subclass in use. |
|
""" |
|
|
|
data_dir = "test_data" |
|
|
|
def __init__(self, key, *args, **kwds): |
|
"""Initialize with a string *key* that provides a short, unique |
|
description of this test. All other arguments are passed to run_test(). |
|
|
|
*key* is used to determine the file name for storing test results. |
|
""" |
|
self.audit = AUDIT_TESTS |
|
self.key = key |
|
self.rtol = 1e-3 |
|
self.args = args |
|
self.assert_test_info(*args, **kwds) |
|
|
|
def run_test(self, *args, **kwds): |
|
""" |
|
Exceute the test. All arguments are taken from __init__. |
|
Return a picklable dictionary of test results. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def compare_results(self, key, info, expect): |
|
""" |
|
Compare *result* of the current test against the previously stored |
|
result *expect*. If *expect* is None, then no previous result was |
|
stored. |
|
|
|
If *result* and *expect* do not match, then raise an exception. |
|
""" |
|
# Check test structures are the same |
|
assert type(info) is type(expect) |
|
if hasattr(info, "__len__"): |
|
assert len(info) == len(expect) |
|
|
|
if isinstance(info, dict): |
|
for k in info: |
|
assert k in expect |
|
for k in expect: |
|
assert k in info |
|
self.compare_results(k, info[k], expect[k]) |
|
elif isinstance(info, list): |
|
for i in range(len(info)): |
|
self.compare_results(key, info[i], expect[i]) |
|
elif isinstance(info, np.ndarray): |
|
assert info.shape == expect.shape |
|
if len(info) == 0: |
|
return |
|
# assert info.dtype == expect.dtype |
|
if info.dtype.fields is None: |
|
intnan = -9223372036854775808 # happens when np.nan is cast to int |
|
inans = np.isnan(info) | (info == intnan) |
|
enans = np.isnan(expect) | (expect == intnan) |
|
assert np.all(inans == enans) |
|
mask = ~inans |
|
if not np.allclose(info[mask], expect[mask], rtol=self.rtol): |
|
print( |
|
"\nComparing data array, shapes match: ", |
|
info.shape == expect.shape, |
|
) |
|
print("Model tested: %s, measure: %s" % (self.key, key)) |
|
# print( 'args: ', dir(self.args[0])) |
|
print("Array expected: ", expect[mask]) |
|
print("Array received: ", info[mask]) |
|
try: |
|
self.args[0].print_all_mechs() |
|
except: |
|
print("args[0] is string: ", self.args[0]) |
|
assert np.allclose(info[mask], expect[mask], rtol=self.rtol) |
|
else: |
|
for k in info.dtype.fields.keys(): |
|
self.compare_results(k, info[k], expect[k]) |
|
elif np.isscalar(info): |
|
if not np.allclose(info, expect, rtol=self.rtol): |
|
print("Comparing Scalar data, model: %s, measure: %s" % (self.key, key)) |
|
# print 'args: ', dir(self.args[0]) |
|
print( |
|
"Expected: ", |
|
expect, |
|
", received: ", |
|
info, |
|
" relative tolerance: ", |
|
self.rtol, |
|
) |
|
if isinstance(self.args[0], str): |
|
pass |
|
# print ': ', str |
|
else: |
|
self.args[0].print_all_mechs() |
|
assert np.allclose(info, expect, rtol=self.rtol) |
|
else: |
|
try: |
|
assert info == expect |
|
except AssertionError: |
|
raise |
|
except Exception: |
|
raise NotImplementedError( |
|
"Cannot compare objects of type %s" % type(info) |
|
) |
|
|
|
def audit_result(self, info, expect): |
|
""" Display results and ask the user to decide whether the test passed. |
|
Return True for pass, False for fail. |
|
|
|
If *expect* is None, then no previous test results were stored. |
|
""" |
|
app = pg.mkQApp() |
|
print("\n=== New test results for %s: ===\n" % self.key) |
|
pprint.pprint(info) |
|
|
|
# we use DiffTreeWidget to display differences between large data structures, but |
|
# this is not present in mainline pyqtgraph yet |
|
if hasattr(pg, "DiffTreeWidget"): |
|
win = pg.DiffTreeWidget() |
|
else: |
|
from cnmodel.util.difftreewidget import DiffTreeWidget |
|
|
|
win = DiffTreeWidget() |
|
|
|
win.resize(800, 800) |
|
win.setData(expect, info) |
|
win.show() |
|
print("Store new test results? [y/n]") |
|
yn = raw_input() |
|
win.hide() |
|
return yn.lower().startswith("y") |
|
|
|
def assert_test_info(self, *args, **kwds): |
|
""" |
|
Test *cell* and raise exception if the results do not match prior |
|
data. |
|
""" |
|
result = self.run_test(*args, **kwds) |
|
expect = self.load_test_result() |
|
try: |
|
assert expect is not None |
|
self.compare_results(None, result, expect) |
|
except: |
|
if not self.audit: |
|
if expect is None: |
|
raise Exception( |
|
"No prior test results for test '%s'. " |
|
"Run test.py --audit store new test data." % self.key |
|
) |
|
else: |
|
raise |
|
|
|
store = self.audit_result(result, expect) |
|
if store: |
|
self.save_test_result(result) |
|
else: |
|
raise Exception("Rejected test results for '%s'" % self.key) |
|
|
|
def result_file(self): |
|
""" |
|
Return a file name to be used for storing / retrieving test results |
|
given *self.key*. |
|
""" |
|
modfile = sys.modules[self.__class__.__module__].__file__ |
|
path = os.path.dirname(modfile) |
|
return os.path.join(path, self.data_dir, self.key + ".pk") |
|
|
|
def load_test_result(self): |
|
""" |
|
Load prior test results for *self.key*. |
|
If there are no prior results, return None. |
|
""" |
|
fn = self.result_file() |
|
if os.path.isfile(fn): |
|
return pickle.load(open(fn, "rb"), encoding="latin1") |
|
return None |
|
|
|
def save_test_result(self, result): |
|
""" |
|
Store test results for *self.key*. |
|
Th e*result* argument must be picklable. |
|
""" |
|
fn = self.result_file() |
|
dirname = os.path.dirname(fn) |
|
if not os.path.isdir(dirname): |
|
os.mkdir(dirname) |
|
pickle.dump(result, open(fn, "wb"))
|
|
|