diff options
author | Bob Halley <halley@dnspython.org> | 2020-07-17 16:08:11 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-17 16:08:11 -0700 |
commit | ac3f05d0d5b61ba46444cb9e66deb7357ff8522e (patch) | |
tree | b7e7b2f4b124e478d24edea6143e4e02b286ccac | |
parent | 65d201ea8c53e03f8c7b732b21ea47f7e9188cab (diff) | |
parent | 147924d0a433968c639f75630009eff8a872a4d3 (diff) | |
download | dnspython-ac3f05d0d5b61ba46444cb9e66deb7357ff8522e.tar.gz |
Merge pull request #538 from bwelling/selectors
Use the selectors module.
-rw-r--r-- | dns/query.py | 97 | ||||
-rw-r--r-- | tests/test_query.py | 14 | ||||
-rw-r--r-- | tests/test_resolver.py | 18 |
3 files changed, 35 insertions, 94 deletions
diff --git a/dns/query.py b/dns/query.py index 7df565d..eb82771 100644 --- a/dns/query.py +++ b/dns/query.py @@ -20,7 +20,7 @@ import contextlib import errno import os -import select +import selectors import socket import struct import time @@ -94,91 +94,46 @@ def _compute_times(timeout): else: return (now, now + timeout) -# This module can use either poll() or select() as the "polling backend". -# -# A backend function takes an fd, bools for readability, writablity, and -# error detection, and a timeout. - -def _poll_for(fd, readable, writable, error, timeout): - """Poll polling backend.""" - - event_mask = 0 - if readable: - event_mask |= select.POLLIN - if writable: - event_mask |= select.POLLOUT - if error: - event_mask |= select.POLLERR - - pollable = select.poll() - pollable.register(fd, event_mask) - - if timeout: - event_list = pollable.poll(timeout * 1000) - else: - event_list = pollable.poll() - - return bool(event_list) - - -def _select_for(fd, readable, writable, error, timeout): - """Select polling backend.""" - - rset, wset, xset = [], [], [] - - if readable: - rset = [fd] - if writable: - wset = [fd] - if error: - xset = [fd] - - if timeout is None: - (rcount, wcount, xcount) = select.select(rset, wset, xset) - else: - (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout) - - return bool((rcount or wcount or xcount)) - def _wait_for(fd, readable, writable, error, expiration): - # Use the selected polling backend to wait for any of the specified + # Use the selected selector class to wait for any of the specified # events. An "expiration" absolute time is converted into a relative # timeout. - done = False - while not done: - if expiration is None: - timeout = None - else: - timeout = expiration - time.time() - if timeout <= 0.0: - raise dns.exception.Timeout - try: - if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0: - return True - if not _polling_backend(fd, readable, writable, error, timeout): - raise dns.exception.Timeout - except OSError as e: # pragma: no cover - if e.args[0] != errno.EINTR: - raise e - done = True + if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0: + return True + sel = _selector_class() + events = 0 + if readable: + events |= selectors.EVENT_READ + if writable: + events |= selectors.EVENT_WRITE + if events: + sel.register(fd, events) + if expiration is None: + timeout = None + else: + timeout = expiration - time.time() + if timeout <= 0.0: + raise dns.exception.Timeout + if not sel.select(timeout): + raise dns.exception.Timeout -def _set_polling_backend(fn): +def _set_selector_class(selector_class): # Internal API. Do not use. - global _polling_backend + global _selector_class - _polling_backend = fn + _selector_class = selector_class -if hasattr(select, 'poll'): +if hasattr(selectors, 'PollSelector'): # Prefer poll() on platforms that support it because it has no # limits on the maximum value of a file descriptor (plus it will # be more efficient for high values). - _polling_backend = _poll_for + _selector_class = selectors.PollSelector else: - _polling_backend = _select_for # pragma: no cover + _selector_class = selectors.SelectSelector # pragma: no cover def _wait_for_readable(s, expiration): diff --git a/tests/test_query.py b/tests/test_query.py index 498128d..a13833e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -540,17 +540,3 @@ class LowLevelWaitTests(unittest.TestCase): finally: l.close() r.close() - - def test_select_for(self): - # we test this explicitly in case _wait_for didn't test it (i.e. - # if the default polling backing is _poll_for) - try: - (l, r) = socket.socketpair() - # simple timeout - self.assertFalse(dns.query._select_for(l, False, False, False, - 0.05)) - # writable no timeout - self.assertTrue(dns.query._select_for(l, False, True, False, None)) - finally: - l.close() - r.close() diff --git a/tests/test_resolver.py b/tests/test_resolver.py index a6ab473..cadf224 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -16,7 +16,7 @@ # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. from io import StringIO -import select +import selectors import sys import socket import time @@ -477,26 +477,26 @@ class LiveResolverTests(unittest.TestCase): class PollingMonkeyPatchMixin(object): def setUp(self): - self.__native_polling_backend = dns.query._polling_backend - dns.query._set_polling_backend(self.polling_backend()) + self.__native_selector_class = dns.query._selector_class + dns.query._set_selector_class(self.selector_class()) unittest.TestCase.setUp(self) def tearDown(self): - dns.query._set_polling_backend(self.__native_polling_backend) + dns.query._set_selector_class(self.__native_selector_class) unittest.TestCase.tearDown(self) class SelectResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase): - def polling_backend(self): - return dns.query._select_for + def selector_class(self): + return selectors.SelectSelector -if hasattr(select, 'poll'): +if hasattr(selectors, 'PollSelector'): class PollResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase): - def polling_backend(self): - return dns.query._poll_for + def selector_class(self): + return selectors.PollSelector class NXDOMAINExceptionTestCase(unittest.TestCase): |