summaryrefslogtreecommitdiff
path: root/examples/nested_sets/nested_sets.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/nested_sets/nested_sets.py')
-rw-r--r--examples/nested_sets/nested_sets.py108
1 files changed, 66 insertions, 42 deletions
diff --git a/examples/nested_sets/nested_sets.py b/examples/nested_sets/nested_sets.py
index c64b15b61..705a3d279 100644
--- a/examples/nested_sets/nested_sets.py
+++ b/examples/nested_sets/nested_sets.py
@@ -4,19 +4,27 @@ http://www.intelligententerprise.com/001020/celko.jhtml
"""
-from sqlalchemy import (create_engine, Column, Integer, String, select, case,
- func)
+from sqlalchemy import (
+ create_engine,
+ Column,
+ Integer,
+ String,
+ select,
+ case,
+ func,
+)
from sqlalchemy.orm import Session, aliased
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import event
Base = declarative_base()
+
class Employee(Base):
- __tablename__ = 'personnel'
+ __tablename__ = "personnel"
__mapper_args__ = {
- 'batch': False # allows extension to fire for each
- # instance before going to the next.
+ "batch": False # allows extension to fire for each
+ # instance before going to the next.
}
parent = None
@@ -29,6 +37,7 @@ class Employee(Base):
def __repr__(self):
return "Employee(%s, %d, %d)" % (self.emp, self.left, self.right)
+
@event.listens_for(Employee, "before_insert")
def before_insert(mapper, connection, instance):
if not instance.parent:
@@ -37,23 +46,31 @@ def before_insert(mapper, connection, instance):
else:
personnel = mapper.mapped_table
right_most_sibling = connection.scalar(
- select([personnel.c.rgt]).
- where(personnel.c.emp == instance.parent.emp)
+ select([personnel.c.rgt]).where(
+ personnel.c.emp == instance.parent.emp
+ )
)
connection.execute(
- personnel.update(
- personnel.c.rgt >= right_most_sibling).values(
- lft=case(
- [(personnel.c.lft > right_most_sibling,
- personnel.c.lft + 2)],
- else_=personnel.c.lft
- ),
- rgt=case(
- [(personnel.c.rgt >= right_most_sibling,
- personnel.c.rgt + 2)],
- else_=personnel.c.rgt
- )
+ personnel.update(personnel.c.rgt >= right_most_sibling).values(
+ lft=case(
+ [
+ (
+ personnel.c.lft > right_most_sibling,
+ personnel.c.lft + 2,
+ )
+ ],
+ else_=personnel.c.lft,
+ ),
+ rgt=case(
+ [
+ (
+ personnel.c.rgt >= right_most_sibling,
+ personnel.c.rgt + 2,
+ )
+ ],
+ else_=personnel.c.rgt,
+ ),
)
)
instance.left = right_most_sibling
@@ -62,18 +79,19 @@ def before_insert(mapper, connection, instance):
# before_update() would be needed to support moving of nodes
# after_delete() would be needed to support removal of nodes.
-engine = create_engine('sqlite://', echo=True)
+
+engine = create_engine("sqlite://", echo=True)
Base.metadata.create_all(engine)
session = Session(bind=engine)
-albert = Employee(emp='Albert')
-bert = Employee(emp='Bert')
-chuck = Employee(emp='Chuck')
-donna = Employee(emp='Donna')
-eddie = Employee(emp='Eddie')
-fred = Employee(emp='Fred')
+albert = Employee(emp="Albert")
+bert = Employee(emp="Bert")
+chuck = Employee(emp="Chuck")
+donna = Employee(emp="Donna")
+eddie = Employee(emp="Eddie")
+fred = Employee(emp="Fred")
bert.parent = albert
chuck.parent = albert
@@ -90,22 +108,28 @@ print(session.query(Employee).all())
# 1. Find an employee and all their supervisors, no matter how deep the tree.
ealias = aliased(Employee)
-print(session.query(Employee).\
- filter(ealias.left.between(Employee.left, Employee.right)).\
- filter(ealias.emp == 'Eddie').all())
-
-#2. Find the employee and all their subordinates.
+print(
+ session.query(Employee)
+ .filter(ealias.left.between(Employee.left, Employee.right))
+ .filter(ealias.emp == "Eddie")
+ .all()
+)
+
+# 2. Find the employee and all their subordinates.
# (This query has a nice symmetry with the first query.)
-print(session.query(Employee).\
- filter(Employee.left.between(ealias.left, ealias.right)).\
- filter(ealias.emp == 'Chuck').all())
-
-#3. Find the level of each node, so you can print the tree
+print(
+ session.query(Employee)
+ .filter(Employee.left.between(ealias.left, ealias.right))
+ .filter(ealias.emp == "Chuck")
+ .all()
+)
+
+# 3. Find the level of each node, so you can print the tree
# as an indented listing.
-for indentation, employee in session.query(
- func.count(Employee.emp).label('indentation') - 1, ealias).\
- filter(ealias.left.between(Employee.left, Employee.right)).\
- group_by(ealias.emp).\
- order_by(ealias.left):
+for indentation, employee in (
+ session.query(func.count(Employee.emp).label("indentation") - 1, ealias)
+ .filter(ealias.left.between(Employee.left, Employee.right))
+ .group_by(ealias.emp)
+ .order_by(ealias.left)
+):
print(" " * indentation + str(employee))
-