summaryrefslogtreecommitdiff
path: root/kombu/pools.py
blob: f2d92678c4855e64d5724d5efe86665961604807 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from kombu.connection import Resource
from kombu.messaging import Producer

from itertools import chain

__all__ = ["ProducerPool", "connections", "producers", "set_limit", "reset"]
_limit = [200]
_groups = []


def register_group(group):
    _groups.append(group)
    return group


class ProducerPool(Resource):
    Producer = Producer

    def __init__(self, connections, *args, **kwargs):
        self.connections = connections
        super(ProducerPool, self).__init__(*args, **kwargs)

    def create_producer(self):
        conn = self.connections.acquire(block=True)
        producer = self.Producer(conn)
        producer.connection = conn
        return producer

    def new(self):
        return lambda: self.create_producer()

    def setup(self):
        if self.limit:
            for _ in xrange(self.limit):
                self._resource.put_nowait(self.new())

    def prepare(self, p):
        if callable(p):
            p = p()
        if not p.connection:
            p.connection = self.connections.acquire(block=True)
            p.revive(p.connection.default_channel)
        return p

    def release(self, resource):
        resource.connection.release()
        resource.connection = None
        super(ProducerPool, self).release(resource)


class HashingDict(dict):

    def __getitem__(self, key):
        h = hash(key)
        if h not in self:
            return self.__missing__(key)
        return dict.__getitem__(self, h)

    def __setitem__(self, key, value):
        return dict.__setitem__(self, hash(key), value)

    def __delitem__(self, key):
        return dict.__delitem__(self, hash(key))


class PoolGroup(HashingDict):

    def create(self, resource, limit):
        raise NotImplementedError("PoolGroups must define ``create``")

    def __missing__(self, resource):
        k = self[resource] = self.create(resource, _limit[0])
        return k


class _Connections(PoolGroup):

    def create(self, connection, limit):
        return connection.Pool(limit=limit)
connections = register_group(_Connections())


class _Producers(HashingDict):

    def create(self, connection, limit):
        return ProducerPool(connections[connection], limit=limit)
producers = register_group(_Producers())


def _all_pools():
    return chain(*[(g.itervalues() if g else iter([])) for g in _groups])


def set_limit(limit):
    _limit[0] = limit
    for pool in _all_pools():
        pool.limit = limit
    reset()
    return limit


def reset(*args, **kwargs):
    for pool in _all_pools():
        try:
            pool.force_close_all()
        except Exception:
            pass
    for group in _groups:
        group.clear()


try:
    from multiprocessing.util import register_after_fork
    register_after_fork(connections, reset)
except ImportError:
    pass