diff options
Diffstat (limited to 'examples/nested_sets/nested_sets.py')
| -rw-r--r-- | examples/nested_sets/nested_sets.py | 108 |
1 files changed, 66 insertions, 42 deletions
diff --git a/examples/nested_sets/nested_sets.py b/examples/nested_sets/nested_sets.py index c64b15b61..705a3d279 100644 --- a/examples/nested_sets/nested_sets.py +++ b/examples/nested_sets/nested_sets.py @@ -4,19 +4,27 @@ http://www.intelligententerprise.com/001020/celko.jhtml """ -from sqlalchemy import (create_engine, Column, Integer, String, select, case, - func) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + select, + case, + func, +) from sqlalchemy.orm import Session, aliased from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import event Base = declarative_base() + class Employee(Base): - __tablename__ = 'personnel' + __tablename__ = "personnel" __mapper_args__ = { - 'batch': False # allows extension to fire for each - # instance before going to the next. + "batch": False # allows extension to fire for each + # instance before going to the next. } parent = None @@ -29,6 +37,7 @@ class Employee(Base): def __repr__(self): return "Employee(%s, %d, %d)" % (self.emp, self.left, self.right) + @event.listens_for(Employee, "before_insert") def before_insert(mapper, connection, instance): if not instance.parent: @@ -37,23 +46,31 @@ def before_insert(mapper, connection, instance): else: personnel = mapper.mapped_table right_most_sibling = connection.scalar( - select([personnel.c.rgt]). - where(personnel.c.emp == instance.parent.emp) + select([personnel.c.rgt]).where( + personnel.c.emp == instance.parent.emp + ) ) connection.execute( - personnel.update( - personnel.c.rgt >= right_most_sibling).values( - lft=case( - [(personnel.c.lft > right_most_sibling, - personnel.c.lft + 2)], - else_=personnel.c.lft - ), - rgt=case( - [(personnel.c.rgt >= right_most_sibling, - personnel.c.rgt + 2)], - else_=personnel.c.rgt - ) + personnel.update(personnel.c.rgt >= right_most_sibling).values( + lft=case( + [ + ( + personnel.c.lft > right_most_sibling, + personnel.c.lft + 2, + ) + ], + else_=personnel.c.lft, + ), + rgt=case( + [ + ( + personnel.c.rgt >= right_most_sibling, + personnel.c.rgt + 2, + ) + ], + else_=personnel.c.rgt, + ), ) ) instance.left = right_most_sibling @@ -62,18 +79,19 @@ def before_insert(mapper, connection, instance): # before_update() would be needed to support moving of nodes # after_delete() would be needed to support removal of nodes. -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(bind=engine) -albert = Employee(emp='Albert') -bert = Employee(emp='Bert') -chuck = Employee(emp='Chuck') -donna = Employee(emp='Donna') -eddie = Employee(emp='Eddie') -fred = Employee(emp='Fred') +albert = Employee(emp="Albert") +bert = Employee(emp="Bert") +chuck = Employee(emp="Chuck") +donna = Employee(emp="Donna") +eddie = Employee(emp="Eddie") +fred = Employee(emp="Fred") bert.parent = albert chuck.parent = albert @@ -90,22 +108,28 @@ print(session.query(Employee).all()) # 1. Find an employee and all their supervisors, no matter how deep the tree. ealias = aliased(Employee) -print(session.query(Employee).\ - filter(ealias.left.between(Employee.left, Employee.right)).\ - filter(ealias.emp == 'Eddie').all()) - -#2. Find the employee and all their subordinates. +print( + session.query(Employee) + .filter(ealias.left.between(Employee.left, Employee.right)) + .filter(ealias.emp == "Eddie") + .all() +) + +# 2. Find the employee and all their subordinates. # (This query has a nice symmetry with the first query.) -print(session.query(Employee).\ - filter(Employee.left.between(ealias.left, ealias.right)).\ - filter(ealias.emp == 'Chuck').all()) - -#3. Find the level of each node, so you can print the tree +print( + session.query(Employee) + .filter(Employee.left.between(ealias.left, ealias.right)) + .filter(ealias.emp == "Chuck") + .all() +) + +# 3. Find the level of each node, so you can print the tree # as an indented listing. -for indentation, employee in session.query( - func.count(Employee.emp).label('indentation') - 1, ealias).\ - filter(ealias.left.between(Employee.left, Employee.right)).\ - group_by(ealias.emp).\ - order_by(ealias.left): +for indentation, employee in ( + session.query(func.count(Employee.emp).label("indentation") - 1, ealias) + .filter(ealias.left.between(Employee.left, Employee.right)) + .group_by(ealias.emp) + .order_by(ealias.left) +): print(" " * indentation + str(employee)) - |
