diff options
author | Darsh P. Ranjan <darsh.ranjan@here.com> | 2014-11-09 23:20:16 -0800 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2015-01-22 11:33:55 -0700 |
commit | 1bce8d75a7dd8068adc1c1740c7ed03e71a9b5af (patch) | |
tree | d5741839f5a85c94c91831bbd1e1c5cc5d48b61e | |
parent | c9075faa59f823033dca449edce002a0a569a9a5 (diff) | |
download | numpy-1bce8d75a7dd8068adc1c1740c7ed03e71a9b5af.tar.gz |
BUG: Fix astype for structured array fields of different byte order.
The offending commit is c53b0e4, which introduced two regressions:
- using astype to cast a structured array to one with a different byte
order no longer works;
- comparing structured-array dtypes can give incorrect results if the
two dtypes have different byte orders.
This pull request should fix both.
One thing I wasn't sure about is reordering struct fields. In my
implementation, the `equiv`, `same_kind`, and `safe` rules are now
allowed to reorder fields. If that isn't desired, though, it's a pretty
easy change.
-rw-r--r-- | numpy/core/src/multiarray/convert_datatype.c | 64 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 41 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 63 |
3 files changed, 139 insertions, 29 deletions
diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index 1db3bfe85..35503d1e2 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -634,6 +634,52 @@ static npy_bool PyArray_CanCastTypeTo_impl(PyArray_Descr *from, PyArray_Descr *to, NPY_CASTING casting); +/* + * Compare two field dictionaries for castability. + * + * Return 1 if 'field1' can be cast to 'field2' according to the rule + * 'casting', 0 if not. + * + * Castabiliy of field dictionaries is defined recursively: 'field1' and + * 'field2' must have the same field names (possibly in different + * orders), and the corresponding field types must be castable according + * to the given casting rule. + */ +static int +can_cast_fields(PyObject *field1, PyObject *field2, NPY_CASTING casting) +{ + Py_ssize_t ppos; + PyObject *key; + PyObject *tuple1, *tuple2; + + if (field1 == field2) { + return 1; + } + if (field1 == NULL || field2 == NULL) { + return 0; + } + if (PyDict_Size(field1) != PyDict_Size(field2)) { + return 0; + } + + /* Iterate over all the fields and compare for castability */ + ppos = 0; + while (PyDict_Next(field1, &ppos, &key, &tuple1)) { + if ((tuple2 = PyDict_GetItem(field2, key)) == NULL) { + return 0; + } + /* Compare the dtype of the field for castability */ + if (!PyArray_CanCastTypeTo( + (PyArray_Descr *)PyTuple_GET_ITEM(tuple1, 0), + (PyArray_Descr *)PyTuple_GET_ITEM(tuple2, 0), + casting)) { + return 0; + } + } + + return 1; +} + /*NUMPY_API * Returns true if data of type 'from' may be cast to data of type * 'to' according to the rule 'casting'. @@ -687,7 +733,6 @@ PyArray_CanCastTypeTo_impl(PyArray_Descr *from, PyArray_Descr *to, else if (PyArray_EquivTypenums(from->type_num, to->type_num)) { /* For complicated case, use EquivTypes (for now) */ if (PyTypeNum_ISUSERDEF(from->type_num) || - PyDataType_HASFIELDS(from) || from->subarray != NULL) { int ret; @@ -715,6 +760,23 @@ PyArray_CanCastTypeTo_impl(PyArray_Descr *from, PyArray_Descr *to, return ret; } + if (PyDataType_HASFIELDS(from)) { + switch (casting) { + case NPY_EQUIV_CASTING: + case NPY_SAFE_CASTING: + case NPY_SAME_KIND_CASTING: + /* + * `from' and `to' must have the same fields, and + * corresponding fields must be (recursively) castable. + */ + return can_cast_fields(from->fields, to->fields, casting); + + case NPY_NO_CASTING: + default: + return PyArray_EquivTypes(from, to); + } + } + switch (from->type_num) { case NPY_DATETIME: { PyArray_DatetimeMetaData *meta1, *meta2; diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 8446e1583..844daa82d 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1410,9 +1410,7 @@ array_putmask(PyObject *NPY_UNUSED(module), PyObject *args, PyObject *kwds) static int _equivalent_fields(PyObject *field1, PyObject *field2) { - Py_ssize_t ppos; - PyObject *key; - PyObject *tuple1, *tuple2; + int same, val; if (field1 == field2) { return 1; @@ -1420,33 +1418,20 @@ _equivalent_fields(PyObject *field1, PyObject *field2) { if (field1 == NULL || field2 == NULL) { return 0; } - - if (PyDict_Size(field1) != PyDict_Size(field2)) { - return 0; +#if defined(NPY_PY3K) + val = PyObject_RichCompareBool(field1, field2, Py_EQ); + if (val != 1 || PyErr_Occurred()) { +#else + val = PyObject_Compare(field1, field2); + if (val != 0 || PyErr_Occurred()) { +#endif + same = 0; } - - /* Iterate over all the fields and compare for equivalency */ - ppos = 0; - while (PyDict_Next(field1, &ppos, &key, &tuple1)) { - if ((tuple2 = PyDict_GetItem(field2, key)) == NULL) { - return 0; - } - /* Compare the dtype of the field for equivalency */ - if (!PyArray_CanCastTypeTo((PyArray_Descr *)PyTuple_GET_ITEM(tuple1, 0), - (PyArray_Descr *)PyTuple_GET_ITEM(tuple2, 0), - NPY_EQUIV_CASTING)) { - return 0; - } - /* Compare the byte position of the field */ - if (PyObject_RichCompareBool(PyTuple_GET_ITEM(tuple1, 1), - PyTuple_GET_ITEM(tuple2, 1), - Py_EQ) != 1) { - PyErr_Clear(); - return 0; - } + else { + same = 1; } - - return 1; + PyErr_Clear(); + return same; } /* diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index d472a9569..4770faf9d 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -289,6 +289,10 @@ class TestDtypedescr(TestCase): d2 = dtype('f8') assert_equal(d2, dtype(float64)) + def test_byteorders(self): + self.assertNotEqual(dtype('<i4'), dtype('>i4')) + self.assertNotEqual(dtype([('a', '<i4')]), dtype([('a', '>i4')])) + class TestZeroRank(TestCase): def setUp(self): self.d = array(0), array('x', object) @@ -688,6 +692,65 @@ class TestStructured(TestCase): b = np.array([(5, 43), (10, 1)], dtype=[('a', '<i8'), ('b', '>f8')]) assert_equal(a == b, [False, True]) + def test_casting(self): + # Check that casting a structured array to change its byte order + # works + a = np.array([(1,)], dtype=[('a', '<i4')]) + assert_(np.can_cast(a.dtype, [('a', '>i4')], casting='unsafe')) + b = a.astype([('a', '>i4')]) + assert_equal(b, a.byteswap().newbyteorder()) + assert_equal(a['a'][0], b['a'][0]) + + # Check that equality comparison works on structured arrays if + # they are 'equiv'-castable + a = np.array([(5, 42), (10, 1)], dtype=[('a', '>i4'), ('b', '<f8')]) + b = np.array([(42, 5), (1, 10)], dtype=[('b', '>f8'), ('a', '<i4')]) + assert_(np.can_cast(a.dtype, b.dtype, casting='equiv')) + assert_equal(a == b, [True, True]) + + # Check that 'equiv' casting can reorder fields and change byte + # order + assert_(np.can_cast(a.dtype, b.dtype, casting='equiv')) + c = a.astype(b.dtype, casting='equiv') + assert_equal(a == c, [True, True]) + + # Check that 'safe' casting can change byte order and up-cast + # fields + t = [('a', '<i8'), ('b', '>f8')] + assert_(np.can_cast(a.dtype, t, casting='safe')) + c = a.astype(t, casting='safe') + assert_equal((c == np.array([(5, 42), (10, 1)], dtype=t)), + [True, True]) + + # Check that 'same_kind' casting can change byte order and + # change field widths within a "kind" + t = [('a', '<i4'), ('b', '>f4')] + assert_(np.can_cast(a.dtype, t, casting='same_kind')) + c = a.astype(t, casting='same_kind') + assert_equal((c == np.array([(5, 42), (10, 1)], dtype=t)), + [True, True]) + + # Check that casting fails if the casting rule should fail on + # any of the fields + t = [('a', '>i8'), ('b', '<f4')] + assert_(not np.can_cast(a.dtype, t, casting='safe')) + assert_raises(TypeError, a.astype, t, casting='safe') + t = [('a', '>i2'), ('b', '<f8')] + assert_(not np.can_cast(a.dtype, t, casting='equiv')) + assert_raises(TypeError, a.astype, t, casting='equiv') + t = [('a', '>i8'), ('b', '<i2')] + assert_(not np.can_cast(a.dtype, t, casting='same_kind')) + assert_raises(TypeError, a.astype, t, casting='same_kind') + assert_(not np.can_cast(a.dtype, b.dtype, casting='no')) + assert_raises(TypeError, a.astype, b.dtype, casting='no') + + # Check that non-'unsafe' casting can't change the set of field names + for casting in ['no', 'safe', 'equiv', 'same_kind']: + t = [('a', '>i4')] + assert_(not np.can_cast(a.dtype, t, casting=casting)) + t = [('a', '>i4'), ('b', '<f8'), ('c', 'i4')] + assert_(not np.can_cast(a.dtype, t, casting=casting)) + class TestBool(TestCase): def test_test_interning(self): |