diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/cextension/utils.c | 249 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/util.py | 104 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/profiling.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/requirements.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/_collections.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 11 |
7 files changed, 80 insertions, 324 deletions
diff --git a/lib/sqlalchemy/cextension/utils.c b/lib/sqlalchemy/cextension/utils.c deleted file mode 100644 index e06843c9d..000000000 --- a/lib/sqlalchemy/cextension/utils.c +++ /dev/null @@ -1,249 +0,0 @@ -/* -utils.c -Copyright (C) 2012-2021 the SQLAlchemy authors and contributors <see AUTHORS file> - -This module is part of SQLAlchemy and is released under -the MIT License: http://www.opensource.org/licenses/mit-license.php -*/ - -#include <Python.h> - -#define MODULE_NAME "cutils" -#define MODULE_DOC "Module containing C versions of utility functions." - -/* - Given arguments from the calling form *multiparams, **params, - return a list of bind parameter structures, usually a list of - dictionaries. - - In the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists. - - */ -static PyObject * -distill_params(PyObject *self, PyObject *args) -{ - // TODO: pass the Connection in so that there can be a standard - // method for warning on parameter format - - PyObject *connection, *multiparams, *params; - PyObject *enclosing_list, *double_enclosing_list; - PyObject *zero_element, *zero_element_item; - PyObject *tmp; - Py_ssize_t multiparam_size, zero_element_length; - - if (!PyArg_UnpackTuple(args, "_distill_params", 3, 3, &connection, &multiparams, ¶ms)) { - return NULL; - } - - if (multiparams != Py_None) { - multiparam_size = PyTuple_Size(multiparams); - if (multiparam_size < 0) { - return NULL; - } - } - else { - multiparam_size = 0; - } - - if (multiparam_size == 0) { - if (params != Py_None && PyMapping_Size(params) != 0) { - - tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); - if (tmp == NULL) { - return NULL; - } - - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(params); - if (PyList_SetItem(enclosing_list, 0, params) == -1) { - Py_DECREF(params); - Py_DECREF(enclosing_list); - return NULL; - } - } - else { - enclosing_list = PyList_New(0); - if (enclosing_list == NULL) { - return NULL; - } - } - return enclosing_list; - } - else if (multiparam_size == 1) { - zero_element = PyTuple_GetItem(multiparams, 0); - if (PyTuple_Check(zero_element) || PyList_Check(zero_element)) { - zero_element_length = PySequence_Length(zero_element); - - if (zero_element_length != 0) { - zero_element_item = PySequence_GetItem(zero_element, 0); - if (zero_element_item == NULL) { - return NULL; - } - } - else { - zero_element_item = NULL; - } - - if (zero_element_length == 0 || - ( - PyObject_HasAttrString(zero_element_item, "__iter__") && - !PyObject_HasAttrString(zero_element_item, "strip") - ) - ) { - /* - * execute(stmt, [{}, {}, {}, ...]) - * execute(stmt, [(), (), (), ...]) - */ - Py_XDECREF(zero_element_item); - Py_INCREF(zero_element); - return zero_element; - } - else { - /* - * execute(stmt, ("value", "value")) - */ - Py_XDECREF(zero_element_item); - - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(zero_element); - if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - return NULL; - } - return enclosing_list; - } - } - else if (PyObject_HasAttrString(zero_element, "keys")) { - /* - * execute(stmt, {"key":"value"}) - */ - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(zero_element); - if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - return NULL; - } - return enclosing_list; - } else { - tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); - if (tmp == NULL) { - return NULL; - } - - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - double_enclosing_list = PyList_New(1); - if (double_enclosing_list == NULL) { - Py_DECREF(enclosing_list); - return NULL; - } - Py_INCREF(zero_element); - if (PyList_SetItem(enclosing_list, 0, zero_element) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - Py_DECREF(double_enclosing_list); - return NULL; - } - if (PyList_SetItem(double_enclosing_list, 0, enclosing_list) == -1) { - Py_DECREF(zero_element); - Py_DECREF(enclosing_list); - Py_DECREF(double_enclosing_list); - return NULL; - } - return double_enclosing_list; - } - } - else { - - tmp = PyObject_CallMethod(connection, "_warn_for_legacy_exec_format", ""); - if (tmp == NULL) { - return NULL; - } - - zero_element = PyTuple_GetItem(multiparams, 0); - if (PyObject_HasAttrString(zero_element, "__iter__") && - !PyObject_HasAttrString(zero_element, "strip") - ) { - Py_INCREF(multiparams); - return multiparams; - } - else { - enclosing_list = PyList_New(1); - if (enclosing_list == NULL) { - return NULL; - } - Py_INCREF(multiparams); - if (PyList_SetItem(enclosing_list, 0, multiparams) == -1) { - Py_DECREF(multiparams); - Py_DECREF(enclosing_list); - return NULL; - } - return enclosing_list; - } - } -} - -static PyMethodDef module_methods[] = { - {"_distill_params", distill_params, METH_VARARGS, - "Distill an execute() parameter structure."}, - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ -#define PyMODINIT_FUNC void -#endif - -#if PY_MAJOR_VERSION >= 3 - -#define INITERROR return NULL - -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - MODULE_NAME, - MODULE_DOC, - -1, - module_methods - }; - -PyMODINIT_FUNC -PyInit_cutils(void) - -#else - -#define INITERROR return - -PyMODINIT_FUNC -initcutils(void) - -#endif - -{ - PyObject *m; - -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&module_def); -#else - m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); -#endif - if (m == NULL) - INITERROR; - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} - diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index 4e302f464..ede263198 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -34,63 +34,55 @@ _no_tuple = () _no_kw = util.immutabledict() -def py_fallback(): - # TODO: pass the Connection in so that there can be a standard - # method for warning on parameter format - def _distill_params(connection, multiparams, params): # noqa - r"""Given arguments from the calling form \*multiparams, \**params, - return a list of bind parameter structures, usually a list of - dictionaries. - - In the case of 'raw' execution which accepts positional parameters, - it may be a list of tuples or lists. - - """ - - # C version will fail if this assertion is not true. - # assert isinstance(multiparams, tuple) - - if not multiparams: - if params: - connection._warn_for_legacy_exec_format() - return [params] +def _distill_params(connection, multiparams, params): + r"""Given arguments from the calling form \*multiparams, \**params, + return a list of bind parameter structures, usually a list of + dictionaries. + + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + + """ + + if not multiparams: + if params: + connection._warn_for_legacy_exec_format() + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): + # execute(stmt, [{}, {}, {}, ...]) + # execute(stmt, [(), (), (), ...]) + return zero else: - return [] - elif len(multiparams) == 1: - zero = multiparams[0] - if isinstance(zero, (list, tuple)): - if ( - not zero - or hasattr(zero[0], "__iter__") - and not hasattr(zero[0], "strip") - ): - # execute(stmt, [{}, {}, {}, ...]) - # execute(stmt, [(), (), (), ...]) - return zero - else: - # this is used by exec_driver_sql only, so a deprecation - # warning would already be coming from passing a plain - # textual statement with positional parameters to - # execute(). - # execute(stmt, ("value", "value")) - return [zero] - elif hasattr(zero, "keys"): - # execute(stmt, {"key":"value"}) + # this is used by exec_driver_sql only, so a deprecation + # warning would already be coming from passing a plain + # textual statement with positional parameters to + # execute(). + # execute(stmt, ("value", "value")) return [zero] - else: - connection._warn_for_legacy_exec_format() - # execute(stmt, "value") - return [[zero]] + elif hasattr(zero, "keys"): + # execute(stmt, {"key":"value"}) + return [zero] else: connection._warn_for_legacy_exec_format() - if hasattr(multiparams[0], "__iter__") and not hasattr( - multiparams[0], "strip" - ): - return multiparams - else: - return [multiparams] - - return locals() + # execute(stmt, "value") + return [[zero]] + else: + connection._warn_for_legacy_exec_format() + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): + return multiparams + else: + return [multiparams] def _distill_cursor_params(connection, multiparams, params): @@ -161,9 +153,3 @@ def _distill_params_20(params): return (params,), _no_kw else: raise exc.ArgumentError("mapping or sequence expected for parameters") - - -try: - from sqlalchemy.cutils import _distill_params # noqa -except ImportError: - globals().update(py_fallback()) diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py index 16c6d458c..5e4f19273 100644 --- a/lib/sqlalchemy/testing/profiling.py +++ b/lib/sqlalchemy/testing/profiling.py @@ -22,6 +22,7 @@ import sys from . import config from .util import gc_collect +from ..util import has_compiled_ext try: @@ -109,7 +110,7 @@ class ProfileStatsFile(object): if config.db.dialect.convert_unicode else "dbapiunicode" ) - _has_cext = config.requirements._has_cextensions() + _has_cext = has_compiled_ext() platform_tokens.append(_has_cext and "cextensions" or "nocextensions") return "_".join(platform_tokens) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index d8da9c818..f16ba326c 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1258,7 +1258,7 @@ class SuiteRequirements(Requirements): @property def cextensions(self): return exclusions.skip_if( - lambda: not self._has_cextensions(), "C extensions not installed" + lambda: not util.has_compiled_ext(), "C extensions not installed" ) def _has_sqlite(self): @@ -1270,14 +1270,6 @@ class SuiteRequirements(Requirements): except ImportError: return False - def _has_cextensions(self): - try: - from sqlalchemy import cresultproxy, cprocessors # noqa - - return True - except ImportError: - return False - @property def async_dialect(self): """dialect makes use of await_() to invoke operations on the DBAPI.""" diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 2d86b8b63..4b61658b2 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -140,6 +140,7 @@ from .langhelpers import get_callable_argspec # noqa from .langhelpers import get_cls_kwargs # noqa from .langhelpers import get_func_kwargs # noqa from .langhelpers import getargspec_init # noqa +from .langhelpers import has_compiled_ext # noqa from .langhelpers import HasMemoized # noqa from .langhelpers import hybridmethod # noqa from .langhelpers import hybridproperty # noqa diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index b18cc13de..7484a8f1a 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -363,7 +363,6 @@ else: class OrderedSet(set): def __init__(self, d=None): set.__init__(self) - self._list = [] if d is not None: self._list = unique_list(d) set.update(self, self._list) @@ -521,7 +520,10 @@ class IdentitySet(object): return True def issubset(self, iterable): - other = self.__class__(iterable) + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) if len(self) > len(other): return False @@ -542,7 +544,10 @@ class IdentitySet(object): return len(self) < len(other) and self.issubset(other) def issuperset(self, iterable): - other = self.__class__(iterable) + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) if len(self) < len(other): return False @@ -587,7 +592,10 @@ class IdentitySet(object): def difference(self, iterable): result = self.__class__() members = self._members - other = {id(obj) for obj in iterable} + if isinstance(iterable, self.__class__): + other = set(iterable._members.keys()) + else: + other = {id(obj) for obj in iterable} result._members.update( ((k, v) for k, v in members.items() if k not in other) ) @@ -610,7 +618,10 @@ class IdentitySet(object): def intersection(self, iterable): result = self.__class__() members = self._members - other = {id(obj) for obj in iterable} + if isinstance(iterable, self.__class__): + other = set(iterable._members.keys()) + else: + other = {id(obj) for obj in iterable} result._members.update( (k, v) for k, v in members.items() if k in other ) @@ -633,7 +644,10 @@ class IdentitySet(object): def symmetric_difference(self, iterable): result = self.__class__() members = self._members - other = {id(obj): obj for obj in iterable} + if isinstance(iterable, self.__class__): + other = iterable._members + else: + other = {id(obj): obj for obj in iterable} result._members.update( ((k, v) for k, v in members.items() if k not in other) ) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 457d2875d..eb582b528 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1879,3 +1879,14 @@ def repr_tuple_names(names): return ", ".join(res) else: return "%s, ..., %s" % (", ".join(res[0:3]), res[-1]) + + +def has_compiled_ext(): + try: + from sqlalchemy import cimmutabledict # noqa F401 + from sqlalchemy import cprocessors # noqa F401 + from sqlalchemy import cresultproxy # noqa F401 + + return True + except ImportError: + return False |
