From 73bff2a400a2453bc387f2ed882c9ded88d89c12 Mon Sep 17 00:00:00 2001 From: Tim Hochberg Date: Thu, 20 Apr 2006 21:04:52 +0000 Subject: Tweaks to Travis's recent addition of thread local semantics to setting of errors. Most important is that I disabled the optimization in the "default" case since it appeared broken when there were multiple threads and it didn't seem to have a significant performance impact. Added comments on how to fix it if that turns out to be desirable. --- numpy/core/numeric.py | 23 +++++++++------ numpy/core/src/ufuncobject.c | 60 +++++++++++++++++++++++++++------------- numpy/core/tests/test_numeric.py | 29 ++++++++++++++++++- 3 files changed, 83 insertions(+), 29 deletions(-) diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index af4f3e6f8..f57527990 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -413,19 +413,24 @@ for key in _errdict.keys(): _errdict_rev[_errdict[key]] = key del key -def seterr(divide="ignore", over="ignore", under="ignore", - invalid="ignore"): +def seterr(divide=None, over=None, under=None, invalid=None): + + pyvals = umath.geterrobj() + old = geterr() + + if divide is None: divide = old['divide'] + if over is None: over = old['over'] + if under is None: under = old['under'] + if invalid is None: invalid = old['invalid'] maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) + (_errdict[over] << SHIFT_OVERFLOW ) + (_errdict[under] << SHIFT_UNDERFLOW) + (_errdict[invalid] << SHIFT_INVALID)) - pyvals = umath.geterrobj() - old = pyvals[:] pyvals[1] = maskvalue umath.seterrobj(pyvals) - return ufunc_values_obj(old) + return old def geterr(): maskvalue = umath.geterrobj()[1] @@ -446,10 +451,10 @@ def setbufsize(size): raise ValueError, "Very big buffers.. %s" % size pyvals = umath.geterrobj() - old = pyvals[:] + old = getbufsize() pyvals[0] = size umath.seterrobj(pyvals) - return ufunc_values_obj(old) + return old def getbufsize(): return umath.geterrobj()[0] @@ -458,10 +463,10 @@ def seterrcall(func): if not callable(func): raise ValueError, "Only callable can be used as callback" pyvals = umath.geterrobj() - old = pyvals[:] + old = geterrcall() pyvals[2] = func umath.seterrobj(pyvals) - return ufunc_values_obj(old) + return old def geterrcall(): return umath.geterrobj()[2] diff --git a/numpy/core/src/ufuncobject.c b/numpy/core/src/ufuncobject.c index 721ee2f42..bc439b61c 100644 --- a/numpy/core/src/ufuncobject.c +++ b/numpy/core/src/ufuncobject.c @@ -35,6 +35,9 @@ typedef void (CdoubleBinaryFunc)(cdouble *x, cdouble *y, cdouble *res); typedef void (CfloatBinaryFunc)(cfloat *x, cfloat *y, cfloat *res); typedef void (ClongdoubleBinaryFunc)(clongdouble *x, clongdouble *y, \ clongdouble *res); + +#define USE_USE_DEFAULTS 0 + /*UFUNC_API*/ static void @@ -696,9 +699,12 @@ select_types(PyUFuncObject *self, int *arg_types, return 0; } + +#if USE_USE_DEFAULTS +static int PyUFunc_USEDEFAULTS=0; +#endif +static PyObject *PyUFunc_PYVALS_NAME=NULL; -static int PyUFunc_USEDEFAULTS=0; -static PyObject *PyUFunc_PYVALS_NAME=NULL; /*UFUNC_API*/ static int @@ -707,8 +713,10 @@ PyUFunc_GetPyValues(char *name, int *bufsize, int *errmask, PyObject **errobj) PyObject *thedict; PyObject *ref=NULL; PyObject *retval; - - if (!PyUFunc_USEDEFAULTS) { + + #if USE_USE_DEFAULTS + if (!PyUFunc_USEDEFAULTS) { + #endif if (PyUFunc_PYVALS_NAME == NULL) { PyUFunc_PYVALS_NAME = PyString_InternFromString(UFUNC_PYVALS_NAME); } @@ -716,18 +724,10 @@ PyUFunc_GetPyValues(char *name, int *bufsize, int *errmask, PyObject **errobj) if (thedict == NULL) { thedict = PyEval_GetBuiltins(); } - ref = PyDict_GetItem(thedict, PyUFunc_PYVALS_NAME); -/* thedict = PyEval_GetLocals(); */ -/* ref = PyDict_GetItem(thedict, thestring); */ -/* if (ref == NULL) { */ -/* thedict = PyEval_GetGlobals(); */ -/* ref = PyDict_GetItem(thedict, thestring); */ -/* } */ -/* if (ref == NULL) { */ -/* thedict = PyEval_GetBuiltins(); */ -/* ref = PyDict_GetItem(thedict, thestring); */ -/* } */ - } + ref = PyDict_GetItem(thedict, PyUFunc_PYVALS_NAME); + #if USE_USE_DEFAULTS + } + #endif if (ref == NULL) { *errmask = UFUNC_ERR_DEFAULT; *errobj = Py_BuildValue("NO", @@ -2776,7 +2776,26 @@ ufunc_geterr(PyObject *dummy, PyObject *args) PyList_SET_ITEM(res, 2, Py_None); Py_INCREF(Py_None); return res; } - + +#if USE_USE_DEFAULTS +/* +This doesn't look it will work in the presence of threads. It updates +PyUFunc_USEDEFAULTS based on the current thread. If some other thread is +around, it will see an incorrect value for use_defaults. + +I think the following strategy would fix this: + 1. Change PyUFunc_USEDEFAULTS to PyUFunc_NONDEFAULTCOUNT or similar + 2. Increment PyUFunc_NONDEFAULTCOUNT whenever a value is set to a nondefault + value + 3. Only use defaults when PyUFunc_NONDEFAULTCOUNT is nonzero. + +However, I'm not sure that it's worth the trouble. I've done a few small +benchmarks and I see at most marginal speed improvements with the +default values. So, for the time being, I'm simply ifdefing out the +nonworking code and not worrying about it. If those benchmarks hold up, we +should go ahead and rip the code out so as not to confuse future generations. + +*/ static int ufunc_update_use_defaults(void) { @@ -2793,6 +2812,7 @@ ufunc_update_use_defaults(void) } return 0; } +#endif static PyObject * ufunc_seterr(PyObject *dummy, PyObject *args) @@ -2829,8 +2849,10 @@ ufunc_seterr(PyObject *dummy, PyObject *args) } res = PyDict_SetItem(thedict, PyUFunc_PYVALS_NAME, val); Py_DECREF(val); - if (res < 0) return NULL; - if (ufunc_update_use_defaults() < 0) return NULL; + if (res < 0) return NULL; +#if USE_USE_DEFAULTS + if (ufunc_update_use_defaults() < 0) return NULL; +#endif Py_INCREF(Py_None); return Py_None; } diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 2edf58399..45204e3bf 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -134,4 +134,31 @@ class test_bool_scalar(ScipyTestCase): self.failUnless((t ^ t) is f) self.failUnless((f ^ t) is t) self.failUnless((t ^ f) is t) - self.failUnless((f ^ f) is f) + self.failUnless((f ^ f) is f) + + +class test_seterr(ScipyTestCase): + def test_set(self): + err = seterr() + old = seterr(divide='warn') + self.failUnless(err == old) + new = seterr() + self.failUnless(new['divide'] == 'warn') + seterr(over='raise') + self.failUnless(geterr()['over'] == 'raise') + self.failUnless(new['divide'] == 'warn') + seterr(**old) + self.failUnless(geterr() == old) + def test_divideerr(self): + seterr(divide='raise') + try: + array([1.]) / array([0.]) + except FloatingPointError: + pass + else: + self.fail() + seterr(divide='ignore') + array([1.]) / array([0.]) + +if __name__ == '__main__': + NumpyTest().run() \ No newline at end of file -- cgit v1.2.1