summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-07-17 16:08:11 -0700
committerGitHub <noreply@github.com>2020-07-17 16:08:11 -0700
commitac3f05d0d5b61ba46444cb9e66deb7357ff8522e (patch)
treeb7e7b2f4b124e478d24edea6143e4e02b286ccac
parent65d201ea8c53e03f8c7b732b21ea47f7e9188cab (diff)
parent147924d0a433968c639f75630009eff8a872a4d3 (diff)
downloaddnspython-ac3f05d0d5b61ba46444cb9e66deb7357ff8522e.tar.gz
Merge pull request #538 from bwelling/selectors
Use the selectors module.
-rw-r--r--dns/query.py97
-rw-r--r--tests/test_query.py14
-rw-r--r--tests/test_resolver.py18
3 files changed, 35 insertions, 94 deletions
diff --git a/dns/query.py b/dns/query.py
index 7df565d..eb82771 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -20,7 +20,7 @@
import contextlib
import errno
import os
-import select
+import selectors
import socket
import struct
import time
@@ -94,91 +94,46 @@ def _compute_times(timeout):
else:
return (now, now + timeout)
-# This module can use either poll() or select() as the "polling backend".
-#
-# A backend function takes an fd, bools for readability, writablity, and
-# error detection, and a timeout.
-
-def _poll_for(fd, readable, writable, error, timeout):
- """Poll polling backend."""
-
- event_mask = 0
- if readable:
- event_mask |= select.POLLIN
- if writable:
- event_mask |= select.POLLOUT
- if error:
- event_mask |= select.POLLERR
-
- pollable = select.poll()
- pollable.register(fd, event_mask)
-
- if timeout:
- event_list = pollable.poll(timeout * 1000)
- else:
- event_list = pollable.poll()
-
- return bool(event_list)
-
-
-def _select_for(fd, readable, writable, error, timeout):
- """Select polling backend."""
-
- rset, wset, xset = [], [], []
-
- if readable:
- rset = [fd]
- if writable:
- wset = [fd]
- if error:
- xset = [fd]
-
- if timeout is None:
- (rcount, wcount, xcount) = select.select(rset, wset, xset)
- else:
- (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
-
- return bool((rcount or wcount or xcount))
-
def _wait_for(fd, readable, writable, error, expiration):
- # Use the selected polling backend to wait for any of the specified
+ # Use the selected selector class to wait for any of the specified
# events. An "expiration" absolute time is converted into a relative
# timeout.
- done = False
- while not done:
- if expiration is None:
- timeout = None
- else:
- timeout = expiration - time.time()
- if timeout <= 0.0:
- raise dns.exception.Timeout
- try:
- if isinstance(fd, ssl.SSLSocket) and readable and fd.pending() > 0:
- return True
- if not _polling_backend(fd, readable, writable, error, timeout):
- raise dns.exception.Timeout
- except OSError as e: # pragma: no cover
- if e.args[0] != errno.EINTR:
- raise e
- done = True
+ if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
+ return True
+ sel = _selector_class()
+ events = 0
+ if readable:
+ events |= selectors.EVENT_READ
+ if writable:
+ events |= selectors.EVENT_WRITE
+ if events:
+ sel.register(fd, events)
+ if expiration is None:
+ timeout = None
+ else:
+ timeout = expiration - time.time()
+ if timeout <= 0.0:
+ raise dns.exception.Timeout
+ if not sel.select(timeout):
+ raise dns.exception.Timeout
-def _set_polling_backend(fn):
+def _set_selector_class(selector_class):
# Internal API. Do not use.
- global _polling_backend
+ global _selector_class
- _polling_backend = fn
+ _selector_class = selector_class
-if hasattr(select, 'poll'):
+if hasattr(selectors, 'PollSelector'):
# Prefer poll() on platforms that support it because it has no
# limits on the maximum value of a file descriptor (plus it will
# be more efficient for high values).
- _polling_backend = _poll_for
+ _selector_class = selectors.PollSelector
else:
- _polling_backend = _select_for # pragma: no cover
+ _selector_class = selectors.SelectSelector # pragma: no cover
def _wait_for_readable(s, expiration):
diff --git a/tests/test_query.py b/tests/test_query.py
index 498128d..a13833e 100644
--- a/tests/test_query.py
+++ b/tests/test_query.py
@@ -540,17 +540,3 @@ class LowLevelWaitTests(unittest.TestCase):
finally:
l.close()
r.close()
-
- def test_select_for(self):
- # we test this explicitly in case _wait_for didn't test it (i.e.
- # if the default polling backing is _poll_for)
- try:
- (l, r) = socket.socketpair()
- # simple timeout
- self.assertFalse(dns.query._select_for(l, False, False, False,
- 0.05))
- # writable no timeout
- self.assertTrue(dns.query._select_for(l, False, True, False, None))
- finally:
- l.close()
- r.close()
diff --git a/tests/test_resolver.py b/tests/test_resolver.py
index a6ab473..cadf224 100644
--- a/tests/test_resolver.py
+++ b/tests/test_resolver.py
@@ -16,7 +16,7 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from io import StringIO
-import select
+import selectors
import sys
import socket
import time
@@ -477,26 +477,26 @@ class LiveResolverTests(unittest.TestCase):
class PollingMonkeyPatchMixin(object):
def setUp(self):
- self.__native_polling_backend = dns.query._polling_backend
- dns.query._set_polling_backend(self.polling_backend())
+ self.__native_selector_class = dns.query._selector_class
+ dns.query._set_selector_class(self.selector_class())
unittest.TestCase.setUp(self)
def tearDown(self):
- dns.query._set_polling_backend(self.__native_polling_backend)
+ dns.query._set_selector_class(self.__native_selector_class)
unittest.TestCase.tearDown(self)
class SelectResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase):
- def polling_backend(self):
- return dns.query._select_for
+ def selector_class(self):
+ return selectors.SelectSelector
-if hasattr(select, 'poll'):
+if hasattr(selectors, 'PollSelector'):
class PollResolverTestCase(PollingMonkeyPatchMixin, LiveResolverTests, unittest.TestCase):
- def polling_backend(self):
- return dns.query._poll_for
+ def selector_class(self):
+ return selectors.PollSelector
class NXDOMAINExceptionTestCase(unittest.TestCase):