diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2019-06-01 22:09:02 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-01 22:09:02 +0300 |
commit | 2b843ac0ae745026ce39514573c5d075137bef65 (patch) | |
tree | 8e176372e55d171590b4c798d6deaf9311cbef8c /Modules | |
parent | 9843bc110dc4241ba7cb05f3d3ef74ac6c77caf2 (diff) | |
download | cpython-git-2b843ac0ae745026ce39514573c5d075137bef65.tar.gz |
bpo-35431: Refactor math.comb() implementation. (GH-13725)
* Fixed some bugs.
* Added support for index-likes objects.
* Improved error messages.
* Cleaned up and optimized the code.
* Added more tests.
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/clinic/mathmodule.c.h | 24 | ||||
-rw-r--r-- | Modules/mathmodule.c | 155 |
2 files changed, 87 insertions, 92 deletions
diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index cba791e209..92ec4bec9b 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -639,10 +639,10 @@ exit: } PyDoc_STRVAR(math_comb__doc__, -"comb($module, /, n, k)\n" +"comb($module, n, k, /)\n" "--\n" "\n" -"Number of ways to choose *k* items from *n* items without repetition and without order.\n" +"Number of ways to choose k items from n items without repetition and without order.\n" "\n" "Also called the binomial coefficient. It is mathematically equal to the expression\n" "n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n" @@ -652,38 +652,26 @@ PyDoc_STRVAR(math_comb__doc__, "Raises ValueError if the arguments are negative or if k > n."); #define MATH_COMB_METHODDEF \ - {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL|METH_KEYWORDS, math_comb__doc__}, + {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__}, static PyObject * math_comb_impl(PyObject *module, PyObject *n, PyObject *k); static PyObject * -math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames) +math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs) { PyObject *return_value = NULL; - static const char * const _keywords[] = {"n", "k", NULL}; - static _PyArg_Parser _parser = {NULL, _keywords, "comb", 0}; - PyObject *argsbuf[2]; PyObject *n; PyObject *k; - args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 2, 0, argsbuf); - if (!args) { - goto exit; - } - if (!PyLong_Check(args[0])) { - _PyArg_BadArgument("comb", 1, "int", args[0]); + if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) { goto exit; } n = args[0]; - if (!PyLong_Check(args[1])) { - _PyArg_BadArgument("comb", 2, "int", args[1]); - goto exit; - } k = args[1]; return_value = math_comb_impl(module, n, k); exit: return return_value; } -/*[clinic end generated code: output=00aa76356759617a input=a9049054013a1b77]*/ +/*[clinic end generated code: output=6709521e5e1d90ec input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 007a880142..bea4607b9b 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start) /*[clinic input] math.comb - n: object(subclass_of='&PyLong_Type') - k: object(subclass_of='&PyLong_Type') + n: object + k: object + / -Number of ways to choose *k* items from *n* items without repetition and without order. +Number of ways to choose k items from n items without repetition and without order. Also called the binomial coefficient. It is mathematically equal to the expression n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in @@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n. static PyObject * math_comb_impl(PyObject *module, PyObject *n, PyObject *k) -/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/ +/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/ { - PyObject *val = NULL, - *temp_obj1 = NULL, - *temp_obj2 = NULL, - *dump_var = NULL; + PyObject *result = NULL, *factor = NULL, *temp; int overflow, cmp; - long long i, terms; + long long i, factors; - cmp = PyObject_RichCompareBool(n, k, Py_LT); - if (cmp < 0) { - goto fail_comb; + n = PyNumber_Index(n); + if (n == NULL) { + return NULL; } - else if (cmp > 0) { - PyErr_Format(PyExc_ValueError, - "n must be an integer greater than or equal to k"); - goto fail_comb; + k = PyNumber_Index(k); + if (k == NULL) { + Py_DECREF(n); + return NULL; } - /* b = min(b, a - b) */ - dump_var = PyNumber_Subtract(n, k); - if (dump_var == NULL) { - goto fail_comb; + if (Py_SIZE(n) < 0) { + PyErr_SetString(PyExc_ValueError, + "n must be a non-negative integer"); + goto error; } - cmp = PyObject_RichCompareBool(k, dump_var, Py_GT); - if (cmp < 0) { - goto fail_comb; + /* k = min(k, n - k) */ + temp = PyNumber_Subtract(n, k); + if (temp == NULL) { + goto error; } - else if (cmp > 0) { - k = dump_var; - dump_var = NULL; + if (Py_SIZE(temp) < 0) { + Py_DECREF(temp); + PyErr_SetString(PyExc_ValueError, + "k must be an integer less than or equal to n"); + goto error; + } + cmp = PyObject_RichCompareBool(k, temp, Py_GT); + if (cmp > 0) { + Py_SETREF(k, temp); } else { - Py_DECREF(dump_var); - dump_var = NULL; + Py_DECREF(temp); + if (cmp < 0) { + goto error; + } } - terms = PyLong_AsLongLongAndOverflow(k, &overflow); - if (terms < 0 && PyErr_Occurred()) { - goto fail_comb; - } - else if (overflow > 0) { + factors = PyLong_AsLongLongAndOverflow(k, &overflow); + if (overflow > 0) { PyErr_Format(PyExc_OverflowError, - "minimum(n - k, k) must not exceed %lld", + "min(n - k, k) must not exceed %lld", LLONG_MAX); - goto fail_comb; + goto error; } - else if (overflow < 0 || terms < 0) { - PyErr_Format(PyExc_ValueError, - "k must be a positive integer"); - goto fail_comb; + else if (overflow < 0 || factors < 0) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ValueError, + "k must be a non-negative integer"); + } + goto error; } - if (terms == 0) { - return PyNumber_Long(_PyLong_One); + if (factors == 0) { + result = PyLong_FromLong(1); + goto done; } - val = PyNumber_Long(n); - for (i = 1; i < terms; ++i) { - temp_obj1 = PyLong_FromSsize_t(i); - if (temp_obj1 == NULL) { - goto fail_comb; - } - temp_obj2 = PyNumber_Subtract(n, temp_obj1); - if (temp_obj2 == NULL) { - goto fail_comb; + result = n; + Py_INCREF(result); + if (factors == 1) { + goto done; + } + + factor = n; + Py_INCREF(factor); + for (i = 1; i < factors; ++i) { + Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One)); + if (factor == NULL) { + goto error; } - dump_var = val; - val = PyNumber_Multiply(val, temp_obj2); - if (val == NULL) { - goto fail_comb; + Py_SETREF(result, PyNumber_Multiply(result, factor)); + if (result == NULL) { + goto error; } - Py_DECREF(dump_var); - dump_var = NULL; - Py_DECREF(temp_obj2); - temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1)); - if (temp_obj2 == NULL) { - goto fail_comb; + + temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1); + if (temp == NULL) { + goto error; } - dump_var = val; - val = PyNumber_FloorDivide(val, temp_obj2); - if (val == NULL) { - goto fail_comb; + Py_SETREF(result, PyNumber_FloorDivide(result, temp)); + Py_DECREF(temp); + if (result == NULL) { + goto error; } - Py_DECREF(dump_var); - Py_DECREF(temp_obj1); - Py_DECREF(temp_obj2); } + Py_DECREF(factor); - return val; - -fail_comb: - Py_XDECREF(val); - Py_XDECREF(dump_var); - Py_XDECREF(temp_obj1); - Py_XDECREF(temp_obj2); +done: + Py_DECREF(n); + Py_DECREF(k); + return result; +error: + Py_XDECREF(factor); + Py_XDECREF(result); + Py_DECREF(n); + Py_DECREF(k); return NULL; } |