diff options
Diffstat (limited to 'dns')
-rw-r--r-- | dns/asyncbackend.py | 74 |
1 files changed, 58 insertions, 16 deletions
diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py index 92a1ae3..069eaf0 100644 --- a/dns/asyncbackend.py +++ b/dns/asyncbackend.py @@ -6,16 +6,51 @@ from dns._asyncbackend import Socket, DatagramSocket, \ _default_backend = None +_trio_backend = None +_curio_backend = None +_asyncio_backend = None -def get_default_backend(): - if _default_backend: - return _default_backend +def get_backend(name): + """Get the specified asychronous backend. - return set_default_backend(sniff()) + *name*, a ``str``, the name of the backend. Currently the "trio", + "curio", and "asyncio" backends are available. + + Raises NotImplementError if an unknown backend name is specified. + """ + if name == 'trio': + global _trio_backend + if _trio_backend: + return _trio_backend + import dns._trio_backend + _trio_backend = dns._trio_backend.Backend() + return _trio_backend + elif name == 'curio': + global _curio_backend + if _curio_backend: + return _curio_backend + import dns._curio_backend + _curio_backend = dns._curio_backend.Backend() + return _curio_backend + elif name == 'asyncio': + global _asyncio_backend + if _asyncio_backend: + return _asyncio_backend + import dns._asyncio_backend + _asyncio_backend = dns._asyncio_backend.Backend() + return _asyncio_backend + else: + raise NotImplementedError(f'unimplemented async backend {name}') def sniff(): + """Attempt to determine the in-use asynchronous I/O library by using + the ``sniffio`` module if it is available. + + Returns the name of the library, defaulting to "asyncio" if no other + library appears to be in use. + """ name = 'asyncio' try: import sniffio @@ -25,19 +60,26 @@ def sniff(): return name +def get_default_backend(): + """Get the default backend, initializing it if necessary. + """ + if _default_backend: + return _default_backend + + return set_default_backend(sniff()) + + def set_default_backend(name): + """Set the default backend. + + It's not normally necessary to call this method, as + ``get_default_backend()`` will initialize the backend + appropriately in many cases. If ``sniffio`` is not installed, or + in testing situations, this function allows the backend to be set + explicitly. + """ global _default_backend + _default_backend = get_backend(name) + return _default_backend - if name == 'trio': - import dns._trio_backend - _default_backend = dns._trio_backend.Backend() - elif name == 'curio': - import dns._curio_backend - _default_backend = dns._curio_backend.Backend() - elif name == 'asyncio': - import dns._asyncio_backend - _default_backend = dns._asyncio_backend.Backend() - else: - raise NotImplementedException(f'unimplemented async backend {name}') - return _default_backend |