summaryrefslogtreecommitdiff
path: root/Modules
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2019-06-01 22:09:02 +0300
committerGitHub <noreply@github.com>2019-06-01 22:09:02 +0300
commit2b843ac0ae745026ce39514573c5d075137bef65 (patch)
tree8e176372e55d171590b4c798d6deaf9311cbef8c /Modules
parent9843bc110dc4241ba7cb05f3d3ef74ac6c77caf2 (diff)
downloadcpython-git-2b843ac0ae745026ce39514573c5d075137bef65.tar.gz
bpo-35431: Refactor math.comb() implementation. (GH-13725)
* Fixed some bugs. * Added support for index-likes objects. * Improved error messages. * Cleaned up and optimized the code. * Added more tests.
Diffstat (limited to 'Modules')
-rw-r--r--Modules/clinic/mathmodule.c.h24
-rw-r--r--Modules/mathmodule.c155
2 files changed, 87 insertions, 92 deletions
diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h
index cba791e209..92ec4bec9b 100644
--- a/Modules/clinic/mathmodule.c.h
+++ b/Modules/clinic/mathmodule.c.h
@@ -639,10 +639,10 @@ exit:
}
PyDoc_STRVAR(math_comb__doc__,
-"comb($module, /, n, k)\n"
+"comb($module, n, k, /)\n"
"--\n"
"\n"
-"Number of ways to choose *k* items from *n* items without repetition and without order.\n"
+"Number of ways to choose k items from n items without repetition and without order.\n"
"\n"
"Also called the binomial coefficient. It is mathematically equal to the expression\n"
"n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n"
@@ -652,38 +652,26 @@ PyDoc_STRVAR(math_comb__doc__,
"Raises ValueError if the arguments are negative or if k > n.");
#define MATH_COMB_METHODDEF \
- {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL|METH_KEYWORDS, math_comb__doc__},
+ {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__},
static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k);
static PyObject *
-math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
{
PyObject *return_value = NULL;
- static const char * const _keywords[] = {"n", "k", NULL};
- static _PyArg_Parser _parser = {NULL, _keywords, "comb", 0};
- PyObject *argsbuf[2];
PyObject *n;
PyObject *k;
- args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 2, 0, argsbuf);
- if (!args) {
- goto exit;
- }
- if (!PyLong_Check(args[0])) {
- _PyArg_BadArgument("comb", 1, "int", args[0]);
+ if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) {
goto exit;
}
n = args[0];
- if (!PyLong_Check(args[1])) {
- _PyArg_BadArgument("comb", 2, "int", args[1]);
- goto exit;
- }
k = args[1];
return_value = math_comb_impl(module, n, k);
exit:
return return_value;
}
-/*[clinic end generated code: output=00aa76356759617a input=a9049054013a1b77]*/
+/*[clinic end generated code: output=6709521e5e1d90ec input=a9049054013a1b77]*/
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 007a880142..bea4607b9b 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
/*[clinic input]
math.comb
- n: object(subclass_of='&PyLong_Type')
- k: object(subclass_of='&PyLong_Type')
+ n: object
+ k: object
+ /
-Number of ways to choose *k* items from *n* items without repetition and without order.
+Number of ways to choose k items from n items without repetition and without order.
Also called the binomial coefficient. It is mathematically equal to the expression
n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
@@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n.
static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
-/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/
+/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/
{
- PyObject *val = NULL,
- *temp_obj1 = NULL,
- *temp_obj2 = NULL,
- *dump_var = NULL;
+ PyObject *result = NULL, *factor = NULL, *temp;
int overflow, cmp;
- long long i, terms;
+ long long i, factors;
- cmp = PyObject_RichCompareBool(n, k, Py_LT);
- if (cmp < 0) {
- goto fail_comb;
+ n = PyNumber_Index(n);
+ if (n == NULL) {
+ return NULL;
}
- else if (cmp > 0) {
- PyErr_Format(PyExc_ValueError,
- "n must be an integer greater than or equal to k");
- goto fail_comb;
+ k = PyNumber_Index(k);
+ if (k == NULL) {
+ Py_DECREF(n);
+ return NULL;
}
- /* b = min(b, a - b) */
- dump_var = PyNumber_Subtract(n, k);
- if (dump_var == NULL) {
- goto fail_comb;
+ if (Py_SIZE(n) < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "n must be a non-negative integer");
+ goto error;
}
- cmp = PyObject_RichCompareBool(k, dump_var, Py_GT);
- if (cmp < 0) {
- goto fail_comb;
+ /* k = min(k, n - k) */
+ temp = PyNumber_Subtract(n, k);
+ if (temp == NULL) {
+ goto error;
}
- else if (cmp > 0) {
- k = dump_var;
- dump_var = NULL;
+ if (Py_SIZE(temp) < 0) {
+ Py_DECREF(temp);
+ PyErr_SetString(PyExc_ValueError,
+ "k must be an integer less than or equal to n");
+ goto error;
+ }
+ cmp = PyObject_RichCompareBool(k, temp, Py_GT);
+ if (cmp > 0) {
+ Py_SETREF(k, temp);
}
else {
- Py_DECREF(dump_var);
- dump_var = NULL;
+ Py_DECREF(temp);
+ if (cmp < 0) {
+ goto error;
+ }
}
- terms = PyLong_AsLongLongAndOverflow(k, &overflow);
- if (terms < 0 && PyErr_Occurred()) {
- goto fail_comb;
- }
- else if (overflow > 0) {
+ factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+ if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
- "minimum(n - k, k) must not exceed %lld",
+ "min(n - k, k) must not exceed %lld",
LLONG_MAX);
- goto fail_comb;
+ goto error;
}
- else if (overflow < 0 || terms < 0) {
- PyErr_Format(PyExc_ValueError,
- "k must be a positive integer");
- goto fail_comb;
+ else if (overflow < 0 || factors < 0) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_ValueError,
+ "k must be a non-negative integer");
+ }
+ goto error;
}
- if (terms == 0) {
- return PyNumber_Long(_PyLong_One);
+ if (factors == 0) {
+ result = PyLong_FromLong(1);
+ goto done;
}
- val = PyNumber_Long(n);
- for (i = 1; i < terms; ++i) {
- temp_obj1 = PyLong_FromSsize_t(i);
- if (temp_obj1 == NULL) {
- goto fail_comb;
- }
- temp_obj2 = PyNumber_Subtract(n, temp_obj1);
- if (temp_obj2 == NULL) {
- goto fail_comb;
+ result = n;
+ Py_INCREF(result);
+ if (factors == 1) {
+ goto done;
+ }
+
+ factor = n;
+ Py_INCREF(factor);
+ for (i = 1; i < factors; ++i) {
+ Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
+ if (factor == NULL) {
+ goto error;
}
- dump_var = val;
- val = PyNumber_Multiply(val, temp_obj2);
- if (val == NULL) {
- goto fail_comb;
+ Py_SETREF(result, PyNumber_Multiply(result, factor));
+ if (result == NULL) {
+ goto error;
}
- Py_DECREF(dump_var);
- dump_var = NULL;
- Py_DECREF(temp_obj2);
- temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1));
- if (temp_obj2 == NULL) {
- goto fail_comb;
+
+ temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
+ if (temp == NULL) {
+ goto error;
}
- dump_var = val;
- val = PyNumber_FloorDivide(val, temp_obj2);
- if (val == NULL) {
- goto fail_comb;
+ Py_SETREF(result, PyNumber_FloorDivide(result, temp));
+ Py_DECREF(temp);
+ if (result == NULL) {
+ goto error;
}
- Py_DECREF(dump_var);
- Py_DECREF(temp_obj1);
- Py_DECREF(temp_obj2);
}
+ Py_DECREF(factor);
- return val;
-
-fail_comb:
- Py_XDECREF(val);
- Py_XDECREF(dump_var);
- Py_XDECREF(temp_obj1);
- Py_XDECREF(temp_obj2);
+done:
+ Py_DECREF(n);
+ Py_DECREF(k);
+ return result;
+error:
+ Py_XDECREF(factor);
+ Py_XDECREF(result);
+ Py_DECREF(n);
+ Py_DECREF(k);
return NULL;
}