summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-09-23 04:45:08 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-09-23 04:45:08 +0000
commit65f0d8a7ae024b595a019043945ca46028998c69 (patch)
tree7b1528786d1bab1c2c486c11534317a805c21511 /numpy/core
parentdff8d9e31223e31a0544ee743b7be80f47ce9ac3 (diff)
downloadnumpy-65f0d8a7ae024b595a019043945ca46028998c69.tar.gz
Add test for default axis in method and functions.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/tests/test_regression.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index 7d1ed992b..2ee7a9cbb 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -411,6 +411,45 @@ class test_regression(NumpyTestCase):
x = N.array((1,2), dtype=dt)
x = x.byteswap()
assert(x['one'] > 1 and x['two'] > 2)
+
+ def check_method_args(self, level=rlevel):
+ # Make sure methods and functions have same default axis
+ # keyword and arguments
+ funcs1= ['argmax', 'argmin', 'sum', ('product', 'prod'),
+ ('sometrue', 'any'),
+ ('alltrue', 'all'), 'cumsum', ('cumproduct', 'cumprod'),
+ 'ptp', 'cumprod', 'prod', 'std', 'var', 'mean',
+ 'round', 'min', 'max', 'argsort', 'sort']
+ funcs2 = ['compress', 'take', 'repeat']
+
+ for func in funcs1:
+ arr = N.random.rand(8,7)
+ arr2 = arr.copy()
+ if isinstance(func, tuple):
+ func_meth = func[1]
+ func = func[0]
+ else:
+ func_meth = func
+ res1 = getattr(arr, func_meth)()
+ res2 = getattr(N, func)(arr2)
+ if res1 is None:
+ assert abs(arr-res2).max() < 1e-8, func
+ else:
+ assert abs(res1-res2).max() < 1e-8, func
+
+ for func in funcs2:
+ arr1 = N.random.rand(8,7)
+ arr2 = N.random.rand(8,7)
+ res1 = None
+ if func == 'compress':
+ arr1 = arr1.ravel()
+ res1 = getattr(arr2, func)(arr1)
+ else:
+ arr2 = (15*arr2).astype(int).ravel()
+ if res1 is None:
+ res1 = getattr(arr1, func)(arr2)
+ res2 = getattr(N, func)(arr1, arr2)
+ assert abs(res1-res2).max() < 1e-8, func
if __name__ == "__main__":
NumpyTest().run()