summaryrefslogtreecommitdiff
path: root/examples/adjacency_list/adjacency_list.py
blob: a0683ea0c70000fde92937109131298f34a3f0d3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from sqlalchemy import MetaData, Table, Column, Sequence, ForeignKey,\
                        Integer, String, create_engine

from sqlalchemy.orm import sessionmaker, relationship, backref,\
                                joinedload_all
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.collections import attribute_mapped_collection


Base = declarative_base()

class TreeNode(Base):
    __tablename__ = 'tree'
    id = Column(Integer, primary_key=True)
    parent_id = Column(Integer, ForeignKey(id))
    name = Column(String(50), nullable=False)

    children = relationship("TreeNode",

                        # cascade deletions
                        cascade="all",

                        # many to one + adjacency list - remote_side
                        # is required to reference the 'remote'
                        # column in the join condition.
                        backref=backref("parent", remote_side=id),

                        # children will be represented as a dictionary
                        # on the "name" attribute.
                        collection_class=attribute_mapped_collection('name'),
                    )

    def __init__(self, name, parent=None):
        self.name = name
        self.parent = parent

    def __repr__(self):
        return "TreeNode(name=%r, id=%r, parent_id=%r)" % (
                    self.name,
                    self.id,
                    self.parent_id
                )

    def dump(self, _indent=0):

        return "   " * _indent + repr(self) + \
                    "\n" + \
                    "".join([
                        c.dump(_indent +1)
                        for c in self.children.values()]
                    )

if __name__ == '__main__':
    engine = create_engine('sqlite://', echo=True)

    def msg(msg, *args):
        msg = msg % args
        print("\n\n\n" + "-" * len(msg.split("\n")[0]))
        print(msg)
        print("-" * len(msg.split("\n")[0]))

    msg("Creating Tree Table:")

    Base.metadata.create_all(engine)

    # session.  using expire_on_commit=False
    # so that the session's contents are not expired
    # after each transaction commit.
    session = sessionmaker(engine, expire_on_commit=False)()

    node = TreeNode('rootnode')
    TreeNode('node1', parent=node)
    TreeNode('node3', parent=node)

    node2 = TreeNode('node2')
    TreeNode('subnode1', parent=node2)
    node.children['node2'] = node2
    TreeNode('subnode2', parent=node.children['node2'])

    msg("Created new tree structure:\n%s", node.dump())

    msg("flush + commit:")

    session.add(node)
    session.commit()

    msg("Tree After Save:\n %s", node.dump())

    TreeNode('node4', parent=node)
    TreeNode('subnode3', parent=node.children['node4'])
    TreeNode('subnode4', parent=node.children['node4'])
    TreeNode('subsubnode1', parent=node.children['node4'].children['subnode3'])

    # mark node1 as deleted and remove
    session.delete(node.children['node1'])

    msg("Removed node1.  flush + commit:")
    session.commit()

    # expire the "children" collection so that
    # it reflects the deletion of "node1".
    session.expire(node, ['children'])

    msg("Tree after save:\n %s", node.dump())

    msg("Emptying out the session entirely, "
        "selecting tree on root, using eager loading to join four levels deep.")
    session.expunge_all()
    node = session.query(TreeNode).\
                        options(joinedload_all("children", "children",
                                                "children", "children")).\
                        filter(TreeNode.name=="rootnode").\
                        first()

    msg("Full Tree:\n%s", node.dump())

    msg( "Marking root node as deleted, flush + commit:" )

    session.delete(node)
    session.commit()