diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /examples | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'examples')
62 files changed, 2435 insertions, 1742 deletions
diff --git a/examples/adjacency_list/adjacency_list.py b/examples/adjacency_list/adjacency_list.py index 3c9144323..f1628a632 100644 --- a/examples/adjacency_list/adjacency_list.py +++ b/examples/adjacency_list/adjacency_list.py @@ -8,7 +8,7 @@ Base = declarative_base() class TreeNode(Base): - __tablename__ = 'tree' + __tablename__ = "tree" id = Column(Integer, primary_key=True) parent_id = Column(Integer, ForeignKey(id)) name = Column(String(50), nullable=False) @@ -17,15 +17,13 @@ class TreeNode(Base): "TreeNode", # cascade deletions cascade="all, delete-orphan", - # many to one + adjacency list - remote_side # is required to reference the 'remote' # column in the join condition. backref=backref("parent", remote_side=id), - # children will be represented as a dictionary # on the "name" attribute. - collection_class=attribute_mapped_collection('name'), + collection_class=attribute_mapped_collection("name"), ) def __init__(self, name, parent=None): @@ -36,20 +34,20 @@ class TreeNode(Base): return "TreeNode(name=%r, id=%r, parent_id=%r)" % ( self.name, self.id, - self.parent_id + self.parent_id, ) def dump(self, _indent=0): - return " " * _indent + repr(self) + \ - "\n" + \ - "".join([ - c.dump(_indent + 1) - for c in self.children.values() - ]) + return ( + " " * _indent + + repr(self) + + "\n" + + "".join([c.dump(_indent + 1) for c in self.children.values()]) + ) -if __name__ == '__main__': - engine = create_engine('sqlite://', echo=True) +if __name__ == "__main__": + engine = create_engine("sqlite://", echo=True) def msg(msg, *args): msg = msg % args @@ -63,14 +61,14 @@ if __name__ == '__main__': session = Session(engine) - node = TreeNode('rootnode') - TreeNode('node1', parent=node) - TreeNode('node3', parent=node) + node = TreeNode("rootnode") + TreeNode("node1", parent=node) + TreeNode("node3", parent=node) - node2 = TreeNode('node2') - TreeNode('subnode1', parent=node2) - node.children['node2'] = node2 - TreeNode('subnode2', parent=node.children['node2']) + node2 = TreeNode("node2") + TreeNode("subnode1", parent=node2) + node.children["node2"] = node2 + TreeNode("subnode2", parent=node.children["node2"]) msg("Created new tree structure:\n%s", node.dump()) @@ -81,28 +79,33 @@ if __name__ == '__main__': msg("Tree After Save:\n %s", node.dump()) - TreeNode('node4', parent=node) - TreeNode('subnode3', parent=node.children['node4']) - TreeNode('subnode4', parent=node.children['node4']) - TreeNode('subsubnode1', parent=node.children['node4'].children['subnode3']) + TreeNode("node4", parent=node) + TreeNode("subnode3", parent=node.children["node4"]) + TreeNode("subnode4", parent=node.children["node4"]) + TreeNode("subsubnode1", parent=node.children["node4"].children["subnode3"]) # remove node1 from the parent, which will trigger a delete # via the delete-orphan cascade. - del node.children['node1'] + del node.children["node1"] msg("Removed node1. flush + commit:") session.commit() msg("Tree after save:\n %s", node.dump()) - msg("Emptying out the session entirely, selecting tree on root, using " - "eager loading to join four levels deep.") + msg( + "Emptying out the session entirely, selecting tree on root, using " + "eager loading to join four levels deep." + ) session.expunge_all() - node = session.query(TreeNode).\ - options(joinedload_all("children", "children", - "children", "children")).\ - filter(TreeNode.name == "rootnode").\ - first() + node = ( + session.query(TreeNode) + .options( + joinedload_all("children", "children", "children", "children") + ) + .filter(TreeNode.name == "rootnode") + .first() + ) msg("Full Tree:\n%s", node.dump()) diff --git a/examples/association/basic_association.py b/examples/association/basic_association.py index 6714aa681..52476f184 100644 --- a/examples/association/basic_association.py +++ b/examples/association/basic_association.py @@ -12,8 +12,16 @@ of "items", with a particular price paid associated with each "item". from datetime import datetime -from sqlalchemy import (create_engine, Column, Integer, String, DateTime, - Float, ForeignKey, and_) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + DateTime, + Float, + ForeignKey, + and_, +) from sqlalchemy.orm import relationship, Session from sqlalchemy.ext.declarative import declarative_base @@ -21,20 +29,21 @@ Base = declarative_base() class Order(Base): - __tablename__ = 'order' + __tablename__ = "order" order_id = Column(Integer, primary_key=True) customer_name = Column(String(30), nullable=False) order_date = Column(DateTime, nullable=False, default=datetime.now()) - order_items = relationship("OrderItem", cascade="all, delete-orphan", - backref='order') + order_items = relationship( + "OrderItem", cascade="all, delete-orphan", backref="order" + ) def __init__(self, customer_name): self.customer_name = customer_name class Item(Base): - __tablename__ = 'item' + __tablename__ = "item" item_id = Column(Integer, primary_key=True) description = Column(String(30), nullable=False) price = Column(Float, nullable=False) @@ -44,41 +53,40 @@ class Item(Base): self.price = price def __repr__(self): - return 'Item(%r, %r)' % ( - self.description, self.price - ) + return "Item(%r, %r)" % (self.description, self.price) class OrderItem(Base): - __tablename__ = 'orderitem' - order_id = Column(Integer, ForeignKey('order.order_id'), primary_key=True) - item_id = Column(Integer, ForeignKey('item.item_id'), primary_key=True) + __tablename__ = "orderitem" + order_id = Column(Integer, ForeignKey("order.order_id"), primary_key=True) + item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) price = Column(Float, nullable=False) def __init__(self, item, price=None): self.item = item self.price = price or item.price - item = relationship(Item, lazy='joined') + + item = relationship(Item, lazy="joined") -if __name__ == '__main__': - engine = create_engine('sqlite://') +if __name__ == "__main__": + engine = create_engine("sqlite://") Base.metadata.create_all(engine) session = Session(engine) # create catalog tshirt, mug, hat, crowbar = ( - Item('SA T-Shirt', 10.99), - Item('SA Mug', 6.50), - Item('SA Hat', 8.99), - Item('MySQL Crowbar', 16.99) + Item("SA T-Shirt", 10.99), + Item("SA Mug", 6.50), + Item("SA Hat", 8.99), + Item("MySQL Crowbar", 16.99), ) session.add_all([tshirt, mug, hat, crowbar]) session.commit() # create an order - order = Order('john smith') + order = Order("john smith") # add three OrderItem associations to the Order and save order.order_items.append(OrderItem(mug)) @@ -88,13 +96,18 @@ if __name__ == '__main__': session.commit() # query the order, print items - order = session.query(Order).filter_by(customer_name='john smith').one() - print([(order_item.item.description, order_item.price) - for order_item in order.order_items]) + order = session.query(Order).filter_by(customer_name="john smith").one() + print( + [ + (order_item.item.description, order_item.price) + for order_item in order.order_items + ] + ) # print customers who bought 'MySQL Crowbar' on sale - q = session.query(Order).join('order_items', 'item') - q = q.filter(and_(Item.description == 'MySQL Crowbar', - Item.price > OrderItem.price)) + q = session.query(Order).join("order_items", "item") + q = q.filter( + and_(Item.description == "MySQL Crowbar", Item.price > OrderItem.price) + ) print([order.customer_name for order in q]) diff --git a/examples/association/dict_of_sets_with_default.py b/examples/association/dict_of_sets_with_default.py index fb9b6aa06..7f668c087 100644 --- a/examples/association/dict_of_sets_with_default.py +++ b/examples/association/dict_of_sets_with_default.py @@ -37,7 +37,9 @@ class A(Base): __tablename__ = "a" associations = relationship( "B", - collection_class=lambda: GenDefaultCollection(operator.attrgetter("key")) + collection_class=lambda: GenDefaultCollection( + operator.attrgetter("key") + ), ) collections = association_proxy("associations", "values") @@ -71,19 +73,15 @@ class C(Base): self.value = value -if __name__ == '__main__': - engine = create_engine('sqlite://', echo=True) +if __name__ == "__main__": + engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) # only "A" is referenced explicitly. Using "collections", # we deal with a dict of key/sets of integers directly. - session.add_all([ - A(collections={ - "1": set([1, 2, 3]), - }) - ]) + session.add_all([A(collections={"1": set([1, 2, 3])})]) session.commit() a1 = session.query(A).first() diff --git a/examples/association/proxied_association.py b/examples/association/proxied_association.py index 3393fdd1d..46785c6e2 100644 --- a/examples/association/proxied_association.py +++ b/examples/association/proxied_association.py @@ -7,8 +7,15 @@ to ``OrderItem`` optional. from datetime import datetime -from sqlalchemy import (create_engine, Column, Integer, String, DateTime, - Float, ForeignKey) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + DateTime, + Float, + ForeignKey, +) from sqlalchemy.orm import relationship, Session from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.associationproxy import association_proxy @@ -17,13 +24,14 @@ Base = declarative_base() class Order(Base): - __tablename__ = 'order' + __tablename__ = "order" order_id = Column(Integer, primary_key=True) customer_name = Column(String(30), nullable=False) order_date = Column(DateTime, nullable=False, default=datetime.now()) - order_items = relationship("OrderItem", cascade="all, delete-orphan", - backref='order') + order_items = relationship( + "OrderItem", cascade="all, delete-orphan", backref="order" + ) items = association_proxy("order_items", "item") def __init__(self, customer_name): @@ -31,7 +39,7 @@ class Order(Base): class Item(Base): - __tablename__ = 'item' + __tablename__ = "item" item_id = Column(Integer, primary_key=True) description = Column(String(30), nullable=False) price = Column(Float, nullable=False) @@ -41,39 +49,40 @@ class Item(Base): self.price = price def __repr__(self): - return 'Item(%r, %r)' % (self.description, self.price) + return "Item(%r, %r)" % (self.description, self.price) class OrderItem(Base): - __tablename__ = 'orderitem' - order_id = Column(Integer, ForeignKey('order.order_id'), primary_key=True) - item_id = Column(Integer, ForeignKey('item.item_id'), primary_key=True) + __tablename__ = "orderitem" + order_id = Column(Integer, ForeignKey("order.order_id"), primary_key=True) + item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True) price = Column(Float, nullable=False) def __init__(self, item, price=None): self.item = item self.price = price or item.price - item = relationship(Item, lazy='joined') + + item = relationship(Item, lazy="joined") -if __name__ == '__main__': - engine = create_engine('sqlite://') +if __name__ == "__main__": + engine = create_engine("sqlite://") Base.metadata.create_all(engine) session = Session(engine) # create catalog tshirt, mug, hat, crowbar = ( - Item('SA T-Shirt', 10.99), - Item('SA Mug', 6.50), - Item('SA Hat', 8.99), - Item('MySQL Crowbar', 16.99) + Item("SA T-Shirt", 10.99), + Item("SA Mug", 6.50), + Item("SA Hat", 8.99), + Item("MySQL Crowbar", 16.99), ) session.add_all([tshirt, mug, hat, crowbar]) session.commit() # create an order - order = Order('john smith') + order = Order("john smith") # add items via the association proxy. # the OrderItem is created automatically. @@ -87,19 +96,24 @@ if __name__ == '__main__': session.commit() # query the order, print items - order = session.query(Order).filter_by(customer_name='john smith').one() + order = session.query(Order).filter_by(customer_name="john smith").one() # print items based on the OrderItem collection directly - print([(assoc.item.description, assoc.price, assoc.item.price) - for assoc in order.order_items]) + print( + [ + (assoc.item.description, assoc.price, assoc.item.price) + for assoc in order.order_items + ] + ) # print items based on the "proxied" items collection - print([(item.description, item.price) - for item in order.items]) + print([(item.description, item.price) for item in order.items]) # print customers who bought 'MySQL Crowbar' on sale - orders = session.query(Order).\ - join('order_items', 'item').\ - filter(Item.description == 'MySQL Crowbar').\ - filter(Item.price > OrderItem.price) + orders = ( + session.query(Order) + .join("order_items", "item") + .filter(Item.description == "MySQL Crowbar") + .filter(Item.price > OrderItem.price) + ) print([o.customer_name for o in orders]) diff --git a/examples/custom_attributes/__init__.py b/examples/custom_attributes/__init__.py index cbc65dfed..8d73d27e3 100644 --- a/examples/custom_attributes/__init__.py +++ b/examples/custom_attributes/__init__.py @@ -4,4 +4,4 @@ system. .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/custom_attributes/active_column_defaults.py b/examples/custom_attributes/active_column_defaults.py index dd823e814..f05a53173 100644 --- a/examples/custom_attributes/active_column_defaults.py +++ b/examples/custom_attributes/active_column_defaults.py @@ -35,6 +35,7 @@ def default_listener(col_attr, default): user integrating this feature. """ + @event.listens_for(col_attr, "init_scalar", retval=True, propagate=True) def init_scalar(target, value, dict_): @@ -52,7 +53,8 @@ def default_listener(col_attr, default): # or can procure a connection from an Engine # or Session and actually run the SQL, if desired. raise NotImplementedError( - "Can't invoke pre-default for a SQL-level column default") + "Can't invoke pre-default for a SQL-level column default" + ) # set the value in the given dict_; this won't emit any further # attribute set events or create attribute "history", but the value @@ -63,7 +65,7 @@ def default_listener(col_attr, default): return value -if __name__ == '__main__': +if __name__ == "__main__": from sqlalchemy import Column, Integer, DateTime, create_engine from sqlalchemy.orm import Session @@ -72,10 +74,10 @@ if __name__ == '__main__': Base = declarative_base() - event.listen(Base, 'mapper_configured', configure_listener, propagate=True) + event.listen(Base, "mapper_configured", configure_listener, propagate=True) class Widget(Base): - __tablename__ = 'widget' + __tablename__ = "widget" id = Column(Integer, primary_key=True) @@ -96,8 +98,8 @@ if __name__ == '__main__': # Column-level default for the "timestamp" column will no longer fire # off. current_time = w1.timestamp - assert ( - current_time > datetime.datetime.now() - datetime.timedelta(seconds=5) + assert current_time > datetime.datetime.now() - datetime.timedelta( + seconds=5 ) # persist @@ -107,7 +109,7 @@ if __name__ == '__main__': # data is persisted. The timestamp is also the one we generated above; # e.g. the default wasn't re-invoked later. - assert ( - sess.query(Widget.radius, Widget.timestamp).first() == - (30, current_time) + assert sess.query(Widget.radius, Widget.timestamp).first() == ( + 30, + current_time, ) diff --git a/examples/custom_attributes/custom_management.py b/examples/custom_attributes/custom_management.py index 2199e0138..812385906 100644 --- a/examples/custom_attributes/custom_management.py +++ b/examples/custom_attributes/custom_management.py @@ -9,31 +9,44 @@ descriptors with a user-defined system. """ -from sqlalchemy import create_engine, MetaData, Table, Column, Integer, Text,\ - ForeignKey +from sqlalchemy import ( + create_engine, + MetaData, + Table, + Column, + Integer, + Text, + ForeignKey, +) from sqlalchemy.orm import mapper, relationship, Session -from sqlalchemy.orm.attributes import set_attribute, get_attribute, \ - del_attribute +from sqlalchemy.orm.attributes import ( + set_attribute, + get_attribute, + del_attribute, +) from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.ext.instrumentation import InstrumentationManager + class MyClassState(InstrumentationManager): def get_instance_dict(self, class_, instance): return instance._goofy_dict def initialize_instance_dict(self, class_, instance): - instance.__dict__['_goofy_dict'] = {} + instance.__dict__["_goofy_dict"] = {} def install_state(self, class_, instance, state): - instance.__dict__['_goofy_dict']['state'] = state + instance.__dict__["_goofy_dict"]["state"] = state def state_getter(self, class_): def find(instance): - return instance.__dict__['_goofy_dict']['state'] + return instance.__dict__["_goofy_dict"]["state"] + return find + class MyClass(object): __sa_instrumentation_manager__ = MyClassState @@ -63,17 +76,23 @@ class MyClass(object): del self._goofy_dict[key] -if __name__ == '__main__': - engine = create_engine('sqlite://') +if __name__ == "__main__": + engine = create_engine("sqlite://") meta = MetaData() - table1 = Table('table1', meta, - Column('id', Integer, primary_key=True), - Column('name', Text)) - table2 = Table('table2', meta, - Column('id', Integer, primary_key=True), - Column('name', Text), - Column('t1id', Integer, ForeignKey('table1.id'))) + table1 = Table( + "table1", + meta, + Column("id", Integer, primary_key=True), + Column("name", Text), + ) + table2 = Table( + "table2", + meta, + Column("id", Integer, primary_key=True), + Column("name", Text), + Column("t1id", Integer, ForeignKey("table1.id")), + ) meta.create_all(engine) class A(MyClass): @@ -82,16 +101,14 @@ if __name__ == '__main__': class B(MyClass): pass - mapper(A, table1, properties={ - 'bs': relationship(B) - }) + mapper(A, table1, properties={"bs": relationship(B)}) mapper(B, table2) - a1 = A(name='a1', bs=[B(name='b1'), B(name='b2')]) + a1 = A(name="a1", bs=[B(name="b1"), B(name="b2")]) - assert a1.name == 'a1' - assert a1.bs[0].name == 'b1' + assert a1.name == "a1" + assert a1.bs[0].name == "b1" sess = Session(engine) sess.add(a1) @@ -100,8 +117,8 @@ if __name__ == '__main__': a1 = sess.query(A).get(a1.id) - assert a1.name == 'a1' - assert a1.bs[0].name == 'b1' + assert a1.name == "a1" + assert a1.bs[0].name == "b1" a1.bs.remove(a1.bs[0]) diff --git a/examples/custom_attributes/listen_for_events.py b/examples/custom_attributes/listen_for_events.py index 0aeebc1d1..e3ef4cbea 100644 --- a/examples/custom_attributes/listen_for_events.py +++ b/examples/custom_attributes/listen_for_events.py @@ -5,6 +5,7 @@ and listen for change events. from sqlalchemy import event + def configure_listener(class_, key, inst): def append(instance, value, initiator): instance.receive_change_event("append", key, value, None) @@ -15,19 +16,18 @@ def configure_listener(class_, key, inst): def set_(instance, value, oldvalue, initiator): instance.receive_change_event("set", key, value, oldvalue) - event.listen(inst, 'append', append) - event.listen(inst, 'remove', remove) - event.listen(inst, 'set', set_) + event.listen(inst, "append", append) + event.listen(inst, "remove", remove) + event.listen(inst, "set", set_) -if __name__ == '__main__': +if __name__ == "__main__": from sqlalchemy import Column, Integer, String, ForeignKey from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base class Base(object): - def receive_change_event(self, verb, key, value, oldvalue): s = "Value '%s' %s on attribute '%s', " % (value, verb, key) if oldvalue: @@ -37,7 +37,7 @@ if __name__ == '__main__': Base = declarative_base(cls=Base) - event.listen(Base, 'attribute_instrument', configure_listener) + event.listen(Base, "attribute_instrument", configure_listener) class MyMappedClass(Base): __tablename__ = "mytable" @@ -61,9 +61,7 @@ if __name__ == '__main__': # classes are instrumented. Demonstrate the events ! - m1 = MyMappedClass(data='m1', related=Related(data='r1')) - m1.data = 'm1mod' - m1.related.mapped.append(MyMappedClass(data='m2')) + m1 = MyMappedClass(data="m1", related=Related(data="r1")) + m1.data = "m1mod" + m1.related.mapped.append(MyMappedClass(data="m2")) del m1.data - - diff --git a/examples/dogpile_caching/advanced.py b/examples/dogpile_caching/advanced.py index dc2ed0771..8f395bd7b 100644 --- a/examples/dogpile_caching/advanced.py +++ b/examples/dogpile_caching/advanced.py @@ -7,6 +7,7 @@ from .environment import Session from .model import Person, cache_address_bits from .caching_query import FromCache, RelationshipCache + def load_name_range(start, end, invalidate=False): """Load Person objects on a range of names. @@ -24,10 +25,14 @@ def load_name_range(start, end, invalidate=False): SQL that emits for unloaded Person objects as well as the distribution of data within the cache. """ - q = Session.query(Person).\ - filter(Person.name.between("person %.2d" % start, "person %.2d" % end)).\ - options(cache_address_bits).\ - options(FromCache("default", "name_range")) + q = ( + Session.query(Person) + .filter( + Person.name.between("person %.2d" % start, "person %.2d" % end) + ) + .options(cache_address_bits) + .options(FromCache("default", "name_range")) + ) # have the "addresses" collection cached separately # each lazyload of Person.addresses loads from cache. @@ -37,7 +42,7 @@ def load_name_range(start, end, invalidate=False): # be cached together. This issues a bigger SQL statement and caches # a single, larger value in the cache per person rather than two # separate ones. - #q = q.options(joinedload(Person.addresses)) + # q = q.options(joinedload(Person.addresses)) # if requested, invalidate the cache on current criterion. if invalidate: @@ -45,6 +50,7 @@ def load_name_range(start, end, invalidate=False): return q.all() + print("two through twelve, possibly from cache:\n") print(", ".join([p.name for p in load_name_range(2, 12)])) @@ -61,7 +67,9 @@ print(", ".join([p.name for p in load_name_range(25, 40, True)])) # illustrate the address loading from either cache/already # on the Person -print("\n\nPeople plus addresses, two through twelve, addresses possibly from cache") +print( + "\n\nPeople plus addresses, two through twelve, addresses possibly from cache" +) for p in load_name_range(2, 12): print(p.format_full()) @@ -71,5 +79,7 @@ print("\n\nPeople plus addresses, two through twelve, addresses from cache") for p in load_name_range(2, 12): print(p.format_full()) -print("\n\nIf this was the first run of advanced.py, try "\ - "a second run. Only one SQL statement will be emitted.") +print( + "\n\nIf this was the first run of advanced.py, try " + "a second run. Only one SQL statement will be emitted." +) diff --git a/examples/dogpile_caching/caching_query.py b/examples/dogpile_caching/caching_query.py index 6ad2dba4d..060c14613 100644 --- a/examples/dogpile_caching/caching_query.py +++ b/examples/dogpile_caching/caching_query.py @@ -59,7 +59,7 @@ class CachingQuery(Query): """ super_ = super(CachingQuery, self) - if hasattr(self, '_cache_region'): + if hasattr(self, "_cache_region"): return self.get_value(createfunc=lambda: list(super_.__iter__())) else: return super_.__iter__() @@ -78,13 +78,11 @@ class CachingQuery(Query): """ super_ = super(CachingQuery, self) - if context.query is not self and hasattr(self, '_cache_region'): + if context.query is not self and hasattr(self, "_cache_region"): # special logic called when the Query._execute_and_instances() # method is called directly from the baked query return self.get_value( - createfunc=lambda: list( - super_._execute_and_instances(context) - ) + createfunc=lambda: list(super_._execute_and_instances(context)) ) else: return super_._execute_and_instances(context) @@ -105,8 +103,13 @@ class CachingQuery(Query): dogpile_region, cache_key = self._get_cache_plus_key() dogpile_region.delete(cache_key) - def get_value(self, merge=True, createfunc=None, - expiration_time=None, ignore_expiration=False): + def get_value( + self, + merge=True, + createfunc=None, + expiration_time=None, + ignore_expiration=False, + ): """Return the value from the cache for this query. Raise KeyError if no value present and no @@ -119,19 +122,20 @@ class CachingQuery(Query): # but is expired, return it anyway. This doesn't make sense # with createfunc, which says, if the value is expired, generate # a new value. - assert not ignore_expiration or not createfunc, \ - "Can't ignore expiration and also provide createfunc" + assert ( + not ignore_expiration or not createfunc + ), "Can't ignore expiration and also provide createfunc" if ignore_expiration or not createfunc: - cached_value = dogpile_region.get(cache_key, - expiration_time=expiration_time, - ignore_expiration=ignore_expiration) + cached_value = dogpile_region.get( + cache_key, + expiration_time=expiration_time, + ignore_expiration=ignore_expiration, + ) else: cached_value = dogpile_region.get_or_create( - cache_key, - createfunc, - expiration_time=expiration_time - ) + cache_key, createfunc, expiration_time=expiration_time + ) if cached_value is NO_VALUE: raise KeyError(cache_key) if merge: @@ -144,11 +148,14 @@ class CachingQuery(Query): dogpile_region, cache_key = self._get_cache_plus_key() dogpile_region.set(cache_key, value) + def query_callable(regions, query_cls=CachingQuery): def query(*arg, **kw): return query_cls(regions, *arg, **kw) + return query + def _key_from_query(query, qualifier=None): """Given a Query, create a cache key. @@ -168,9 +175,8 @@ def _key_from_query(query, qualifier=None): # here we return the key as a long string. our "key mangler" # set up with the region will boil it down to an md5. - return " ".join( - [str(compiled)] + - [str(params[k]) for k in sorted(params)]) + return " ".join([str(compiled)] + [str(params[k]) for k in sorted(params)]) + class FromCache(MapperOption): """Specifies that a Query should load results from a cache.""" @@ -198,6 +204,7 @@ class FromCache(MapperOption): """Process a Query during normal loading operation.""" query._cache_region = self + class RelationshipCache(MapperOption): """Specifies that a Query as called within a "lazy load" should load results from a cache.""" @@ -237,7 +244,9 @@ class RelationshipCache(MapperOption): for cls in mapper.class_.__mro__: if (cls, key) in self._relationship_options: - relationship_option = self._relationship_options[(cls, key)] + relationship_option = self._relationship_options[ + (cls, key) + ] query._cache_region = relationship_option break @@ -264,4 +273,3 @@ class RelationshipCache(MapperOption): """ return None - diff --git a/examples/dogpile_caching/environment.py b/examples/dogpile_caching/environment.py index 130dfdb2b..13bd0a310 100644 --- a/examples/dogpile_caching/environment.py +++ b/examples/dogpile_caching/environment.py @@ -10,6 +10,7 @@ from dogpile.cache.region import make_region import os from hashlib import md5 import sys + py2k = sys.version_info < (3, 0) if py2k: @@ -23,9 +24,7 @@ regions = {} # using a callable that will associate the dictionary # of regions with the Query. Session = scoped_session( - sessionmaker( - query_cls=caching_query.query_callable(regions) - ) + sessionmaker(query_cls=caching_query.query_callable(regions)) ) # global declarative base class. @@ -42,7 +41,7 @@ if not os.path.exists(root): os.makedirs(root) dbfile = os.path.join(root, "dogpile_demo.db") -engine = create_engine('sqlite:///%s' % dbfile, echo=True) +engine = create_engine("sqlite:///%s" % dbfile, echo=True) Session.configure(bind=engine) @@ -51,10 +50,11 @@ def md5_key_mangler(key): distill them into an md5 hash. """ - return md5(key.encode('ascii')).hexdigest() + return md5(key.encode("ascii")).hexdigest() + # configure the "default" cache region. -regions['default'] = make_region( +regions["default"] = make_region( # the "dbm" backend needs # string-encoded keys key_mangler=md5_key_mangler @@ -63,11 +63,9 @@ regions['default'] = make_region( # serialized persistence. Normally # memcached or similar is a better choice # for caching. - 'dogpile.cache.dbm', + "dogpile.cache.dbm", expiration_time=3600, - arguments={ - "filename": os.path.join(root, "cache.dbm") - } + arguments={"filename": os.path.join(root, "cache.dbm")}, ) # optional; call invalidate() on the region @@ -83,6 +81,7 @@ installed = False def bootstrap(): global installed from . import fixture_data + if not os.path.exists(dbfile): fixture_data.install() - installed = True
\ No newline at end of file + installed = True diff --git a/examples/dogpile_caching/fixture_data.py b/examples/dogpile_caching/fixture_data.py index 465171891..e301db2a4 100644 --- a/examples/dogpile_caching/fixture_data.py +++ b/examples/dogpile_caching/fixture_data.py @@ -12,13 +12,19 @@ def install(): Base.metadata.create_all(Session().bind) data = [ - ('Chicago', 'United States', ('60601', '60602', '60603', '60604')), - ('Montreal', 'Canada', ('H2S 3K9', 'H2B 1V4', 'H7G 2T8')), - ('Edmonton', 'Canada', ('T5J 1R9', 'T5J 1Z4', 'T5H 1P6')), - ('New York', 'United States', - ('10001', '10002', '10003', '10004', '10005', '10006')), - ('San Francisco', 'United States', - ('94102', '94103', '94104', '94105', '94107', '94108')) + ("Chicago", "United States", ("60601", "60602", "60603", "60604")), + ("Montreal", "Canada", ("H2S 3K9", "H2B 1V4", "H7G 2T8")), + ("Edmonton", "Canada", ("T5J 1R9", "T5J 1Z4", "T5H 1P6")), + ( + "New York", + "United States", + ("10001", "10002", "10003", "10004", "10005", "10006"), + ), + ( + "San Francisco", + "United States", + ("94102", "94103", "94104", "94105", "94107", "94108"), + ), ] countries = {} @@ -40,8 +46,9 @@ def install(): Address( street="street %.2d" % i, postal_code=all_post_codes[ - random.randint(0, len(all_post_codes) - 1)] - ) + random.randint(0, len(all_post_codes) - 1) + ], + ), ) Session.add(person) diff --git a/examples/dogpile_caching/helloworld.py b/examples/dogpile_caching/helloworld.py index 0dbde5eaf..eb565344e 100644 --- a/examples/dogpile_caching/helloworld.py +++ b/examples/dogpile_caching/helloworld.py @@ -21,28 +21,34 @@ people = Session.query(Person).options(FromCache("default")).all() # Specifying a different query produces a different cache key, so # these results are independently cached. print("loading people two through twelve") -people_two_through_twelve = Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 02", "person 12")).\ - all() +people_two_through_twelve = ( + Session.query(Person) + .options(FromCache("default")) + .filter(Person.name.between("person 02", "person 12")) + .all() +) # the data is cached under string structure of the SQL statement, *plus* # the bind parameters of the query. So this query, having # different literal parameters under "Person.name.between()" than the # previous one, issues new SQL... print("loading people five through fifteen") -people_five_through_fifteen = Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 05", "person 15")).\ - all() +people_five_through_fifteen = ( + Session.query(Person) + .options(FromCache("default")) + .filter(Person.name.between("person 05", "person 15")) + .all() +) # ... but using the same params as are already cached, no SQL print("loading people two through twelve...again!") -people_two_through_twelve = Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 02", "person 12")).\ - all() +people_two_through_twelve = ( + Session.query(Person) + .options(FromCache("default")) + .filter(Person.name.between("person 02", "person 12")) + .all() +) # invalidate the cache for the three queries we've done. Recreate @@ -51,10 +57,9 @@ people_two_through_twelve = Session.query(Person).\ # same order, then call invalidate(). print("invalidating everything") Session.query(Person).options(FromCache("default")).invalidate() -Session.query(Person).\ - options(FromCache("default")).\ - filter(Person.name.between("person 02", "person 12")).invalidate() -Session.query(Person).\ - options(FromCache("default", "people_on_range")).\ - filter(Person.name.between("person 05", "person 15")).invalidate() - +Session.query(Person).options(FromCache("default")).filter( + Person.name.between("person 02", "person 12") +).invalidate() +Session.query(Person).options(FromCache("default", "people_on_range")).filter( + Person.name.between("person 05", "person 15") +).invalidate() diff --git a/examples/dogpile_caching/local_session_caching.py b/examples/dogpile_caching/local_session_caching.py index 633252fc7..358886bf0 100644 --- a/examples/dogpile_caching/local_session_caching.py +++ b/examples/dogpile_caching/local_session_caching.py @@ -30,7 +30,7 @@ class ScopedSessionBackend(CacheBackend): """ def __init__(self, arguments): - self.scoped_session = arguments['scoped_session'] + self.scoped_session = arguments["scoped_session"] def get(self, key): return self._cache_dictionary.get(key, NO_VALUE) @@ -52,10 +52,11 @@ class ScopedSessionBackend(CacheBackend): sess._cache_dictionary = cache_dict = {} return cache_dict + register_backend("sqlalchemy.session", __name__, "ScopedSessionBackend") -if __name__ == '__main__': +if __name__ == "__main__": from .environment import Session, regions from .caching_query import FromCache from dogpile.cache import make_region @@ -63,20 +64,19 @@ if __name__ == '__main__': # set up a region based on the ScopedSessionBackend, # pointing to the scoped_session declared in the example # environment. - regions['local_session'] = make_region().configure( - 'sqlalchemy.session', - arguments={ - "scoped_session": Session - } + regions["local_session"] = make_region().configure( + "sqlalchemy.session", arguments={"scoped_session": Session} ) from .model import Person # query to load Person by name, with criterion # of "person 10" - q = Session.query(Person).\ - options(FromCache("local_session")).\ - filter(Person.name == "person 10") + q = ( + Session.query(Person) + .options(FromCache("local_session")) + .filter(Person.name == "person 10") + ) # load from DB person10 = q.one() diff --git a/examples/dogpile_caching/model.py b/examples/dogpile_caching/model.py index 3eb02108c..f6a259820 100644 --- a/examples/dogpile_caching/model.py +++ b/examples/dogpile_caching/model.py @@ -14,7 +14,7 @@ from .environment import Base, bootstrap class Country(Base): - __tablename__ = 'country' + __tablename__ = "country" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) @@ -24,11 +24,11 @@ class Country(Base): class City(Base): - __tablename__ = 'city' + __tablename__ = "city" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) - country_id = Column(Integer, ForeignKey('country.id'), nullable=False) + country_id = Column(Integer, ForeignKey("country.id"), nullable=False) country = relationship(Country) def __init__(self, name, country): @@ -37,11 +37,11 @@ class City(Base): class PostalCode(Base): - __tablename__ = 'postal_code' + __tablename__ = "postal_code" id = Column(Integer, primary_key=True) code = Column(String(10), nullable=False) - city_id = Column(Integer, ForeignKey('city.id'), nullable=False) + city_id = Column(Integer, ForeignKey("city.id"), nullable=False) city = relationship(City) @property @@ -54,12 +54,12 @@ class PostalCode(Base): class Address(Base): - __tablename__ = 'address' + __tablename__ = "address" id = Column(Integer, primary_key=True) - person_id = Column(Integer, ForeignKey('person.id'), nullable=False) + person_id = Column(Integer, ForeignKey("person.id"), nullable=False) street = Column(String(200), nullable=False) - postal_code_id = Column(Integer, ForeignKey('postal_code.id')) + postal_code_id = Column(Integer, ForeignKey("postal_code.id")) postal_code = relationship(PostalCode) @property @@ -71,15 +71,16 @@ class Address(Base): return self.postal_code.country def __str__(self): - return ( - "%s\t%s, %s\t%s" % ( - self.street, self.city.name, - self.postal_code.code, self.country.name) + return "%s\t%s, %s\t%s" % ( + self.street, + self.city.name, + self.postal_code.code, + self.country.name, ) class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) name = Column(String(100), nullable=False) @@ -98,14 +99,14 @@ class Person(Base): def format_full(self): return "\t".join([str(x) for x in [self] + list(self.addresses)]) + # Caching options. A set of three RelationshipCache options # which can be applied to Query(), causing the "lazy load" # of these attributes to be loaded from cache. -cache_address_bits = RelationshipCache(PostalCode.city, "default").\ - and_( - RelationshipCache(City.country, "default") -).and_( - RelationshipCache(Address.postal_code, "default") +cache_address_bits = ( + RelationshipCache(PostalCode.city, "default") + .and_(RelationshipCache(City.country, "default")) + .and_(RelationshipCache(Address.postal_code, "default")) ) bootstrap() diff --git a/examples/dogpile_caching/relationship_caching.py b/examples/dogpile_caching/relationship_caching.py index 920d696f8..76c7e767f 100644 --- a/examples/dogpile_caching/relationship_caching.py +++ b/examples/dogpile_caching/relationship_caching.py @@ -12,7 +12,8 @@ from sqlalchemy.orm import joinedload import os for p in Session.query(Person).options( - joinedload(Person.addresses), cache_address_bits): + joinedload(Person.addresses), cache_address_bits +): print(p.format_full()) @@ -25,5 +26,5 @@ print( "related data is pulled from cache.\n" "To clear the cache, delete the file %r. \n" "This will cause a re-load of cities, postal codes and countries on " - "the next run.\n" - % os.path.join(root, 'cache.dbm')) + "the next run.\n" % os.path.join(root, "cache.dbm") +) diff --git a/examples/dynamic_dict/__init__.py b/examples/dynamic_dict/__init__.py index e592ea200..ed31df062 100644 --- a/examples/dynamic_dict/__init__.py +++ b/examples/dynamic_dict/__init__.py @@ -5,4 +5,4 @@ full collection at once. .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/dynamic_dict/dynamic_dict.py b/examples/dynamic_dict/dynamic_dict.py index 530674f2e..62da55c38 100644 --- a/examples/dynamic_dict/dynamic_dict.py +++ b/examples/dynamic_dict/dynamic_dict.py @@ -14,7 +14,7 @@ class ProxyDict(object): return [x[0] for x in self.collection.values(descriptor)] def __getitem__(self, key): - x = self.collection.filter_by(**{self.keyname:key}).first() + x = self.collection.filter_by(**{self.keyname: key}).first() if x: return x else: @@ -28,43 +28,48 @@ class ProxyDict(object): pass self.collection.append(value) + from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import create_engine, Column, Integer, String, ForeignKey from sqlalchemy.orm import sessionmaker, relationship -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base = declarative_base(engine) + class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) name = Column(String(50)) - _collection = relationship("Child", lazy="dynamic", - cascade="all, delete-orphan") + _collection = relationship( + "Child", lazy="dynamic", cascade="all, delete-orphan" + ) @property def child_map(self): - return ProxyDict(self, '_collection', Child, 'key') + return ProxyDict(self, "_collection", Child, "key") + class Child(Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) key = Column(String(50)) - parent_id = Column(Integer, ForeignKey('parent.id')) + parent_id = Column(Integer, ForeignKey("parent.id")) def __repr__(self): return "Child(key=%r)" % self.key + Base.metadata.create_all() sess = sessionmaker()() -p1 = Parent(name='p1') +p1 = Parent(name="p1") sess.add(p1) print("\n---------begin setting nodes, autoflush occurs\n") -p1.child_map['k1'] = Child(key='k1') -p1.child_map['k2'] = Child(key='k2') +p1.child_map["k1"] = Child(key="k1") +p1.child_map["k2"] = Child(key="k2") # this will autoflush the current map. # ['k1', 'k2'] @@ -73,16 +78,15 @@ print(list(p1.child_map.keys())) # k1 print("\n---------print 'k1' node\n") -print(p1.child_map['k1']) +print(p1.child_map["k1"]) print("\n---------update 'k2' node - must find existing, and replace\n") -p1.child_map['k2'] = Child(key='k2') +p1.child_map["k2"] = Child(key="k2") print("\n---------print 'k2' key - flushes first\n") # k2 -print(p1.child_map['k2']) +print(p1.child_map["k2"]) print("\n---------print all child nodes\n") # [k1, k2b] print(sess.query(Child).all()) - diff --git a/examples/elementtree/__init__.py b/examples/elementtree/__init__.py index 66e9cfbbe..82d00ff5a 100644 --- a/examples/elementtree/__init__.py +++ b/examples/elementtree/__init__.py @@ -22,4 +22,4 @@ E.g.:: .. autosource:: :files: pickle.py, adjacency_list.py, optimized_al.py -"""
\ No newline at end of file +""" diff --git a/examples/elementtree/adjacency_list.py b/examples/elementtree/adjacency_list.py index 5e27ba9ca..1f7161212 100644 --- a/examples/elementtree/adjacency_list.py +++ b/examples/elementtree/adjacency_list.py @@ -15,42 +15,63 @@ styles of persistence are identical, as is the structure of the main Document cl """ ################################# PART I - Imports/Coniguration #################################### -from sqlalchemy import (MetaData, Table, Column, Integer, String, ForeignKey, - Unicode, and_, create_engine) +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + String, + ForeignKey, + Unicode, + and_, + create_engine, +) from sqlalchemy.orm import mapper, relationship, Session, lazyload import sys, os, io, re from xml.etree import ElementTree -e = create_engine('sqlite://') +e = create_engine("sqlite://") meta = MetaData() ################################# PART II - Table Metadata ######################################### # stores a top level record of an XML document. -documents = Table('documents', meta, - Column('document_id', Integer, primary_key=True), - Column('filename', String(30), unique=True), - Column('element_id', Integer, ForeignKey('elements.element_id')) +documents = Table( + "documents", + meta, + Column("document_id", Integer, primary_key=True), + Column("filename", String(30), unique=True), + Column("element_id", Integer, ForeignKey("elements.element_id")), ) # stores XML nodes in an adjacency list model. This corresponds to # Element and SubElement objects. -elements = Table('elements', meta, - Column('element_id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('elements.element_id')), - Column('tag', Unicode(30), nullable=False), - Column('text', Unicode), - Column('tail', Unicode) - ) +elements = Table( + "elements", + meta, + Column("element_id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("elements.element_id")), + Column("tag", Unicode(30), nullable=False), + Column("text", Unicode), + Column("tail", Unicode), +) # stores attributes. This corresponds to the dictionary of attributes # stored by an Element or SubElement. -attributes = Table('attributes', meta, - Column('element_id', Integer, ForeignKey('elements.element_id'), primary_key=True), - Column('name', Unicode(100), nullable=False, primary_key=True), - Column('value', Unicode(255))) +attributes = Table( + "attributes", + meta, + Column( + "element_id", + Integer, + ForeignKey("elements.element_id"), + primary_key=True, + ), + Column("name", Unicode(100), nullable=False, primary_key=True), + Column("value", Unicode(255)), +) meta.create_all(e) @@ -68,6 +89,7 @@ class Document(object): self.element.write(buf) return buf.getvalue() + #################################### PART IV - Persistence Mapping ################################# # Node class. a non-public class which will represent @@ -78,6 +100,7 @@ class Document(object): class _Node(object): pass + # Attribute class. also internal, this will represent the key/value attributes stored for # a particular Node. class _Attribute(object): @@ -85,16 +108,25 @@ class _Attribute(object): self.name = name self.value = value + # setup mappers. Document will eagerly load a list of _Node objects. -mapper(Document, documents, properties={ - '_root':relationship(_Node, lazy='joined', cascade="all") -}) +mapper( + Document, + documents, + properties={"_root": relationship(_Node, lazy="joined", cascade="all")}, +) -mapper(_Node, elements, properties={ - 'children':relationship(_Node, cascade="all"), - # eagerly load attributes - 'attributes':relationship(_Attribute, lazy='joined', cascade="all, delete-orphan"), -}) +mapper( + _Node, + elements, + properties={ + "children": relationship(_Node, cascade="all"), + # eagerly load attributes + "attributes": relationship( + _Attribute, lazy="joined", cascade="all, delete-orphan" + ), + }, +) mapper(_Attribute, attributes) @@ -106,7 +138,7 @@ class ElementTreeMarshal(object): if document is None: return self - if hasattr(document, '_element'): + if hasattr(document, "_element"): return document._element def traverse(node, parent=None): @@ -132,7 +164,9 @@ class ElementTreeMarshal(object): n.text = str(node.text) n.tail = str(node.tail) n.children = [traverse(n2) for n2 in node] - n.attributes = [_Attribute(str(k), str(v)) for k, v in node.attrib.items()] + n.attributes = [ + _Attribute(str(k), str(v)) for k, v in node.attrib.items() + ] return n document._root = traverse(element.getroot()) @@ -142,6 +176,7 @@ class ElementTreeMarshal(object): del document._element document._root = [] + # override Document's "element" attribute with the marshaller. Document.element = ElementTreeMarshal() @@ -153,7 +188,7 @@ line = "\n--------------------------------------------------------" session = Session(e) # get ElementTree documents -for file in ('test.xml', 'test2.xml', 'test3.xml'): +for file in ("test.xml", "test2.xml", "test3.xml"): filename = os.path.join(os.path.dirname(__file__), file) doc = ElementTree.parse(filename) session.add(Document(file, doc)) @@ -170,10 +205,16 @@ print(document) ############################################ PART VI - Searching for Paths ######################### # manually search for a document which contains "/somefile/header/field1:hi" -d = session.query(Document).join('_root', aliased=True).filter(_Node.tag=='somefile').\ - join('children', aliased=True, from_joinpoint=True).filter(_Node.tag=='header').\ - join('children', aliased=True, from_joinpoint=True).filter( - and_(_Node.tag=='field1', _Node.text=='hi')).one() +d = ( + session.query(Document) + .join("_root", aliased=True) + .filter(_Node.tag == "somefile") + .join("children", aliased=True, from_joinpoint=True) + .filter(_Node.tag == "header") + .join("children", aliased=True, from_joinpoint=True) + .filter(and_(_Node.tag == "field1", _Node.text == "hi")) + .one() +) print(d) # generalize the above approach into an extremely impoverished xpath function: @@ -181,26 +222,39 @@ def find_document(path, compareto): j = documents prev_elements = None query = session.query(Document) - attribute = '_root' - for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)): + attribute = "_root" + for i, match in enumerate( + re.finditer(r"/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?", path) + ): (token, attrname, attrvalue) = match.group(1, 2, 3) - query = query.join(attribute, aliased=True, from_joinpoint=True).filter(_Node.tag==token) - attribute = 'children' + query = query.join( + attribute, aliased=True, from_joinpoint=True + ).filter(_Node.tag == token) + attribute = "children" if attrname: if attrvalue: - query = query.join('attributes', aliased=True, from_joinpoint=True).filter( - and_(_Attribute.name==attrname, _Attribute.value==attrvalue)) + query = query.join( + "attributes", aliased=True, from_joinpoint=True + ).filter( + and_( + _Attribute.name == attrname, + _Attribute.value == attrvalue, + ) + ) else: - query = query.join('attributes', aliased=True, from_joinpoint=True).filter( - _Attribute.name==attrname) - return query.options(lazyload('_root')).filter(_Node.text==compareto).all() + query = query.join( + "attributes", aliased=True, from_joinpoint=True + ).filter(_Attribute.name == attrname) + return ( + query.options(lazyload("_root")).filter(_Node.text == compareto).all() + ) + for path, compareto in ( - ('/somefile/header/field1', 'hi'), - ('/somefile/field1', 'hi'), - ('/somefile/header/field2', 'there'), - ('/somefile/header/field2[@attr=foo]', 'there') - ): + ("/somefile/header/field1", "hi"), + ("/somefile/field1", "hi"), + ("/somefile/header/field2", "there"), + ("/somefile/header/field2[@attr=foo]", "there"), +): print("\nDocuments containing '%s=%s':" % (path, compareto), line) print([d.filename for d in find_document(path, compareto)]) - diff --git a/examples/elementtree/optimized_al.py b/examples/elementtree/optimized_al.py index e13f5b0ee..8e9c48b96 100644 --- a/examples/elementtree/optimized_al.py +++ b/examples/elementtree/optimized_al.py @@ -8,42 +8,63 @@ """ ##################### PART I - Imports/Configuration ######################### -from sqlalchemy import (MetaData, Table, Column, Integer, String, ForeignKey, - Unicode, and_, create_engine) +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + String, + ForeignKey, + Unicode, + and_, + create_engine, +) from sqlalchemy.orm import mapper, relationship, Session, lazyload import sys, os, io, re from xml.etree import ElementTree -e = create_engine('sqlite://', echo=True) +e = create_engine("sqlite://", echo=True) meta = MetaData() ####################### PART II - Table Metadata ############################# # stores a top level record of an XML document. -documents = Table('documents', meta, - Column('document_id', Integer, primary_key=True), - Column('filename', String(30), unique=True), +documents = Table( + "documents", + meta, + Column("document_id", Integer, primary_key=True), + Column("filename", String(30), unique=True), ) # stores XML nodes in an adjacency list model. This corresponds to # Element and SubElement objects. -elements = Table('elements', meta, - Column('element_id', Integer, primary_key=True), - Column('parent_id', Integer, ForeignKey('elements.element_id')), - Column('document_id', Integer, ForeignKey('documents.document_id')), - Column('tag', Unicode(30), nullable=False), - Column('text', Unicode), - Column('tail', Unicode) - ) +elements = Table( + "elements", + meta, + Column("element_id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey("elements.element_id")), + Column("document_id", Integer, ForeignKey("documents.document_id")), + Column("tag", Unicode(30), nullable=False), + Column("text", Unicode), + Column("tail", Unicode), +) # stores attributes. This corresponds to the dictionary of attributes # stored by an Element or SubElement. -attributes = Table('attributes', meta, - Column('element_id', Integer, ForeignKey('elements.element_id'), primary_key=True), - Column('name', Unicode(100), nullable=False, primary_key=True), - Column('value', Unicode(255))) +attributes = Table( + "attributes", + meta, + Column( + "element_id", + Integer, + ForeignKey("elements.element_id"), + primary_key=True, + ), + Column("name", Unicode(100), nullable=False, primary_key=True), + Column("value", Unicode(255)), +) meta.create_all(e) @@ -61,6 +82,7 @@ class Document(object): self.element.write(buf) return buf.getvalue() + ########################## PART IV - Persistence Mapping ##################### # Node class. a non-public class which will represent @@ -71,6 +93,7 @@ class Document(object): class _Node(object): pass + # Attribute class. also internal, this will represent the key/value attributes stored for # a particular Node. class _Attribute(object): @@ -78,21 +101,36 @@ class _Attribute(object): self.name = name self.value = value + # setup mappers. Document will eagerly load a list of _Node objects. # they will be ordered in primary key/insert order, so that we can reconstruct # an ElementTree structure from the list. -mapper(Document, documents, properties={ - '_nodes':relationship(_Node, lazy='joined', cascade="all, delete-orphan") -}) +mapper( + Document, + documents, + properties={ + "_nodes": relationship( + _Node, lazy="joined", cascade="all, delete-orphan" + ) + }, +) # the _Node objects change the way they load so that a list of _Nodes will organize # themselves hierarchically using the ElementTreeMarshal. this depends on the ordering of # nodes being hierarchical as well; relationship() always applies at least ROWID/primary key # ordering to rows which will suffice. -mapper(_Node, elements, properties={ - 'children':relationship(_Node, lazy=None), # doesnt load; used only for the save relationship - 'attributes':relationship(_Attribute, lazy='joined', cascade="all, delete-orphan"), # eagerly load attributes -}) +mapper( + _Node, + elements, + properties={ + "children": relationship( + _Node, lazy=None + ), # doesnt load; used only for the save relationship + "attributes": relationship( + _Attribute, lazy="joined", cascade="all, delete-orphan" + ), # eagerly load attributes + }, +) mapper(_Attribute, attributes) @@ -104,7 +142,7 @@ class ElementTreeMarshal(object): if document is None: return self - if hasattr(document, '_element'): + if hasattr(document, "_element"): return document._element nodes = {} @@ -134,7 +172,9 @@ class ElementTreeMarshal(object): n.tail = str(node.tail) document._nodes.append(n) n.children = [traverse(n2) for n2 in node] - n.attributes = [_Attribute(str(k), str(v)) for k, v in node.attrib.items()] + n.attributes = [ + _Attribute(str(k), str(v)) for k, v in node.attrib.items() + ] return n traverse(element.getroot()) @@ -144,6 +184,7 @@ class ElementTreeMarshal(object): del document._element document._nodes = [] + # override Document's "element" attribute with the marshaller. Document.element = ElementTreeMarshal() @@ -155,7 +196,7 @@ line = "\n--------------------------------------------------------" session = Session(e) # get ElementTree documents -for file in ('test.xml', 'test2.xml', 'test3.xml'): +for file in ("test.xml", "test2.xml", "test3.xml"): filename = os.path.join(os.path.dirname(__file__), file) doc = ElementTree.parse(filename) session.add(Document(file, doc)) @@ -173,13 +214,16 @@ print(document) # manually search for a document which contains "/somefile/header/field1:hi" print("\nManual search for /somefile/header/field1=='hi':", line) -d = session.query(Document).join('_nodes', aliased=True).\ - filter(and_(_Node.parent_id==None, _Node.tag=='somefile')).\ - join('children', aliased=True, from_joinpoint=True).\ - filter(_Node.tag=='header').\ - join('children', aliased=True, from_joinpoint=True).\ - filter(and_(_Node.tag=='field1', _Node.text=='hi')).\ - one() +d = ( + session.query(Document) + .join("_nodes", aliased=True) + .filter(and_(_Node.parent_id == None, _Node.tag == "somefile")) + .join("children", aliased=True, from_joinpoint=True) + .filter(_Node.tag == "header") + .join("children", aliased=True, from_joinpoint=True) + .filter(and_(_Node.tag == "field1", _Node.text == "hi")) + .one() +) print(d) # generalize the above approach into an extremely impoverished xpath function: @@ -188,28 +232,39 @@ def find_document(path, compareto): prev_elements = None query = session.query(Document) first = True - for i, match in enumerate(re.finditer(r'/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?', path)): + for i, match in enumerate( + re.finditer(r"/([\w_]+)(?:\[@([\w_]+)(?:=(.*))?\])?", path) + ): (token, attrname, attrvalue) = match.group(1, 2, 3) if first: - query = query.join('_nodes', aliased=True).filter(_Node.parent_id==None) + query = query.join("_nodes", aliased=True).filter( + _Node.parent_id == None + ) first = False else: - query = query.join('children', aliased=True, from_joinpoint=True) - query = query.filter(_Node.tag==token) + query = query.join("children", aliased=True, from_joinpoint=True) + query = query.filter(_Node.tag == token) if attrname: - query = query.join('attributes', aliased=True, from_joinpoint=True) + query = query.join("attributes", aliased=True, from_joinpoint=True) if attrvalue: - query = query.filter(and_(_Attribute.name==attrname, _Attribute.value==attrvalue)) + query = query.filter( + and_( + _Attribute.name == attrname, + _Attribute.value == attrvalue, + ) + ) else: - query = query.filter(_Attribute.name==attrname) - return query.options(lazyload('_nodes')).filter(_Node.text==compareto).all() + query = query.filter(_Attribute.name == attrname) + return ( + query.options(lazyload("_nodes")).filter(_Node.text == compareto).all() + ) + for path, compareto in ( - ('/somefile/header/field1', 'hi'), - ('/somefile/field1', 'hi'), - ('/somefile/header/field2', 'there'), - ('/somefile/header/field2[@attr=foo]', 'there') - ): + ("/somefile/header/field1", "hi"), + ("/somefile/field1", "hi"), + ("/somefile/header/field2", "there"), + ("/somefile/header/field2[@attr=foo]", "there"), +): print("\nDocuments containing '%s=%s':" % (path, compareto), line) print([d.filename for d in find_document(path, compareto)]) - diff --git a/examples/elementtree/pickle.py b/examples/elementtree/pickle.py index d40af275b..a86fe30e5 100644 --- a/examples/elementtree/pickle.py +++ b/examples/elementtree/pickle.py @@ -6,15 +6,22 @@ structure in distinct rows using two additional mapped entities. Note that the styles of persistence are identical, as is the structure of the main Document class. """ -from sqlalchemy import (create_engine, MetaData, Table, Column, Integer, String, - PickleType) +from sqlalchemy import ( + create_engine, + MetaData, + Table, + Column, + Integer, + String, + PickleType, +) from sqlalchemy.orm import mapper, Session import sys, os from xml.etree import ElementTree -e = create_engine('sqlite://') +e = create_engine("sqlite://") meta = MetaData() # setup a comparator for the PickleType since it's a mutable @@ -22,12 +29,15 @@ meta = MetaData() def are_elements_equal(x, y): return x == y + # stores a top level record of an XML document. # the "element" column will store the ElementTree document as a BLOB. -documents = Table('documents', meta, - Column('document_id', Integer, primary_key=True), - Column('filename', String(30), unique=True), - Column('element', PickleType(comparator=are_elements_equal)) +documents = Table( + "documents", + meta, + Column("document_id", Integer, primary_key=True), + Column("filename", String(30), unique=True), + Column("element", PickleType(comparator=are_elements_equal)), ) meta.create_all(e) @@ -39,6 +49,7 @@ class Document(object): self.filename = name self.element = element + # setup mapper. mapper(Document, documents) @@ -58,4 +69,3 @@ document = session.query(Document).filter_by(filename="test.xml").first() # print document.element.write(sys.stdout) - diff --git a/examples/generic_associations/__init__.py b/examples/generic_associations/__init__.py index b6593b4f4..9d103b73e 100644 --- a/examples/generic_associations/__init__.py +++ b/examples/generic_associations/__init__.py @@ -15,4 +15,4 @@ are modernized versions of recipes presented in the 2007 blog post .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/generic_associations/discriminator_on_association.py b/examples/generic_associations/discriminator_on_association.py index c3501eefb..38f2370f3 100644 --- a/examples/generic_associations/discriminator_on_association.py +++ b/examples/generic_associations/discriminator_on_association.py @@ -16,27 +16,31 @@ objects, but is also slightly more complex. """ from sqlalchemy.ext.declarative import as_declarative, declared_attr -from sqlalchemy import create_engine, Integer, Column, \ - String, ForeignKey +from sqlalchemy import create_engine, Integer, Column, String, ForeignKey from sqlalchemy.orm import Session, relationship, backref from sqlalchemy.ext.associationproxy import association_proxy + @as_declarative() class Base(object): """Base class which provides automated table name and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class AddressAssociation(Base): """Associates a collection of Address objects with a particular parent. """ + __tablename__ = "address_association" discriminator = Column(String) @@ -44,6 +48,7 @@ class AddressAssociation(Base): __mapper_args__ = {"polymorphic_on": discriminator} + class Address(Base): """The Address class. @@ -51,6 +56,7 @@ class Address(Base): single table. """ + association_id = Column(Integer, ForeignKey("address_association.id")) street = Column(String) city = Column(String) @@ -60,15 +66,20 @@ class Address(Base): parent = association_proxy("association", "parent") def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a relationship to the address_association table for each parent. """ + @declared_attr def address_association_id(cls): return Column(Integer, ForeignKey("address_association.id")) @@ -79,63 +90,62 @@ class HasAddresses(object): discriminator = name.lower() assoc_cls = type( - "%sAddressAssociation" % name, - (AddressAssociation, ), - dict( - __tablename__=None, - __mapper_args__={ - "polymorphic_identity": discriminator - } - ) - ) + "%sAddressAssociation" % name, + (AddressAssociation,), + dict( + __tablename__=None, + __mapper_args__={"polymorphic_identity": discriminator}, + ), + ) cls.addresses = association_proxy( - "address_association", "addresses", - creator=lambda addresses: assoc_cls(addresses=addresses) - ) - return relationship(assoc_cls, - backref=backref("parent", uselist=False)) + "address_association", + "addresses", + creator=lambda addresses: assoc_cls(addresses=addresses), + ) + return relationship( + assoc_cls, backref=backref("parent", uselist=False) + ) class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Address(street="2569 west elm", city="Detroit", zip="56785") + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - print(address.parent)
\ No newline at end of file + print(address.parent) diff --git a/examples/generic_associations/generic_fk.py b/examples/generic_associations/generic_fk.py index 31d2c138d..ded8f749d 100644 --- a/examples/generic_associations/generic_fk.py +++ b/examples/generic_associations/generic_fk.py @@ -20,8 +20,7 @@ or "table_per_association" instead of this approach. """ from sqlalchemy.ext.declarative import as_declarative, declared_attr -from sqlalchemy import create_engine, Integer, Column, \ - String, and_ +from sqlalchemy import create_engine, Integer, Column, String, and_ from sqlalchemy.orm import Session, relationship, foreign, remote, backref from sqlalchemy import event @@ -32,11 +31,14 @@ class Base(object): and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class Address(Base): """The Address class. @@ -44,6 +46,7 @@ class Address(Base): single table. """ + street = Column(String) city = Column(String) zip = Column(String) @@ -66,9 +69,13 @@ class Address(Base): return getattr(self, "parent_%s" % self.discriminator) def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a relationship to @@ -76,63 +83,66 @@ class HasAddresses(object): """ + @event.listens_for(HasAddresses, "mapper_configured", propagate=True) def setup_listener(mapper, class_): name = class_.__name__ discriminator = name.lower() - class_.addresses = relationship(Address, - primaryjoin=and_( - class_.id == foreign(remote(Address.parent_id)), - Address.discriminator == discriminator - ), - backref=backref( - "parent_%s" % discriminator, - primaryjoin=remote(class_.id) == foreign(Address.parent_id) - ) - ) + class_.addresses = relationship( + Address, + primaryjoin=and_( + class_.id == foreign(remote(Address.parent_id)), + Address.discriminator == discriminator, + ), + backref=backref( + "parent_%s" % discriminator, + primaryjoin=remote(class_.id) == foreign(Address.parent_id), + ), + ) + @event.listens_for(class_.addresses, "append") def append_address(target, value, initiator): value.discriminator = discriminator + class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Address(street="2569 west elm", city="Detroit", zip="56785") + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - print(address.parent)
\ No newline at end of file + print(address.parent) diff --git a/examples/generic_associations/table_per_association.py b/examples/generic_associations/table_per_association.py index d54d2f1fa..7de246934 100644 --- a/examples/generic_associations/table_per_association.py +++ b/examples/generic_associations/table_per_association.py @@ -12,21 +12,31 @@ has no dependency on the system. """ from sqlalchemy.ext.declarative import as_declarative, declared_attr -from sqlalchemy import create_engine, Integer, Column, \ - String, ForeignKey, Table +from sqlalchemy import ( + create_engine, + Integer, + Column, + String, + ForeignKey, + Table, +) from sqlalchemy.orm import Session, relationship + @as_declarative() class Base(object): """Base class which provides automated table name and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class Address(Base): """The Address class. @@ -34,72 +44,79 @@ class Address(Base): single table. """ + street = Column(String) city = Column(String) zip = Column(String) def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a new address_association table for each parent. """ + @declared_attr def addresses(cls): address_association = Table( "%s_addresses" % cls.__tablename__, cls.metadata, - Column("address_id", ForeignKey("address.id"), - primary_key=True), - Column("%s_id" % cls.__tablename__, - ForeignKey("%s.id" % cls.__tablename__), - primary_key=True), + Column("address_id", ForeignKey("address.id"), primary_key=True), + Column( + "%s_id" % cls.__tablename__, + ForeignKey("%s.id" % cls.__tablename__), + primary_key=True, + ), ) return relationship(Address, secondary=address_association) + class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Address(street="2569 west elm", city="Detroit", zip="56785") + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - # no parent here
\ No newline at end of file + # no parent here diff --git a/examples/generic_associations/table_per_related.py b/examples/generic_associations/table_per_related.py index 51c9f1b26..9c5e0e179 100644 --- a/examples/generic_associations/table_per_related.py +++ b/examples/generic_associations/table_per_related.py @@ -20,17 +20,21 @@ from sqlalchemy.ext.declarative import as_declarative, declared_attr from sqlalchemy import create_engine, Integer, Column, String, ForeignKey from sqlalchemy.orm import Session, relationship + @as_declarative() class Base(object): """Base class which provides automated table name and surrogate primary key column. """ + @declared_attr def __tablename__(cls): return cls.__name__.lower() + id = Column(Integer, primary_key=True) + class Address(object): """Define columns that will be present in each 'Address' table. @@ -40,74 +44,82 @@ class Address(object): should be set up using @declared_attr. """ + street = Column(String) city = Column(String) zip = Column(String) def __repr__(self): - return "%s(street=%r, city=%r, zip=%r)" % \ - (self.__class__.__name__, self.street, - self.city, self.zip) + return "%s(street=%r, city=%r, zip=%r)" % ( + self.__class__.__name__, + self.street, + self.city, + self.zip, + ) + class HasAddresses(object): """HasAddresses mixin, creates a new Address class for each parent. """ + @declared_attr def addresses(cls): cls.Address = type( "%sAddress" % cls.__name__, - (Address, Base,), + (Address, Base), dict( - __tablename__="%s_address" % - cls.__tablename__, - parent_id=Column(Integer, - ForeignKey("%s.id" % cls.__tablename__)), - parent=relationship(cls) - ) + __tablename__="%s_address" % cls.__tablename__, + parent_id=Column( + Integer, ForeignKey("%s.id" % cls.__tablename__) + ), + parent=relationship(cls), + ), ) return relationship(cls.Address) + class Customer(HasAddresses, Base): name = Column(String) + class Supplier(HasAddresses, Base): company_name = Column(String) -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -session.add_all([ - Customer( - name='customer 1', - addresses=[ - Customer.Address( - street='123 anywhere street', - city="New York", - zip="10110"), - Customer.Address( - street='40 main street', - city="San Francisco", - zip="95732") - ] - ), - Supplier( - company_name="Ace Hammers", - addresses=[ - Supplier.Address( - street='2569 west elm', - city="Detroit", - zip="56785") - ] - ), -]) +session.add_all( + [ + Customer( + name="customer 1", + addresses=[ + Customer.Address( + street="123 anywhere street", city="New York", zip="10110" + ), + Customer.Address( + street="40 main street", city="San Francisco", zip="95732" + ), + ], + ), + Supplier( + company_name="Ace Hammers", + addresses=[ + Supplier.Address( + street="2569 west elm", city="Detroit", zip="56785" + ) + ], + ), + ] +) session.commit() for customer in session.query(Customer): for address in customer.addresses: print(address) - print(address.parent)
\ No newline at end of file + print(address.parent) diff --git a/examples/graphs/__init__.py b/examples/graphs/__init__.py index 57d41453b..0f8fe58a7 100644 --- a/examples/graphs/__init__.py +++ b/examples/graphs/__init__.py @@ -10,4 +10,4 @@ and querying for lower- and upper- neighbors are illustrated:: .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/graphs/directed_graph.py b/examples/graphs/directed_graph.py index 7bcfc5683..af85a4295 100644 --- a/examples/graphs/directed_graph.py +++ b/examples/graphs/directed_graph.py @@ -1,7 +1,6 @@ """a directed graph example.""" -from sqlalchemy import Column, Integer, ForeignKey, \ - create_engine +from sqlalchemy import Column, Integer, ForeignKey, create_engine from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.ext.declarative import declarative_base @@ -9,7 +8,7 @@ Base = declarative_base() class Node(Base): - __tablename__ = 'node' + __tablename__ = "node" node_id = Column(Integer, primary_key=True) @@ -21,33 +20,26 @@ class Node(Base): class Edge(Base): - __tablename__ = 'edge' + __tablename__ = "edge" - lower_id = Column( - Integer, - ForeignKey('node.node_id'), - primary_key=True) + lower_id = Column(Integer, ForeignKey("node.node_id"), primary_key=True) - higher_id = Column( - Integer, - ForeignKey('node.node_id'), - primary_key=True) + higher_id = Column(Integer, ForeignKey("node.node_id"), primary_key=True) lower_node = relationship( - Node, - primaryjoin=lower_id == Node.node_id, - backref='lower_edges') + Node, primaryjoin=lower_id == Node.node_id, backref="lower_edges" + ) higher_node = relationship( - Node, - primaryjoin=higher_id == Node.node_id, - backref='higher_edges') + Node, primaryjoin=higher_id == Node.node_id, backref="higher_edges" + ) def __init__(self, n1, n2): self.lower_node = n1 self.higher_node = n2 -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = sessionmaker(engine)() @@ -80,4 +72,3 @@ assert [x for x in n3.higher_neighbors()] == [n6] assert [x for x in n3.lower_neighbors()] == [n1] assert [x for x in n2.lower_neighbors()] == [n1] assert [x for x in n2.higher_neighbors()] == [n1, n5, n7] - diff --git a/examples/inheritance/concrete.py b/examples/inheritance/concrete.py index 258f41025..2245aa4e0 100644 --- a/examples/inheritance/concrete.py +++ b/examples/inheritance/concrete.py @@ -1,7 +1,14 @@ """Concrete-table (table-per-class) inheritance example.""" -from sqlalchemy import Column, Integer, String, \ - ForeignKey, create_engine, inspect, or_ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + create_engine, + inspect, + or_, +) from sqlalchemy.orm import relationship, Session, with_polymorphic from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import ConcreteBase @@ -11,107 +18,105 @@ Base = declarative_base() class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship( - "Person", - back_populates='company', - cascade='all, delete-orphan') + "Person", back_populates="company", cascade="all, delete-orphan" + ) def __repr__(self): return "Company %s" % self.name class Person(ConcreteBase, Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) name = Column(String(50)) company = relationship("Company", back_populates="employees") - __mapper_args__ = { - 'polymorphic_identity': 'person', - } + __mapper_args__ = {"polymorphic_identity": "person"} def __repr__(self): return "Ordinary person %s" % self.name class Engineer(Person): - __tablename__ = 'engineer' + __tablename__ = "engineer" id = Column(Integer, primary_key=True) name = Column(String(50)) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) status = Column(String(30)) engineer_name = Column(String(30)) primary_language = Column(String(30)) company = relationship("Company", back_populates="employees") - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - 'concrete': True - } + __mapper_args__ = {"polymorphic_identity": "engineer", "concrete": True} def __repr__(self): return ( "Engineer %s, status %s, engineer_name %s, " - "primary_language %s" % - ( - self.name, self.status, - self.engineer_name, self.primary_language) + "primary_language %s" + % ( + self.name, + self.status, + self.engineer_name, + self.primary_language, + ) ) class Manager(Person): - __tablename__ = 'manager' + __tablename__ = "manager" id = Column(Integer, primary_key=True) name = Column(String(50)) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) status = Column(String(30)) manager_name = Column(String(30)) company = relationship("Company", back_populates="employees") - __mapper_args__ = { - 'polymorphic_identity': 'manager', - 'concrete': True - } + __mapper_args__ = {"polymorphic_identity": "manager", "concrete": True} def __repr__(self): return "Manager %s, status %s, manager_name %s" % ( - self.name, self.status, self.manager_name) + self.name, + self.status, + self.manager_name, + ) -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -c = Company(name='company1', employees=[ - Manager( - name='pointy haired boss', - status='AAB', - manager_name='manager1'), - Engineer( - name='dilbert', - status='BBA', - engineer_name='engineer1', - primary_language='java'), - Person(name='joesmith'), - Engineer( - name='wally', - status='CGG', - engineer_name='engineer2', - primary_language='python'), - Manager( - name='jsmith', - status='ABA', - manager_name='manager2') -]) +c = Company( + name="company1", + employees=[ + Manager( + name="pointy haired boss", status="AAB", manager_name="manager1" + ), + Engineer( + name="dilbert", + status="BBA", + engineer_name="engineer1", + primary_language="java", + ), + Person(name="joesmith"), + Engineer( + name="wally", + status="CGG", + engineer_name="engineer2", + primary_language="python", + ), + Manager(name="jsmith", status="ABA", manager_name="manager2"), + ], +) session.add(c) session.commit() @@ -120,14 +125,15 @@ c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( - ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"] +) print("\n") -dilbert = session.query(Person).filter_by(name='dilbert').one() -dilbert2 = session.query(Engineer).filter_by(name='dilbert').one() +dilbert = session.query(Person).filter_by(name="dilbert").one() +dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 -dilbert.engineer_name = 'hes dilbert!' +dilbert.engineer_name = "hes dilbert!" session.commit() @@ -138,24 +144,28 @@ for e in c.employees: # query using with_polymorphic. eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(eng_manager). - filter( + session.query(eng_manager) + .filter( or_( - eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2' + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", ) - ).all() + ) + .all() ) # illustrate join from Company eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(Company). - join( - Company.employees.of_type(eng_manager) - ).filter( - or_(eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2') - ).all()) + session.query(Company) + .join(Company.employees.of_type(eng_manager)) + .filter( + or_( + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", + ) + ) + .all() +) session.commit() diff --git a/examples/inheritance/joined.py b/examples/inheritance/joined.py index f9322158e..a3a61d763 100644 --- a/examples/inheritance/joined.py +++ b/examples/inheritance/joined.py @@ -1,7 +1,14 @@ """Joined-table (table-per-subclass) inheritance example.""" -from sqlalchemy import Column, Integer, String, \ - ForeignKey, create_engine, inspect, or_ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + create_engine, + inspect, + or_, +) from sqlalchemy.orm import relationship, Session, with_polymorphic from sqlalchemy.ext.declarative import declarative_base @@ -9,31 +16,30 @@ Base = declarative_base() class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship( - "Person", - back_populates='company', - cascade='all, delete-orphan') + "Person", back_populates="company", cascade="all, delete-orphan" + ) def __repr__(self): return "Company %s" % self.name class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) name = Column(String(50)) type = Column(String(50)) company = relationship("Company", back_populates="employees") __mapper_args__ = { - 'polymorphic_identity': 'person', - 'polymorphic_on': type + "polymorphic_identity": "person", + "polymorphic_on": type, } def __repr__(self): @@ -41,67 +47,70 @@ class Person(Base): class Engineer(Person): - __tablename__ = 'engineer' - id = Column(ForeignKey('person.id'), primary_key=True) + __tablename__ = "engineer" + id = Column(ForeignKey("person.id"), primary_key=True) status = Column(String(30)) engineer_name = Column(String(30)) primary_language = Column(String(30)) - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - } + __mapper_args__ = {"polymorphic_identity": "engineer"} def __repr__(self): return ( "Engineer %s, status %s, engineer_name %s, " - "primary_language %s" % - ( - self.name, self.status, - self.engineer_name, self.primary_language) + "primary_language %s" + % ( + self.name, + self.status, + self.engineer_name, + self.primary_language, + ) ) class Manager(Person): - __tablename__ = 'manager' - id = Column(ForeignKey('person.id'), primary_key=True) + __tablename__ = "manager" + id = Column(ForeignKey("person.id"), primary_key=True) status = Column(String(30)) manager_name = Column(String(30)) - __mapper_args__ = { - 'polymorphic_identity': 'manager', - } + __mapper_args__ = {"polymorphic_identity": "manager"} def __repr__(self): return "Manager %s, status %s, manager_name %s" % ( - self.name, self.status, self.manager_name) + self.name, + self.status, + self.manager_name, + ) -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -c = Company(name='company1', employees=[ - Manager( - name='pointy haired boss', - status='AAB', - manager_name='manager1'), - Engineer( - name='dilbert', - status='BBA', - engineer_name='engineer1', - primary_language='java'), - Person(name='joesmith'), - Engineer( - name='wally', - status='CGG', - engineer_name='engineer2', - primary_language='python'), - Manager( - name='jsmith', - status='ABA', - manager_name='manager2') -]) +c = Company( + name="company1", + employees=[ + Manager( + name="pointy haired boss", status="AAB", manager_name="manager1" + ), + Engineer( + name="dilbert", + status="BBA", + engineer_name="engineer1", + primary_language="java", + ), + Person(name="joesmith"), + Engineer( + name="wally", + status="CGG", + engineer_name="engineer2", + primary_language="python", + ), + Manager(name="jsmith", status="ABA", manager_name="manager2"), + ], +) session.add(c) session.commit() @@ -110,14 +119,15 @@ c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( - ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"] +) print("\n") -dilbert = session.query(Person).filter_by(name='dilbert').one() -dilbert2 = session.query(Engineer).filter_by(name='dilbert').one() +dilbert = session.query(Person).filter_by(name="dilbert").one() +dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 -dilbert.engineer_name = 'hes dilbert!' +dilbert.engineer_name = "hes dilbert!" session.commit() @@ -128,13 +138,14 @@ for e in c.employees: # query using with_polymorphic. eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(eng_manager). - filter( + session.query(eng_manager) + .filter( or_( - eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2' + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", ) - ).all() + ) + .all() ) # illustrate join from Company. @@ -144,12 +155,15 @@ print( # loading. eng_manager = with_polymorphic(Person, [Engineer, Manager], flat=True) print( - session.query(Company). - join( - Company.employees.of_type(eng_manager) - ).filter( - or_(eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2') - ).all()) + session.query(Company) + .join(Company.employees.of_type(eng_manager)) + .filter( + or_( + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", + ) + ) + .all() +) session.commit() diff --git a/examples/inheritance/single.py b/examples/inheritance/single.py index 56397f540..46d690c94 100644 --- a/examples/inheritance/single.py +++ b/examples/inheritance/single.py @@ -1,7 +1,14 @@ """Single-table (table-per-hierarchy) inheritance example.""" -from sqlalchemy import Column, Integer, String, \ - ForeignKey, create_engine, inspect, or_ +from sqlalchemy import ( + Column, + Integer, + String, + ForeignKey, + create_engine, + inspect, + or_, +) from sqlalchemy.orm import relationship, Session, with_polymorphic from sqlalchemy.ext.declarative import declarative_base, declared_attr @@ -9,31 +16,30 @@ Base = declarative_base() class Company(Base): - __tablename__ = 'company' + __tablename__ = "company" id = Column(Integer, primary_key=True) name = Column(String(50)) employees = relationship( - "Person", - back_populates='company', - cascade='all, delete-orphan') + "Person", back_populates="company", cascade="all, delete-orphan" + ) def __repr__(self): return "Company %s" % self.name class Person(Base): - __tablename__ = 'person' + __tablename__ = "person" id = Column(Integer, primary_key=True) - company_id = Column(ForeignKey('company.id')) + company_id = Column(ForeignKey("company.id")) name = Column(String(50)) type = Column(String(50)) company = relationship("Company", back_populates="employees") __mapper_args__ = { - 'polymorphic_identity': 'person', - 'polymorphic_on': type + "polymorphic_identity": "person", + "polymorphic_on": type, } def __repr__(self): @@ -50,19 +56,20 @@ class Engineer(Person): # declarative/inheritance.html#resolving-column-conflicts @declared_attr def status(cls): - return Person.__table__.c.get('status', Column(String(30))) + return Person.__table__.c.get("status", Column(String(30))) - __mapper_args__ = { - 'polymorphic_identity': 'engineer', - } + __mapper_args__ = {"polymorphic_identity": "engineer"} def __repr__(self): return ( "Engineer %s, status %s, engineer_name %s, " - "primary_language %s" % - ( - self.name, self.status, - self.engineer_name, self.primary_language) + "primary_language %s" + % ( + self.name, + self.status, + self.engineer_name, + self.primary_language, + ) ) @@ -71,43 +78,45 @@ class Manager(Person): @declared_attr def status(cls): - return Person.__table__.c.get('status', Column(String(30))) + return Person.__table__.c.get("status", Column(String(30))) - __mapper_args__ = { - 'polymorphic_identity': 'manager', - } + __mapper_args__ = {"polymorphic_identity": "manager"} def __repr__(self): return "Manager %s, status %s, manager_name %s" % ( - self.name, self.status, self.manager_name) + self.name, + self.status, + self.manager_name, + ) -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) -c = Company(name='company1', employees=[ - Manager( - name='pointy haired boss', - status='AAB', - manager_name='manager1'), - Engineer( - name='dilbert', - status='BBA', - engineer_name='engineer1', - primary_language='java'), - Person(name='joesmith'), - Engineer( - name='wally', - status='CGG', - engineer_name='engineer2', - primary_language='python'), - Manager( - name='jsmith', - status='ABA', - manager_name='manager2') -]) +c = Company( + name="company1", + employees=[ + Manager( + name="pointy haired boss", status="AAB", manager_name="manager1" + ), + Engineer( + name="dilbert", + status="BBA", + engineer_name="engineer1", + primary_language="java", + ), + Person(name="joesmith"), + Engineer( + name="wally", + status="CGG", + engineer_name="engineer2", + primary_language="python", + ), + Manager(name="jsmith", status="ABA", manager_name="manager2"), + ], +) session.add(c) session.commit() @@ -116,14 +125,15 @@ c = session.query(Company).get(1) for e in c.employees: print(e, inspect(e).key, e.company) assert set([e.name for e in c.employees]) == set( - ['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith']) + ["pointy haired boss", "dilbert", "joesmith", "wally", "jsmith"] +) print("\n") -dilbert = session.query(Person).filter_by(name='dilbert').one() -dilbert2 = session.query(Engineer).filter_by(name='dilbert').one() +dilbert = session.query(Person).filter_by(name="dilbert").one() +dilbert2 = session.query(Engineer).filter_by(name="dilbert").one() assert dilbert is dilbert2 -dilbert.engineer_name = 'hes dilbert!' +dilbert.engineer_name = "hes dilbert!" session.commit() @@ -134,24 +144,28 @@ for e in c.employees: # query using with_polymorphic. eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(eng_manager). - filter( + session.query(eng_manager) + .filter( or_( - eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2' + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", ) - ).all() + ) + .all() ) # illustrate join from Company, eng_manager = with_polymorphic(Person, [Engineer, Manager]) print( - session.query(Company). - join( - Company.employees.of_type(eng_manager) - ).filter( - or_(eng_manager.Engineer.engineer_name == 'engineer1', - eng_manager.Manager.manager_name == 'manager2') - ).all()) + session.query(Company) + .join(Company.employees.of_type(eng_manager)) + .filter( + or_( + eng_manager.Engineer.engineer_name == "engineer1", + eng_manager.Manager.manager_name == "manager2", + ) + ) + .all() +) session.commit() diff --git a/examples/join_conditions/__init__.py b/examples/join_conditions/__init__.py index 3a561d084..d67eb68e4 100644 --- a/examples/join_conditions/__init__.py +++ b/examples/join_conditions/__init__.py @@ -4,4 +4,4 @@ of join conditions. .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/join_conditions/cast.py b/examples/join_conditions/cast.py index 246bc1d57..7ea775689 100644 --- a/examples/join_conditions/cast.py +++ b/examples/join_conditions/cast.py @@ -35,6 +35,7 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() + class StringAsInt(TypeDecorator): """Coerce string->integer type. @@ -44,52 +45,55 @@ class StringAsInt(TypeDecorator): on the child during a flush. """ + impl = Integer + def process_bind_param(self, value, dialect): if value is not None: value = int(value) return value + class A(Base): """Parent. The referenced column is a string type.""" - __tablename__ = 'a' + __tablename__ = "a" id = Column(Integer, primary_key=True) a_id = Column(String) + class B(Base): """Child. The column we reference 'A' with is an integer.""" - __tablename__ = 'b' + __tablename__ = "b" id = Column(Integer, primary_key=True) a_id = Column(StringAsInt) - a = relationship("A", - # specify primaryjoin. The string form is optional - # here, but note that Declarative makes available all - # of the built-in functions we might need, including - # cast() and foreign(). - primaryjoin="cast(A.a_id, Integer) == foreign(B.a_id)", - backref="bs") + a = relationship( + "A", + # specify primaryjoin. The string form is optional + # here, but note that Declarative makes available all + # of the built-in functions we might need, including + # cast() and foreign(). + primaryjoin="cast(A.a_id, Integer) == foreign(B.a_id)", + backref="bs", + ) + # we demonstrate with SQLite, but the important part # is the CAST rendered in the SQL output. -e = create_engine('sqlite://', echo=True) +e = create_engine("sqlite://", echo=True) Base.metadata.create_all(e) s = Session(e) -s.add_all([ - A(a_id="1"), - A(a_id="2", bs=[B(), B()]), - A(a_id="3", bs=[B()]), -]) +s.add_all([A(a_id="1"), A(a_id="2", bs=[B(), B()]), A(a_id="3", bs=[B()])]) s.commit() b1 = s.query(B).filter_by(a_id="2").first() print(b1.a) a1 = s.query(A).filter_by(a_id="2").first() -print(a1.bs)
\ No newline at end of file +print(a1.bs) diff --git a/examples/join_conditions/threeway.py b/examples/join_conditions/threeway.py index 13df0f349..257002637 100644 --- a/examples/join_conditions/threeway.py +++ b/examples/join_conditions/threeway.py @@ -39,46 +39,56 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() + class First(Base): - __tablename__ = 'first' + __tablename__ = "first" first_id = Column(Integer, primary_key=True) partition_key = Column(String) def __repr__(self): - return ("First(%s, %s)" % (self.first_id, self.partition_key)) + return "First(%s, %s)" % (self.first_id, self.partition_key) + class Second(Base): - __tablename__ = 'second' + __tablename__ = "second" first_id = Column(Integer, primary_key=True) other_id = Column(Integer, primary_key=True) + class Partitioned(Base): - __tablename__ = 'partitioned' + __tablename__ = "partitioned" other_id = Column(Integer, primary_key=True) partition_key = Column(String, primary_key=True) def __repr__(self): - return ("Partitioned(%s, %s)" % (self.other_id, self.partition_key)) + return "Partitioned(%s, %s)" % (self.other_id, self.partition_key) j = join(Partitioned, Second, Partitioned.other_id == Second.other_id) -partitioned_second = mapper(Partitioned, j, non_primary=True, properties={ +partitioned_second = mapper( + Partitioned, + j, + non_primary=True, + properties={ # note we need to disambiguate columns here - the join() # will provide them as j.c.<tablename>_<colname> for access, # but they retain their real names in the mapping - "other_id": [j.c.partitioned_other_id, j.c.second_other_id], - }) + "other_id": [j.c.partitioned_other_id, j.c.second_other_id] + }, +) First.partitioned = relationship( - partitioned_second, - primaryjoin=and_( - First.partition_key == partitioned_second.c.partition_key, - First.first_id == foreign(partitioned_second.c.first_id) - ), innerjoin=True) + partitioned_second, + primaryjoin=and_( + First.partition_key == partitioned_second.c.partition_key, + First.first_id == foreign(partitioned_second.c.first_id), + ), + innerjoin=True, +) # when using any database other than SQLite, we will get a nested # join, e.g. "first JOIN (partitioned JOIN second ON ..) ON ..". @@ -87,17 +97,19 @@ e = create_engine("sqlite://", echo=True) Base.metadata.create_all(e) s = Session(e) -s.add_all([ - First(first_id=1, partition_key='p1'), - First(first_id=2, partition_key='p1'), - First(first_id=3, partition_key='p2'), - Second(first_id=1, other_id=1), - Second(first_id=2, other_id=1), - Second(first_id=3, other_id=2), - Partitioned(partition_key='p1', other_id=1), - Partitioned(partition_key='p1', other_id=2), - Partitioned(partition_key='p2', other_id=2), -]) +s.add_all( + [ + First(first_id=1, partition_key="p1"), + First(first_id=2, partition_key="p1"), + First(first_id=3, partition_key="p2"), + Second(first_id=1, other_id=1), + Second(first_id=2, other_id=1), + Second(first_id=3, other_id=2), + Partitioned(partition_key="p1", other_id=1), + Partitioned(partition_key="p1", other_id=2), + Partitioned(partition_key="p2", other_id=2), + ] +) s.commit() for row in s.query(First, Partitioned).join(First.partitioned): diff --git a/examples/large_collection/large_collection.py b/examples/large_collection/large_collection.py index 82d2e554b..eb014c6cb 100644 --- a/examples/large_collection/large_collection.py +++ b/examples/large_collection/large_collection.py @@ -1,54 +1,76 @@ - -from sqlalchemy import (MetaData, Table, Column, Integer, String, ForeignKey, - create_engine) -from sqlalchemy.orm import (mapper, relationship, sessionmaker) +from sqlalchemy import ( + MetaData, + Table, + Column, + Integer, + String, + ForeignKey, + create_engine, +) +from sqlalchemy.orm import mapper, relationship, sessionmaker meta = MetaData() -org_table = Table('organizations', meta, - Column('org_id', Integer, primary_key=True), - Column('org_name', String(50), nullable=False, key='name'), - mysql_engine='InnoDB') - -member_table = Table('members', meta, - Column('member_id', Integer, primary_key=True), - Column('member_name', String(50), nullable=False, key='name'), - Column('org_id', Integer, - ForeignKey('organizations.org_id', ondelete="CASCADE")), - mysql_engine='InnoDB') +org_table = Table( + "organizations", + meta, + Column("org_id", Integer, primary_key=True), + Column("org_name", String(50), nullable=False, key="name"), + mysql_engine="InnoDB", +) + +member_table = Table( + "members", + meta, + Column("member_id", Integer, primary_key=True), + Column("member_name", String(50), nullable=False, key="name"), + Column( + "org_id", + Integer, + ForeignKey("organizations.org_id", ondelete="CASCADE"), + ), + mysql_engine="InnoDB", +) class Organization(object): def __init__(self, name): self.name = name + class Member(object): def __init__(self, name): self.name = name -mapper(Organization, org_table, properties = { - 'members' : relationship(Member, - # Organization.members will be a Query object - no loading - # of the entire collection occurs unless requested - lazy="dynamic", - - # Member objects "belong" to their parent, are deleted when - # removed from the collection - cascade="all, delete-orphan", - - # "delete, delete-orphan" cascade does not load in objects on delete, - # allows ON DELETE CASCADE to handle it. - # this only works with a database that supports ON DELETE CASCADE - - # *not* sqlite or MySQL with MyISAM - passive_deletes=True, - ) -}) + +mapper( + Organization, + org_table, + properties={ + "members": relationship( + Member, + # Organization.members will be a Query object - no loading + # of the entire collection occurs unless requested + lazy="dynamic", + # Member objects "belong" to their parent, are deleted when + # removed from the collection + cascade="all, delete-orphan", + # "delete, delete-orphan" cascade does not load in objects on delete, + # allows ON DELETE CASCADE to handle it. + # this only works with a database that supports ON DELETE CASCADE - + # *not* sqlite or MySQL with MyISAM + passive_deletes=True, + ) + }, +) mapper(Member, member_table) -if __name__ == '__main__': - engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True) +if __name__ == "__main__": + engine = create_engine( + "postgresql://scott:tiger@localhost/test", echo=True + ) meta.create_all(engine) # expire_on_commit=False means the session contents @@ -56,10 +78,10 @@ if __name__ == '__main__': sess = sessionmaker(engine, expire_on_commit=False)() # create org with some members - org = Organization('org one') - org.members.append(Member('member one')) - org.members.append(Member('member two')) - org.members.append(Member('member three')) + org = Organization("org one") + org.members.append(Member("member one")) + org.members.append(Member("member two")) + org.members.append(Member("member three")) sess.add(org) @@ -69,14 +91,14 @@ if __name__ == '__main__': # the 'members' collection is a Query. it issues # SQL as needed to load subsets of the collection. print("-------------------------\nload subset of members\n") - members = org.members.filter(member_table.c.name.like('%member t%')).all() + members = org.members.filter(member_table.c.name.like("%member t%")).all() print(members) # new Members can be appended without any # SQL being emitted to load the full collection - org.members.append(Member('member four')) - org.members.append(Member('member five')) - org.members.append(Member('member six')) + org.members.append(Member("member four")) + org.members.append(Member("member five")) + org.members.append(Member("member six")) print("-------------------------\nflush two - save 3 more members\n") sess.commit() @@ -85,7 +107,9 @@ if __name__ == '__main__': # SQL is only emitted for the head row - the Member rows # disappear automatically without the need for additional SQL. sess.delete(org) - print("-------------------------\nflush three - delete org, delete members in one statement\n") + print( + "-------------------------\nflush three - delete org, delete members in one statement\n" + ) sess.commit() print("-------------------------\nno Member rows should remain:\n") @@ -93,4 +117,4 @@ if __name__ == '__main__': sess.close() print("------------------------\ndone. dropping tables.") - meta.drop_all(engine)
\ No newline at end of file + meta.drop_all(engine) diff --git a/examples/materialized_paths/materialized_paths.py b/examples/materialized_paths/materialized_paths.py index 4ded90f7e..45ae0c193 100644 --- a/examples/materialized_paths/materialized_paths.py +++ b/examples/materialized_paths/materialized_paths.py @@ -44,21 +44,35 @@ class Node(Base): # To find the descendants of this node, we look for nodes whose path # starts with this node's path. descendants = relationship( - "Node", viewonly=True, order_by=path, - primaryjoin=remote(foreign(path)).like(path.concat(".%"))) + "Node", + viewonly=True, + order_by=path, + primaryjoin=remote(foreign(path)).like(path.concat(".%")), + ) # Finding the ancestors is a little bit trickier. We need to create a fake # secondary table since this behaves like a many-to-many join. - secondary = select([ - id.label("id"), - func.unnest(cast(func.string_to_array( - func.regexp_replace(path, r"\.?\d+$", ""), "."), - ARRAY(Integer))).label("ancestor_id") - ]).alias() - ancestors = relationship("Node", viewonly=True, secondary=secondary, - primaryjoin=id == secondary.c.id, - secondaryjoin=secondary.c.ancestor_id == id, - order_by=path) + secondary = select( + [ + id.label("id"), + func.unnest( + cast( + func.string_to_array( + func.regexp_replace(path, r"\.?\d+$", ""), "." + ), + ARRAY(Integer), + ) + ).label("ancestor_id"), + ] + ).alias() + ancestors = relationship( + "Node", + viewonly=True, + secondary=secondary, + primaryjoin=id == secondary.c.id, + secondaryjoin=secondary.c.ancestor_id == id, + order_by=path, + ) @property def depth(self): @@ -70,38 +84,44 @@ class Node(Base): def __str__(self): root_depth = self.depth s = [str(self.id)] - s.extend(((n.depth - root_depth) * " " + str(n.id)) - for n in self.descendants) + s.extend( + ((n.depth - root_depth) * " " + str(n.id)) + for n in self.descendants + ) return "\n".join(s) def move_to(self, new_parent): new_path = new_parent.path + "." + str(self.id) for n in self.descendants: - n.path = new_path + n.path[len(self.path):] + n.path = new_path + n.path[len(self.path) :] self.path = new_path if __name__ == "__main__": - engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True) + engine = create_engine( + "postgresql://scott:tiger@localhost/test", echo=True + ) Base.metadata.create_all(engine) session = Session(engine) print("-" * 80) print("create a tree") - session.add_all([ - Node(id=1, path="1"), - Node(id=2, path="1.2"), - Node(id=3, path="1.3"), - Node(id=4, path="1.3.4"), - Node(id=5, path="1.3.5"), - Node(id=6, path="1.3.6"), - Node(id=7, path="1.7"), - Node(id=8, path="1.7.8"), - Node(id=9, path="1.7.9"), - Node(id=10, path="1.7.9.10"), - Node(id=11, path="1.7.11"), - ]) + session.add_all( + [ + Node(id=1, path="1"), + Node(id=2, path="1.2"), + Node(id=3, path="1.3"), + Node(id=4, path="1.3.4"), + Node(id=5, path="1.3.5"), + Node(id=6, path="1.3.6"), + Node(id=7, path="1.7"), + Node(id=8, path="1.7.8"), + Node(id=9, path="1.7.9"), + Node(id=10, path="1.7.9.10"), + Node(id=11, path="1.7.11"), + ] + ) session.flush() print(str(session.query(Node).get(1))) diff --git a/examples/nested_sets/__init__.py b/examples/nested_sets/__init__.py index 3e73bb13e..5fdfbcedc 100644 --- a/examples/nested_sets/__init__.py +++ b/examples/nested_sets/__init__.py @@ -3,4 +3,4 @@ pattern for hierarchical data using the SQLAlchemy ORM. .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/nested_sets/nested_sets.py b/examples/nested_sets/nested_sets.py index c64b15b61..705a3d279 100644 --- a/examples/nested_sets/nested_sets.py +++ b/examples/nested_sets/nested_sets.py @@ -4,19 +4,27 @@ http://www.intelligententerprise.com/001020/celko.jhtml """ -from sqlalchemy import (create_engine, Column, Integer, String, select, case, - func) +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + select, + case, + func, +) from sqlalchemy.orm import Session, aliased from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import event Base = declarative_base() + class Employee(Base): - __tablename__ = 'personnel' + __tablename__ = "personnel" __mapper_args__ = { - 'batch': False # allows extension to fire for each - # instance before going to the next. + "batch": False # allows extension to fire for each + # instance before going to the next. } parent = None @@ -29,6 +37,7 @@ class Employee(Base): def __repr__(self): return "Employee(%s, %d, %d)" % (self.emp, self.left, self.right) + @event.listens_for(Employee, "before_insert") def before_insert(mapper, connection, instance): if not instance.parent: @@ -37,23 +46,31 @@ def before_insert(mapper, connection, instance): else: personnel = mapper.mapped_table right_most_sibling = connection.scalar( - select([personnel.c.rgt]). - where(personnel.c.emp == instance.parent.emp) + select([personnel.c.rgt]).where( + personnel.c.emp == instance.parent.emp + ) ) connection.execute( - personnel.update( - personnel.c.rgt >= right_most_sibling).values( - lft=case( - [(personnel.c.lft > right_most_sibling, - personnel.c.lft + 2)], - else_=personnel.c.lft - ), - rgt=case( - [(personnel.c.rgt >= right_most_sibling, - personnel.c.rgt + 2)], - else_=personnel.c.rgt - ) + personnel.update(personnel.c.rgt >= right_most_sibling).values( + lft=case( + [ + ( + personnel.c.lft > right_most_sibling, + personnel.c.lft + 2, + ) + ], + else_=personnel.c.lft, + ), + rgt=case( + [ + ( + personnel.c.rgt >= right_most_sibling, + personnel.c.rgt + 2, + ) + ], + else_=personnel.c.rgt, + ), ) ) instance.left = right_most_sibling @@ -62,18 +79,19 @@ def before_insert(mapper, connection, instance): # before_update() would be needed to support moving of nodes # after_delete() would be needed to support removal of nodes. -engine = create_engine('sqlite://', echo=True) + +engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(bind=engine) -albert = Employee(emp='Albert') -bert = Employee(emp='Bert') -chuck = Employee(emp='Chuck') -donna = Employee(emp='Donna') -eddie = Employee(emp='Eddie') -fred = Employee(emp='Fred') +albert = Employee(emp="Albert") +bert = Employee(emp="Bert") +chuck = Employee(emp="Chuck") +donna = Employee(emp="Donna") +eddie = Employee(emp="Eddie") +fred = Employee(emp="Fred") bert.parent = albert chuck.parent = albert @@ -90,22 +108,28 @@ print(session.query(Employee).all()) # 1. Find an employee and all their supervisors, no matter how deep the tree. ealias = aliased(Employee) -print(session.query(Employee).\ - filter(ealias.left.between(Employee.left, Employee.right)).\ - filter(ealias.emp == 'Eddie').all()) - -#2. Find the employee and all their subordinates. +print( + session.query(Employee) + .filter(ealias.left.between(Employee.left, Employee.right)) + .filter(ealias.emp == "Eddie") + .all() +) + +# 2. Find the employee and all their subordinates. # (This query has a nice symmetry with the first query.) -print(session.query(Employee).\ - filter(Employee.left.between(ealias.left, ealias.right)).\ - filter(ealias.emp == 'Chuck').all()) - -#3. Find the level of each node, so you can print the tree +print( + session.query(Employee) + .filter(Employee.left.between(ealias.left, ealias.right)) + .filter(ealias.emp == "Chuck") + .all() +) + +# 3. Find the level of each node, so you can print the tree # as an indented listing. -for indentation, employee in session.query( - func.count(Employee.emp).label('indentation') - 1, ealias).\ - filter(ealias.left.between(Employee.left, Employee.right)).\ - group_by(ealias.emp).\ - order_by(ealias.left): +for indentation, employee in ( + session.query(func.count(Employee.emp).label("indentation") - 1, ealias) + .filter(ealias.left.between(Employee.left, Employee.right)) + .group_by(ealias.emp) + .order_by(ealias.left) +): print(" " * indentation + str(employee)) - diff --git a/examples/performance/__init__.py b/examples/performance/__init__.py index 6264ae9f7..b66199f3c 100644 --- a/examples/performance/__init__.py +++ b/examples/performance/__init__.py @@ -255,7 +255,8 @@ class Profiler(object): def profile(cls, fn): if cls.name is None: raise ValueError( - "Need to call Profile.init(<suitename>, <default_num>) first.") + "Need to call Profile.init(<suitename>, <default_num>) first." + ) cls.tests.append(fn) return fn @@ -270,7 +271,8 @@ class Profiler(object): def setup_once(cls, fn): if cls._setup_once is not None: raise ValueError( - "setup_once function already set to %s" % cls._setup_once) + "setup_once function already set to %s" % cls._setup_once + ) cls._setup_once = staticmethod(fn) return fn @@ -298,7 +300,7 @@ class Profiler(object): finally: pr.disable() - stats = pstats.Stats(pr).sort_stats('cumulative') + stats = pstats.Stats(pr).sort_stats("cumulative") self.stats.append(TestResult(self, fn, stats=stats)) return result @@ -326,7 +328,8 @@ class Profiler(object): if cls.name is None: parser.add_argument( - "name", choices=cls._suite_names(), help="suite to run") + "name", choices=cls._suite_names(), help="suite to run" + ) if len(sys.argv) > 1: potential_name = sys.argv[1] @@ -335,35 +338,44 @@ class Profiler(object): except ImportError: pass - parser.add_argument( - "--test", type=str, - help="run specific test name" - ) + parser.add_argument("--test", type=str, help="run specific test name") parser.add_argument( - '--dburl', type=str, default="sqlite:///profile.db", - help="database URL, default sqlite:///profile.db" + "--dburl", + type=str, + default="sqlite:///profile.db", + help="database URL, default sqlite:///profile.db", ) parser.add_argument( - '--num', type=int, default=cls.num, + "--num", + type=int, + default=cls.num, help="Number of iterations/items/etc for tests; " - "default is %d module-specific" % cls.num + "default is %d module-specific" % cls.num, ) parser.add_argument( - '--profile', action='store_true', - help='run profiling and dump call counts') + "--profile", + action="store_true", + help="run profiling and dump call counts", + ) parser.add_argument( - '--dump', action='store_true', - help='dump full call profile (implies --profile)') + "--dump", + action="store_true", + help="dump full call profile (implies --profile)", + ) parser.add_argument( - '--callers', action='store_true', - help='print callers as well (implies --dump)') + "--callers", + action="store_true", + help="print callers as well (implies --dump)", + ) parser.add_argument( - '--runsnake', action='store_true', - help='invoke runsnakerun (implies --profile)') + "--runsnake", + action="store_true", + help="invoke runsnakerun (implies --profile)", + ) parser.add_argument( - '--echo', action='store_true', - help="Echo SQL output") + "--echo", action="store_true", help="Echo SQL output" + ) args = parser.parse_args() args.dump = args.dump or args.callers @@ -378,7 +390,7 @@ class Profiler(object): def _suite_names(cls): suites = [] for file_ in os.listdir(os.path.dirname(__file__)): - match = re.match(r'^([a-z].*).py$', file_) + match = re.match(r"^([a-z].*).py$", file_) if match: suites.append(match.group(1)) return suites @@ -398,7 +410,10 @@ class TestResult(object): def _summary(self): summary = "%s : %s (%d iterations)" % ( - self.test.__name__, self.test.__doc__, self.profile.num) + self.test.__name__, + self.test.__doc__, + self.profile.num, + ) if self.total_time: summary += "; total time %f sec" % self.total_time if self.stats: @@ -412,7 +427,7 @@ class TestResult(object): self._dump() def _dump(self): - self.stats.sort_stats('time', 'calls') + self.stats.sort_stats("time", "calls") self.stats.print_stats() if self.profile.callers: self.stats.print_callers() @@ -424,5 +439,3 @@ class TestResult(object): os.system("runsnake %s" % filename) finally: os.remove(filename) - - diff --git a/examples/performance/__main__.py b/examples/performance/__main__.py index 5e05143bf..945458651 100644 --- a/examples/performance/__main__.py +++ b/examples/performance/__main__.py @@ -2,6 +2,5 @@ from . import Profiler -if __name__ == '__main__': +if __name__ == "__main__": Profiler.main() - diff --git a/examples/performance/bulk_inserts.py b/examples/performance/bulk_inserts.py index 9c3cff5b2..52f0f32e6 100644 --- a/examples/performance/bulk_inserts.py +++ b/examples/performance/bulk_inserts.py @@ -36,12 +36,15 @@ def test_flush_no_pk(n): """Individual INSERT statements via the ORM, calling upon last row id""" session = Session(bind=engine) for chunk in range(0, n, 1000): - session.add_all([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i) - for i in range(chunk, chunk + 1000) - ]) + session.add_all( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(chunk, chunk + 1000) + ] + ) session.flush() session.commit() @@ -50,13 +53,16 @@ def test_flush_no_pk(n): def test_bulk_save_return_pks(n): """Individual INSERT statements in "bulk", but calling upon last row id""" session = Session(bind=engine) - session.bulk_save_objects([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i - ) - for i in range(n) - ], return_defaults=True) + session.bulk_save_objects( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(n) + ], + return_defaults=True, + ) session.commit() @@ -65,13 +71,16 @@ def test_flush_pk_given(n): """Batched INSERT statements via the ORM, PKs already defined""" session = Session(bind=engine) for chunk in range(0, n, 1000): - session.add_all([ - Customer( - id=i + 1, - name='customer name %d' % i, - description='customer description %d' % i) - for i in range(chunk, chunk + 1000) - ]) + session.add_all( + [ + Customer( + id=i + 1, + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(chunk, chunk + 1000) + ] + ) session.flush() session.commit() @@ -80,13 +89,15 @@ def test_flush_pk_given(n): def test_bulk_save(n): """Batched INSERT statements via the ORM in "bulk", discarding PKs.""" session = Session(bind=engine) - session.bulk_save_objects([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i - ) - for i in range(n) - ]) + session.bulk_save_objects( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(n) + ] + ) session.commit() @@ -94,13 +105,16 @@ def test_bulk_save(n): def test_bulk_insert_mappings(n): """Batched INSERT statements via the ORM "bulk", using dictionaries.""" session = Session(bind=engine) - session.bulk_insert_mappings(Customer, [ - dict( - name='customer name %d' % i, - description='customer description %d' % i - ) - for i in range(n) - ]) + session.bulk_insert_mappings( + Customer, + [ + dict( + name="customer name %d" % i, + description="customer description %d" % i, + ) + for i in range(n) + ], + ) session.commit() @@ -112,11 +126,12 @@ def test_core_insert(n): Customer.__table__.insert(), [ dict( - name='customer name %d' % i, - description='customer description %d' % i + name="customer name %d" % i, + description="customer description %d" % i, ) for i in range(n) - ]) + ], + ) @Profiler.profile @@ -125,30 +140,30 @@ def test_dbapi_raw(n): conn = engine.pool._creator() cursor = conn.cursor() - compiled = Customer.__table__.insert().values( - name=bindparam('name'), - description=bindparam('description')).\ - compile(dialect=engine.dialect) + compiled = ( + Customer.__table__.insert() + .values(name=bindparam("name"), description=bindparam("description")) + .compile(dialect=engine.dialect) + ) if compiled.positional: args = ( - ('customer name %d' % i, 'customer description %d' % i) - for i in range(n)) + ("customer name %d" % i, "customer description %d" % i) + for i in range(n) + ) else: args = ( dict( - name='customer name %d' % i, - description='customer description %d' % i + name="customer name %d" % i, + description="customer description %d" % i, ) for i in range(n) ) - cursor.executemany( - str(compiled), - list(args) - ) + cursor.executemany(str(compiled), list(args)) conn.commit() conn.close() -if __name__ == '__main__': + +if __name__ == "__main__": Profiler.main() diff --git a/examples/performance/bulk_updates.py b/examples/performance/bulk_updates.py index 9522e4bf5..ebb700068 100644 --- a/examples/performance/bulk_updates.py +++ b/examples/performance/bulk_updates.py @@ -32,12 +32,16 @@ def setup_database(dburl, echo, num): s = Session(engine) for chunk in range(0, num, 10000): - s.bulk_insert_mappings(Customer, [ - { - 'name': 'customer name %d' % i, - 'description': 'customer description %d' % i - } for i in range(chunk, chunk + 10000) - ]) + s.bulk_insert_mappings( + Customer, + [ + { + "name": "customer name %d" % i, + "description": "customer description %d" % i, + } + for i in range(chunk, chunk + 10000) + ], + ) s.commit() @@ -46,8 +50,11 @@ def test_orm_flush(n): """UPDATE statements via the ORM flush process.""" session = Session(bind=engine) for chunk in range(0, n, 1000): - customers = session.query(Customer).\ - filter(Customer.id.between(chunk, chunk + 1000)).all() + customers = ( + session.query(Customer) + .filter(Customer.id.between(chunk, chunk + 1000)) + .all() + ) for customer in customers: customer.description += "updated" session.flush() diff --git a/examples/performance/large_resultsets.py b/examples/performance/large_resultsets.py index c13683040..ad1c23194 100644 --- a/examples/performance/large_resultsets.py +++ b/examples/performance/large_resultsets.py @@ -46,9 +46,12 @@ def setup_database(dburl, echo, num): Customer.__table__.insert(), params=[ { - 'name': 'customer name %d' % i, - 'description': 'customer description %d' % i - } for i in range(chunk, chunk + 10000)]) + "name": "customer name %d" % i, + "description": "customer description %d" % i, + } + for i in range(chunk, chunk + 10000) + ], + ) s.commit() @@ -74,8 +77,9 @@ def test_orm_bundles(n): """Load lightweight "bundle" objects using the ORM.""" sess = Session(engine) - bundle = Bundle('customer', - Customer.id, Customer.name, Customer.description) + bundle = Bundle( + "customer", Customer.id, Customer.name, Customer.description + ) for row in sess.query(bundle).yield_per(10000).limit(n): pass @@ -85,9 +89,11 @@ def test_orm_columns(n): """Load individual columns into named tuples using the ORM.""" sess = Session(engine) - for row in sess.query( - Customer.id, Customer.name, - Customer.description).yield_per(10000).limit(n): + for row in ( + sess.query(Customer.id, Customer.name, Customer.description) + .yield_per(10000) + .limit(n) + ): pass @@ -98,7 +104,7 @@ def test_core_fetchall(n): with engine.connect() as conn: result = conn.execute(Customer.__table__.select().limit(n)).fetchall() for row in result: - data = row['id'], row['name'], row['description'] + data = row["id"], row["name"], row["description"] @Profiler.profile @@ -106,14 +112,15 @@ def test_core_fetchmany_w_streaming(n): """Load Core result rows using fetchmany/streaming.""" with engine.connect() as conn: - result = conn.execution_options(stream_results=True).\ - execute(Customer.__table__.select().limit(n)) + result = conn.execution_options(stream_results=True).execute( + Customer.__table__.select().limit(n) + ) while True: chunk = result.fetchmany(10000) if not chunk: break for row in chunk: - data = row['id'], row['name'], row['description'] + data = row["id"], row["name"], row["description"] @Profiler.profile @@ -127,7 +134,7 @@ def test_core_fetchmany(n): if not chunk: break for row in chunk: - data = row['id'], row['name'], row['description'] + data = row["id"], row["name"], row["description"] @Profiler.profile @@ -145,10 +152,13 @@ def test_dbapi_fetchall_no_object(n): def _test_dbapi_raw(n, make_objects): - compiled = Customer.__table__.select().limit(n).\ - compile( - dialect=engine.dialect, - compile_kwargs={"literal_binds": True}) + compiled = ( + Customer.__table__.select() + .limit(n) + .compile( + dialect=engine.dialect, compile_kwargs={"literal_binds": True} + ) + ) if make_objects: # because if you're going to roll your own, you're probably @@ -170,7 +180,8 @@ def _test_dbapi_raw(n, make_objects): for row in cursor.fetchall(): # ensure that we fully fetch! customer = SimpleCustomer( - id=row[0], name=row[1], description=row[2]) + id=row[0], name=row[1], description=row[2] + ) else: for row in cursor.fetchall(): # ensure that we fully fetch! @@ -178,5 +189,6 @@ def _test_dbapi_raw(n, make_objects): conn.close() -if __name__ == '__main__': + +if __name__ == "__main__": Profiler.main() diff --git a/examples/performance/short_selects.py b/examples/performance/short_selects.py index 6f64aa63e..4a8d401ad 100644 --- a/examples/performance/short_selects.py +++ b/examples/performance/short_selects.py @@ -6,8 +6,14 @@ record by primary key from . import Profiler from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import Column, Integer, String, create_engine, \ - bindparam, select +from sqlalchemy import ( + Column, + Integer, + String, + create_engine, + bindparam, + select, +) from sqlalchemy.orm import Session, deferred from sqlalchemy.ext import baked import random @@ -29,6 +35,7 @@ class Customer(Base): y = deferred(Column(Integer)) z = deferred(Column(Integer)) + Profiler.init("short_selects", num=10000) @@ -39,16 +46,20 @@ def setup_database(dburl, echo, num): Base.metadata.drop_all(engine) Base.metadata.create_all(engine) sess = Session(engine) - sess.add_all([ - Customer( - id=i, name='c%d' % i, description="c%d" % i, - q=i * 10, - p=i * 20, - x=i * 30, - y=i * 40, - ) - for i in ids - ]) + sess.add_all( + [ + Customer( + id=i, + name="c%d" % i, + description="c%d" % i, + q=i * 10, + p=i * 20, + x=i * 30, + y=i * 40, + ) + for i in ids + ] + ) sess.commit() @@ -65,9 +76,9 @@ def test_orm_query_cols_only(n): """test an ORM query of only the entity columns.""" session = Session(bind=engine) for id_ in random.sample(ids, n): - session.query( - Customer.id, Customer.name, Customer.description - ).filter(Customer.id == id_).one() + session.query(Customer.id, Customer.name, Customer.description).filter( + Customer.id == id_ + ).one() @Profiler.profile @@ -77,7 +88,7 @@ def test_baked_query(n): s = Session(bind=engine) for id_ in random.sample(ids, n): q = bakery(lambda s: s.query(Customer)) - q += lambda q: q.filter(Customer.id == bindparam('id')) + q += lambda q: q.filter(Customer.id == bindparam("id")) q(s).params(id=id_).one() @@ -88,9 +99,9 @@ def test_baked_query_cols_only(n): s = Session(bind=engine) for id_ in random.sample(ids, n): q = bakery( - lambda s: s.query( - Customer.id, Customer.name, Customer.description)) - q += lambda q: q.filter(Customer.id == bindparam('id')) + lambda s: s.query(Customer.id, Customer.name, Customer.description) + ) + q += lambda q: q.filter(Customer.id == bindparam("id")) q(s).params(id=id_).one() @@ -109,7 +120,7 @@ def test_core_new_stmt_each_time(n): def test_core_reuse_stmt(n): """test core, reusing the same statement (but recompiling each time).""" - stmt = select([Customer.__table__]).where(Customer.id == bindparam('id')) + stmt = select([Customer.__table__]).where(Customer.id == bindparam("id")) with engine.connect() as conn: for id_ in random.sample(ids, n): @@ -122,13 +133,14 @@ def test_core_reuse_stmt_compiled_cache(n): """test core, reusing the same statement + compiled cache.""" compiled_cache = {} - stmt = select([Customer.__table__]).where(Customer.id == bindparam('id')) - with engine.connect().\ - execution_options(compiled_cache=compiled_cache) as conn: + stmt = select([Customer.__table__]).where(Customer.id == bindparam("id")) + with engine.connect().execution_options( + compiled_cache=compiled_cache + ) as conn: for id_ in random.sample(ids, n): row = conn.execute(stmt, id=id_).first() tuple(row) -if __name__ == '__main__': +if __name__ == "__main__": Profiler.main() diff --git a/examples/performance/single_inserts.py b/examples/performance/single_inserts.py index cfce90300..79e34dfe6 100644 --- a/examples/performance/single_inserts.py +++ b/examples/performance/single_inserts.py @@ -28,7 +28,7 @@ Profiler.init("single_inserts", num=10000) def setup_database(dburl, echo, num): global engine engine = create_engine(dburl, echo=echo) - if engine.dialect.name == 'sqlite': + if engine.dialect.name == "sqlite": engine.pool = pool.StaticPool(creator=engine.pool._creator) Base.metadata.drop_all(engine) Base.metadata.create_all(engine) @@ -42,8 +42,9 @@ def test_orm_commit(n): session = Session(bind=engine) session.add( Customer( - name='customer name %d' % i, - description='customer description %d' % i) + name="customer name %d" % i, + description="customer description %d" % i, + ) ) session.commit() @@ -54,11 +55,14 @@ def test_bulk_save(n): for i in range(n): session = Session(bind=engine) - session.bulk_save_objects([ - Customer( - name='customer name %d' % i, - description='customer description %d' % i - )]) + session.bulk_save_objects( + [ + Customer( + name="customer name %d" % i, + description="customer description %d" % i, + ) + ] + ) session.commit() @@ -68,11 +72,15 @@ def test_bulk_insert_dictionaries(n): for i in range(n): session = Session(bind=engine) - session.bulk_insert_mappings(Customer, [ - dict( - name='customer name %d' % i, - description='customer description %d' % i - )]) + session.bulk_insert_mappings( + Customer, + [ + dict( + name="customer name %d" % i, + description="customer description %d" % i, + ) + ], + ) session.commit() @@ -85,9 +93,9 @@ def test_core(n): conn.execute( Customer.__table__.insert(), dict( - name='customer name %d' % i, - description='customer description %d' % i - ) + name="customer name %d" % i, + description="customer description %d" % i, + ), ) @@ -102,9 +110,9 @@ def test_core_query_caching(n): conn.execution_options(compiled_cache=cache).execute( ins, dict( - name='customer name %d' % i, - description='customer description %d' % i - ) + name="customer name %d" % i, + description="customer description %d" % i, + ), ) @@ -123,20 +131,22 @@ def test_dbapi_raw_w_pool(n): def _test_dbapi_raw(n, connect): - compiled = Customer.__table__.insert().values( - name=bindparam('name'), - description=bindparam('description')).\ - compile(dialect=engine.dialect) + compiled = ( + Customer.__table__.insert() + .values(name=bindparam("name"), description=bindparam("description")) + .compile(dialect=engine.dialect) + ) if compiled.positional: args = ( - ('customer name %d' % i, 'customer description %d' % i) - for i in range(n)) + ("customer name %d" % i, "customer description %d" % i) + for i in range(n) + ) else: args = ( dict( - name='customer name %d' % i, - description='customer description %d' % i + name="customer name %d" % i, + description="customer description %d" % i, ) for i in range(n) ) @@ -162,5 +172,5 @@ def _test_dbapi_raw(n, connect): conn.close() -if __name__ == '__main__': +if __name__ == "__main__": Profiler.main() diff --git a/examples/postgis/__init__.py b/examples/postgis/__init__.py index 250d9ce87..66ae65d3c 100644 --- a/examples/postgis/__init__.py +++ b/examples/postgis/__init__.py @@ -36,4 +36,3 @@ E.g.:: .. autosource:: """ - diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py index ffea3d018..508d63398 100644 --- a/examples/postgis/postgis.py +++ b/examples/postgis/postgis.py @@ -5,6 +5,7 @@ import binascii # Python datatypes + class GisElement(object): """Represents a geometry value.""" @@ -12,16 +13,21 @@ class GisElement(object): return self.desc def __repr__(self): - return "<%s at 0x%x; %r>" % (self.__class__.__name__, - id(self), self.desc) + return "<%s at 0x%x; %r>" % ( + self.__class__.__name__, + id(self), + self.desc, + ) + class BinaryGisElement(GisElement, expression.Function): """Represents a Geometry value expressed as binary.""" def __init__(self, data): self.data = data - expression.Function.__init__(self, "ST_GeomFromEWKB", data, - type_=Geometry(coerce_="binary")) + expression.Function.__init__( + self, "ST_GeomFromEWKB", data, type_=Geometry(coerce_="binary") + ) @property def desc(self): @@ -31,24 +37,26 @@ class BinaryGisElement(GisElement, expression.Function): def as_hex(self): return binascii.hexlify(self.data) + class TextualGisElement(GisElement, expression.Function): """Represents a Geometry value expressed as text.""" def __init__(self, desc, srid=-1): self.desc = desc - expression.Function.__init__(self, "ST_GeomFromText", desc, srid, - type_=Geometry) + expression.Function.__init__( + self, "ST_GeomFromText", desc, srid, type_=Geometry + ) # SQL datatypes. + class Geometry(UserDefinedType): """Base PostGIS Geometry column type.""" name = "GEOMETRY" - def __init__(self, dimension=None, srid=-1, - coerce_="text"): + def __init__(self, dimension=None, srid=-1, coerce_="text"): self.dimension = dimension self.srid = srid self.coerce = coerce_ @@ -58,11 +66,11 @@ class Geometry(UserDefinedType): # override the __eq__() operator def __eq__(self, other): - return self.op('~=')(other) + return self.op("~=")(other) # add a custom operator def intersects(self, other): - return self.op('&&')(other) + return self.op("&&")(other) # any number of GIS operators can be overridden/added here # using the techniques above. @@ -95,6 +103,7 @@ class Geometry(UserDefinedType): return value.desc else: return value + return process def result_processor(self, dialect, coltype): @@ -104,27 +113,35 @@ class Geometry(UserDefinedType): fac = BinaryGisElement else: assert False + def process(value): if value is not None: return fac(value) else: return value + return process def adapt(self, impltype): - return impltype(dimension=self.dimension, - srid=self.srid, coerce_=self.coerce) + return impltype( + dimension=self.dimension, srid=self.srid, coerce_=self.coerce + ) + # other datatypes can be added as needed. + class Point(Geometry): - name = 'POINT' + name = "POINT" + class Curve(Geometry): - name = 'CURVE' + name = "CURVE" + class LineString(Curve): - name = 'LINESTRING' + name = "LINESTRING" + # ... etc. @@ -135,6 +152,7 @@ class LineString(Curve): # versions don't appear to require these special steps anymore. However, # here we illustrate how to set up these features in any case. + def setup_ddl_events(): @event.listens_for(Table, "before_create") def before_create(target, connection, **kw): @@ -153,9 +171,10 @@ def setup_ddl_events(): dispatch("after-drop", target, connection) def dispatch(event, table, bind): - if event in ('before-create', 'before-drop'): - regular_cols = [c for c in table.c if not - isinstance(c.type, Geometry)] + if event in ("before-create", "before-drop"): + regular_cols = [ + c for c in table.c if not isinstance(c.type, Geometry) + ] gis_cols = set(table.c).difference(regular_cols) table.info["_saved_columns"] = table.c @@ -163,85 +182,129 @@ def setup_ddl_events(): # Geometry columns table.columns = expression.ColumnCollection(*regular_cols) - if event == 'before-drop': + if event == "before-drop": for c in gis_cols: bind.execute( - select([ + select( + [ func.DropGeometryColumn( - 'public', table.name, c.name)], - autocommit=True) - ) + "public", table.name, c.name + ) + ], + autocommit=True, + ) + ) - elif event == 'after-create': - table.columns = table.info.pop('_saved_columns') + elif event == "after-create": + table.columns = table.info.pop("_saved_columns") for c in table.c: if isinstance(c.type, Geometry): bind.execute( - select([ - func.AddGeometryColumn( - table.name, c.name, - c.type.srid, - c.type.name, - c.type.dimension)], - autocommit=True) + select( + [ + func.AddGeometryColumn( + table.name, + c.name, + c.type.srid, + c.type.name, + c.type.dimension, + ) + ], + autocommit=True, ) - elif event == 'after-drop': - table.columns = table.info.pop('_saved_columns') -setup_ddl_events() + ) + elif event == "after-drop": + table.columns = table.info.pop("_saved_columns") +setup_ddl_events() + # illustrate usage -if __name__ == '__main__': - from sqlalchemy import (create_engine, MetaData, Column, Integer, String, - func, select) +if __name__ == "__main__": + from sqlalchemy import ( + create_engine, + MetaData, + Column, + Integer, + String, + func, + select, + ) from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base - engine = create_engine('postgresql://scott:tiger@localhost/test', echo=True) + engine = create_engine( + "postgresql://scott:tiger@localhost/test", echo=True + ) metadata = MetaData(engine) Base = declarative_base(metadata=metadata) class Road(Base): - __tablename__ = 'roads' + __tablename__ = "roads" road_id = Column(Integer, primary_key=True) road_name = Column(String) road_geom = Column(Geometry(2)) - metadata.drop_all() metadata.create_all() session = sessionmaker(bind=engine)() # Add objects. We can use strings... - session.add_all([ - Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'), - Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'), - Road(road_name='Paul St', road_geom='LINESTRING(192783 228138,192612 229814)'), - Road(road_name='Graeme Ave', road_geom='LINESTRING(189412 252431,189631 259122)'), - Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'), - ]) + session.add_all( + [ + Road( + road_name="Jeff Rd", + road_geom="LINESTRING(191232 243118,191108 243242)", + ), + Road( + road_name="Geordie Rd", + road_geom="LINESTRING(189141 244158,189265 244817)", + ), + Road( + road_name="Paul St", + road_geom="LINESTRING(192783 228138,192612 229814)", + ), + Road( + road_name="Graeme Ave", + road_geom="LINESTRING(189412 252431,189631 259122)", + ), + Road( + road_name="Phil Tce", + road_geom="LINESTRING(190131 224148,190871 228134)", + ), + ] + ) # or use an explicit TextualGisElement (similar to saying func.GeomFromText()) - r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1)) + r = Road( + road_name="Dave Cres", + road_geom=TextualGisElement( + "LINESTRING(198231 263418,198213 268322)", -1 + ), + ) session.add(r) # pre flush, the TextualGisElement represents the string we sent. - assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)' + assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)" session.commit() # after flush and/or commit, all the TextualGisElements become PersistentGisElements. assert str(r.road_geom) == "LINESTRING(198231 263418,198213 268322)" - r1 = session.query(Road).filter(Road.road_name == 'Graeme Ave').one() + r1 = session.query(Road).filter(Road.road_name == "Graeme Ave").one() # illustrate the overridden __eq__() operator. # strings come in as TextualGisElements - r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one() + r2 = ( + session.query(Road) + .filter(Road.road_geom == "LINESTRING(189412 252431,189631 259122)") + .one() + ) r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one() @@ -250,22 +313,29 @@ if __name__ == '__main__': # core usage just fine: road_table = Road.__table__ - stmt = select([road_table]).where(road_table.c.road_geom.intersects(r1.road_geom)) + stmt = select([road_table]).where( + road_table.c.road_geom.intersects(r1.road_geom) + ) print(session.execute(stmt).fetchall()) # TODO: for some reason the auto-generated labels have the internal replacement # strings exposed, even though PG doesn't complain # look up the hex binary version, using SQLAlchemy casts - as_binary = session.scalar(select([type_coerce(r.road_geom, Geometry(coerce_="binary"))])) - assert as_binary.as_hex == \ - '01020000000200000000000000b832084100000000e813104100000000283208410000000088601041' + as_binary = session.scalar( + select([type_coerce(r.road_geom, Geometry(coerce_="binary"))]) + ) + assert ( + as_binary.as_hex + == "01020000000200000000000000b832084100000000e813104100000000283208410000000088601041" + ) # back again, same method ! - as_text = session.scalar(select([type_coerce(as_binary, Geometry(coerce_="text"))])) + as_text = session.scalar( + select([type_coerce(as_binary, Geometry(coerce_="text"))]) + ) assert as_text.desc == "LINESTRING(198231 263418,198213 268322)" - session.rollback() metadata.drop_all() diff --git a/examples/sharding/attribute_shard.py b/examples/sharding/attribute_shard.py index 0e19b69f3..48a3dc932 100644 --- a/examples/sharding/attribute_shard.py +++ b/examples/sharding/attribute_shard.py @@ -1,5 +1,13 @@ -from sqlalchemy import (create_engine, Table, Column, Integer, - String, ForeignKey, Float, DateTime) +from sqlalchemy import ( + create_engine, + Table, + Column, + Integer, + String, + ForeignKey, + Float, + DateTime, +) from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.sql import operators, visitors @@ -12,22 +20,24 @@ import datetime # causes the id_generator() to use the same connection as that # of an ongoing transaction within db1. echo = True -db1 = create_engine('sqlite://', echo=echo, pool_threadlocal=True) -db2 = create_engine('sqlite://', echo=echo) -db3 = create_engine('sqlite://', echo=echo) -db4 = create_engine('sqlite://', echo=echo) +db1 = create_engine("sqlite://", echo=echo, pool_threadlocal=True) +db2 = create_engine("sqlite://", echo=echo) +db3 = create_engine("sqlite://", echo=echo) +db4 = create_engine("sqlite://", echo=echo) # create session function. this binds the shard ids # to databases within a ShardedSession and returns it. create_session = sessionmaker(class_=ShardedSession) -create_session.configure(shards={ - 'north_america': db1, - 'asia': db2, - 'europe': db3, - 'south_america': db4 -}) +create_session.configure( + shards={ + "north_america": db1, + "asia": db2, + "europe": db3, + "south_america": db4, + } +) # mappings and tables @@ -40,9 +50,7 @@ Base = declarative_base() # #1. Any other method will do just as well; UUID, hilo, application-specific, # etc. -ids = Table( - 'ids', Base.metadata, - Column('nextid', Integer, nullable=False)) +ids = Table("ids", Base.metadata, Column("nextid", Integer, nullable=False)) def id_generator(ctx): @@ -52,6 +60,7 @@ def id_generator(ctx): conn.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1})) return nextid + # table setup. we'll store a lead table of continents/cities, and a secondary # table storing locations. a particular row will be placed in the database # whose shard id corresponds to the 'continent'. in this setup, secondary rows @@ -67,7 +76,7 @@ class WeatherLocation(Base): continent = Column(String(30), nullable=False) city = Column(String(50), nullable=False) - reports = relationship("Report", backref='location') + reports = relationship("Report", backref="location") def __init__(self, continent, city): self.continent = continent @@ -79,14 +88,17 @@ class Report(Base): id = Column(Integer, primary_key=True) location_id = Column( - 'location_id', Integer, ForeignKey('weather_locations.id')) - temperature = Column('temperature', Float) + "location_id", Integer, ForeignKey("weather_locations.id") + ) + temperature = Column("temperature", Float) report_time = Column( - 'report_time', DateTime, default=datetime.datetime.now) + "report_time", DateTime, default=datetime.datetime.now + ) def __init__(self, temperature): self.temperature = temperature + # create tables for db in (db1, db2, db3, db4): Base.metadata.drop_all(db) @@ -101,10 +113,10 @@ db1.execute(ids.insert(), nextid=1) # we'll use a straight mapping of a particular set of "country" # attributes to shard id. shard_lookup = { - 'North America': 'north_america', - 'Asia': 'asia', - 'Europe': 'europe', - 'South America': 'south_america' + "North America": "north_america", + "Asia": "asia", + "Europe": "europe", + "South America": "south_america", } @@ -139,7 +151,7 @@ def id_chooser(query, ident): # set things up. return [query.lazy_loaded_from.identity_token] else: - return ['north_america', 'asia', 'europe', 'south_america'] + return ["north_america", "asia", "europe", "south_america"] def query_chooser(query): @@ -168,7 +180,7 @@ def query_chooser(query): ids.extend(shard_lookup[v] for v in value) if len(ids) == 0: - return ['north_america', 'asia', 'europe', 'south_america'] + return ["north_america", "asia", "europe", "south_america"] else: return ids @@ -208,13 +220,16 @@ def _get_query_comparisons(query): def visit_binary(binary): # special handling for "col IN (params)" - if binary.left in clauses and \ - binary.operator == operators.in_op and \ - hasattr(binary.right, 'clauses'): + if ( + binary.left in clauses + and binary.operator == operators.in_op + and hasattr(binary.right, "clauses") + ): comparisons.append( ( - binary.left, binary.operator, - tuple(binds[bind] for bind in binary.right.clauses) + binary.left, + binary.operator, + tuple(binds[bind] for bind in binary.right.clauses), ) ) elif binary.left in clauses and binary.right in binds: @@ -232,29 +247,33 @@ def _get_query_comparisons(query): # into a list. if query._criterion is not None: visitors.traverse_depthfirst( - query._criterion, {}, - {'bindparam': visit_bindparam, - 'binary': visit_binary, - 'column': visit_column} + query._criterion, + {}, + { + "bindparam": visit_bindparam, + "binary": visit_binary, + "column": visit_column, + }, ) return comparisons + # further configure create_session to use these functions create_session.configure( shard_chooser=shard_chooser, id_chooser=id_chooser, - query_chooser=query_chooser + query_chooser=query_chooser, ) # save and load objects! -tokyo = WeatherLocation('Asia', 'Tokyo') -newyork = WeatherLocation('North America', 'New York') -toronto = WeatherLocation('North America', 'Toronto') -london = WeatherLocation('Europe', 'London') -dublin = WeatherLocation('Europe', 'Dublin') -brasilia = WeatherLocation('South America', 'Brasila') -quito = WeatherLocation('South America', 'Quito') +tokyo = WeatherLocation("Asia", "Tokyo") +newyork = WeatherLocation("North America", "New York") +toronto = WeatherLocation("North America", "Toronto") +london = WeatherLocation("Europe", "London") +dublin = WeatherLocation("Europe", "Dublin") +brasilia = WeatherLocation("South America", "Brasila") +quito = WeatherLocation("South America", "Quito") tokyo.reports.append(Report(80.0)) newyork.reports.append(Report(75)) @@ -271,12 +290,14 @@ assert t.city == tokyo.city assert t.reports[0].temperature == 80.0 north_american_cities = sess.query(WeatherLocation).filter( - WeatherLocation.continent == 'North America') -assert {c.city for c in north_american_cities} == {'New York', 'Toronto'} + WeatherLocation.continent == "North America" +) +assert {c.city for c in north_american_cities} == {"New York", "Toronto"} asia_and_europe = sess.query(WeatherLocation).filter( - WeatherLocation.continent.in_(['Europe', 'Asia'])) -assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'} + WeatherLocation.continent.in_(["Europe", "Asia"]) +) +assert {c.city for c in asia_and_europe} == {"Tokyo", "London", "Dublin"} # the Report class uses a simple integer primary key. So across two databases, # a primary key will be repeated. The "identity_token" tracks in memory @@ -284,8 +305,8 @@ assert {c.city for c in asia_and_europe} == {'Tokyo', 'London', 'Dublin'} newyork_report = newyork.reports[0] tokyo_report = tokyo.reports[0] -assert inspect(newyork_report).identity_key == (Report, (1, ), "north_america") -assert inspect(tokyo_report).identity_key == (Report, (1, ), "asia") +assert inspect(newyork_report).identity_key == (Report, (1,), "north_america") +assert inspect(tokyo_report).identity_key == (Report, (1,), "asia") # the token representing the originating shard is also available directly diff --git a/examples/space_invaders/__init__.py b/examples/space_invaders/__init__.py index 8816045dc..944f8bb46 100644 --- a/examples/space_invaders/__init__.py +++ b/examples/space_invaders/__init__.py @@ -21,4 +21,4 @@ enjoy! .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/space_invaders/space_invaders.py b/examples/space_invaders/space_invaders.py index 3ce280aec..d5437d8cf 100644 --- a/examples/space_invaders/space_invaders.py +++ b/examples/space_invaders/space_invaders.py @@ -1,7 +1,6 @@ import sys from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import create_engine, Integer, Column, ForeignKey, \ - String, func +from sqlalchemy import create_engine, Integer, Column, ForeignKey, String, func from sqlalchemy.orm import relationship, Session, joinedload from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method import curses @@ -18,7 +17,7 @@ if _PY3: logging.basicConfig( filename="space_invaders.log", - format="%(asctime)s,%(msecs)03d %(levelname)-5.5s %(message)s" + format="%(asctime)s,%(msecs)03d %(levelname)-5.5s %(message)s", ) logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) @@ -47,7 +46,7 @@ COLOR_MAP = { "M": curses.COLOR_MAGENTA, "R": curses.COLOR_RED, "W": curses.COLOR_WHITE, - "Y": curses.COLOR_YELLOW + "Y": curses.COLOR_YELLOW, } @@ -56,7 +55,8 @@ class Glyph(Base): to be painted on the screen. """ - __tablename__ = 'glyph' + + __tablename__ = "glyph" id = Column(Integer, primary_key=True) name = Column(String) type = Column(String) @@ -68,11 +68,9 @@ class Glyph(Base): def __init__(self, name, img, alt=None): self.name = name - self.data, self.width, self.height = \ - self._encode_glyph(img) + self.data, self.width, self.height = self._encode_glyph(img) if alt is not None: - self.alt_data, alt_w, alt_h = \ - self._encode_glyph(alt) + self.alt_data, alt_w, alt_h = self._encode_glyph(alt) def _encode_glyph(self, img): """Receive a textual description of the glyph and @@ -80,7 +78,7 @@ class Glyph(Base): GlyphCoordinate.render(). """ - img = re.sub(r'^\n', "", textwrap.dedent(img)) + img = re.sub(r"^\n", "", textwrap.dedent(img)) color = "W" lines = [line.rstrip() for line in img.split("\n")] data = [] @@ -89,15 +87,15 @@ class Glyph(Base): line = list(line) while line: char = line.pop(0) - if char == '#': + if char == "#": color = line.pop(0) continue render_line.append((color, char)) data.append(render_line) width = max([len(rl) for rl in data]) data = "".join( - "".join("%s%s" % (color, char) for color, char in render_line) + - ("W " * (width - len(render_line))) + "".join("%s%s" % (color, char) for color, char in render_line) + + ("W " * (width - len(render_line))) for render_line in data ) return data, width, len(lines) @@ -121,9 +119,10 @@ class GlyphCoordinate(Base): score value. """ - __tablename__ = 'glyph_coordinate' + + __tablename__ = "glyph_coordinate" id = Column(Integer, primary_key=True) - glyph_id = Column(Integer, ForeignKey('glyph.id')) + glyph_id = Column(Integer, ForeignKey("glyph.id")) x = Column(Integer) y = Column(Integer) tick = Column(Integer) @@ -132,11 +131,9 @@ class GlyphCoordinate(Base): glyph = relationship(Glyph, innerjoin=True) def __init__( - self, - session, glyph_name, x, y, - tick=None, label=None, score=None): - self.glyph = session.query(Glyph).\ - filter_by(name=glyph_name).one() + self, session, glyph_name, x, y, tick=None, label=None, score=None + ): + self.glyph = session.query(Glyph).filter_by(name=glyph_name).one() self.x = x self.y = y self.tick = tick @@ -152,8 +149,7 @@ class GlyphCoordinate(Base): glyph = self.glyph data = glyph.glyph_for_state(self, state) for color, char in [ - (data[i], data[i + 1]) - for i in xrange(0, len(data), 2) + (data[i], data[i + 1]) for i in xrange(0, len(data), 2) ]: x = self.x + col @@ -163,7 +159,8 @@ class GlyphCoordinate(Base): y + VERT_PADDING, x + HORIZ_PADDING, char, - _COLOR_PAIRS[color]) + _COLOR_PAIRS[color], + ) col += 1 if col == glyph.width: col = 0 @@ -186,10 +183,7 @@ class GlyphCoordinate(Base): width = min(glyph.width, MAX_X - x) or 1 for y_a in xrange(self.y, self.y + glyph.height): y = y_a - window.addstr( - y + VERT_PADDING, - x + HORIZ_PADDING, - " " * width) + window.addstr(y + VERT_PADDING, x + HORIZ_PADDING, " " * width) if self.label: self._render_label(window, True) @@ -236,21 +230,22 @@ class GlyphCoordinate(Base): the given GlyphCoordinate.""" return ~( - (self.x + self.width < other.x) | - (self.x > other.x + other.width) + (self.x + self.width < other.x) | (self.x > other.x + other.width) ) & ~( - (self.y + self.height < other.y) | - (self.y > other.y + other.height) + (self.y + self.height < other.y) + | (self.y > other.y + other.height) ) class EnemyGlyph(Glyph): """Describe an enemy.""" + __mapper_args__ = {"polymorphic_identity": "enemy"} class ArmyGlyph(EnemyGlyph): """Describe an enemy that's part of the "army". """ + __mapper_args__ = {"polymorphic_identity": "army"} def glyph_for_state(self, coord, state): @@ -262,6 +257,7 @@ class ArmyGlyph(EnemyGlyph): class SaucerGlyph(EnemyGlyph): """Describe the enemy saucer flying overhead.""" + __mapper_args__ = {"polymorphic_identity": "saucer"} def glyph_for_state(self, coord, state): @@ -273,21 +269,25 @@ class SaucerGlyph(EnemyGlyph): class MessageGlyph(Glyph): """Describe a glyph for displaying a message.""" + __mapper_args__ = {"polymorphic_identity": "message"} class PlayerGlyph(Glyph): """Describe a glyph representing the player.""" + __mapper_args__ = {"polymorphic_identity": "player"} class MissileGlyph(Glyph): """Describe a glyph representing a missile.""" + __mapper_args__ = {"polymorphic_identity": "missile"} class SplatGlyph(Glyph): """Describe a glyph representing a "splat".""" + __mapper_args__ = {"polymorphic_identity": "splat"} def glyph_for_state(self, coord, state): @@ -302,36 +302,39 @@ def init_glyph(session): """Create the glyphs used during play.""" enemy1 = ArmyGlyph( - "enemy1", """ + "enemy1", + """ #W-#B^#R-#B^#W- #G| | """, """ #W>#B^#R-#B^#W< #G^ ^ - """ + """, ) enemy2 = ArmyGlyph( - "enemy2", """ + "enemy2", + """ #W*** #R<#C~~~#R> """, """ #W@@@ #R<#C---#R> - """ + """, ) enemy3 = ArmyGlyph( - "enemy3", """ + "enemy3", + """ #Y((--)) #M-~-~-~ """, """ #Y[[--]] #M~-~-~- - """ + """, ) saucer = SaucerGlyph( @@ -351,35 +354,49 @@ def init_glyph(session): #M| #M- #Y+++ #M- #M| - """ + """, ) - ship = PlayerGlyph("ship", """ + ship = PlayerGlyph( + "ship", + """ #Y^ #G===== - """) + """, + ) - missile = MissileGlyph("missile", """ + missile = MissileGlyph( + "missile", + """ | - """) + """, + ) start = MessageGlyph( "start_message", "J = move left; L = move right; SPACE = fire\n" - " #GPress any key to start") - lose = MessageGlyph("lose_message", - "#YY O U L O S E ! ! !") - win = MessageGlyph( - "win_message", - "#RL E V E L C L E A R E D ! ! !" + " #GPress any key to start", ) + lose = MessageGlyph("lose_message", "#YY O U L O S E ! ! !") + win = MessageGlyph("win_message", "#RL E V E L C L E A R E D ! ! !") paused = MessageGlyph( - "pause_message", - "#WP A U S E D\n#GPress P to continue") + "pause_message", "#WP A U S E D\n#GPress P to continue" + ) session.add_all( - [enemy1, enemy2, enemy3, ship, saucer, - missile, start, lose, win, - paused, splat1]) + [ + enemy1, + enemy2, + enemy3, + ship, + saucer, + missile, + start, + lose, + win, + paused, + splat1, + ] + ) def setup_curses(): @@ -392,7 +409,8 @@ def setup_curses(): WINDOW_HEIGHT + (VERT_PADDING * 2), WINDOW_WIDTH + (HORIZ_PADDING * 2), WINDOW_TOP - VERT_PADDING, - WINDOW_LEFT - HORIZ_PADDING) + WINDOW_LEFT - HORIZ_PADDING, + ) curses.start_color() global _COLOR_PAIRS @@ -416,24 +434,25 @@ def init_positions(session): session.add( GlyphCoordinate( - session, "ship", - WINDOW_WIDTH // 2 - 2, - WINDOW_HEIGHT - 4) + session, "ship", WINDOW_WIDTH // 2 - 2, WINDOW_HEIGHT - 4 + ) ) arrangement = ( - ("enemy3", 50), ("enemy2", 25), - ("enemy1", 10), ("enemy2", 25), - ("enemy1", 10)) + ("enemy3", 50), + ("enemy2", 25), + ("enemy1", 10), + ("enemy2", 25), + ("enemy1", 10), + ) for (ship_vert, (etype, score)) in zip( - xrange(5, 30, ENEMY_VERT_SPACING), arrangement): + xrange(5, 30, ENEMY_VERT_SPACING), arrangement + ): for ship_horiz in xrange(0, 50, 10): session.add( GlyphCoordinate( - session, etype, - ship_horiz, - ship_vert, - score=score) + session, etype, ship_horiz, ship_vert, score=score + ) ) @@ -442,12 +461,9 @@ def draw(session, window, state): database and render. """ - for gcoord in session.query(GlyphCoordinate).\ - options(joinedload("glyph")): + for gcoord in session.query(GlyphCoordinate).options(joinedload("glyph")): gcoord.render(window, state) - window.addstr( - 1, WINDOW_WIDTH - 5, - "Score: %.4d" % state['score']) + window.addstr(1, WINDOW_WIDTH - 5, "Score: %.4d" % state["score"]) window.move(0, 0) window.refresh() @@ -456,11 +472,11 @@ def check_win(session, state): """Return the number of army glyphs remaining - the player wins if this is zero.""" - return session.query( - func.count(GlyphCoordinate.id) - ).join( - GlyphCoordinate.glyph.of_type(ArmyGlyph) - ).scalar() + return ( + session.query(func.count(GlyphCoordinate.id)) + .join(GlyphCoordinate.glyph.of_type(ArmyGlyph)) + .scalar() + ) def check_lose(session, state): @@ -470,12 +486,14 @@ def check_lose(session, state): The player loses if this is non-zero.""" player = state["player"] - return session.query(GlyphCoordinate).join( - GlyphCoordinate.glyph.of_type(ArmyGlyph) - ).filter( - GlyphCoordinate.intersects(player) | - GlyphCoordinate.bottom_bound - ).count() + return ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph.of_type(ArmyGlyph)) + .filter( + GlyphCoordinate.intersects(player) | GlyphCoordinate.bottom_bound + ) + .count() + ) def render_message(session, window, msg, x, y): @@ -490,9 +508,11 @@ def render_message(session, window, msg, x, y): msg = GlyphCoordinate(session, msg, x, y) # clear existing glyphs which intersect - for gly in session.query(GlyphCoordinate).join( - GlyphCoordinate.glyph - ).filter(GlyphCoordinate.intersects(msg)): + for gly in ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph) + .filter(GlyphCoordinate.intersects(msg)) + ): gly.blank(window) # render @@ -551,12 +571,14 @@ def move_army(session, window, state): # get the lower/upper boundaries of the army # along the X axis. - min_x, max_x = session.query( - func.min(GlyphCoordinate.x), - func.max(GlyphCoordinate.x + GlyphCoordinate.width), - ).join( - GlyphCoordinate.glyph.of_type(ArmyGlyph) - ).first() + min_x, max_x = ( + session.query( + func.min(GlyphCoordinate.x), + func.max(GlyphCoordinate.x + GlyphCoordinate.width), + ) + .join(GlyphCoordinate.glyph.of_type(ArmyGlyph)) + .first() + ) if min_x is None or max_x is None: # no enemies @@ -603,27 +625,26 @@ def move_player(session, window, state): player.x -= 1 elif ch == FIRE_KEY and state["missile"] is None: state["missile"] = GlyphCoordinate( - session, - "missile", - player.x + 3, - player.y - 1) + session, "missile", player.x + 3, player.y - 1 + ) def move_missile(session, window, state): """Update the status of the current missile, if any.""" - if state["missile"] is None or \ - state["tick"] % 2 != 0: + if state["missile"] is None or state["tick"] % 2 != 0: return missile = state["missile"] # locate enemy glyphs which intersect with the # missile's current position; i.e. a hit - glyph = session.query(GlyphCoordinate).\ - join(GlyphCoordinate.glyph.of_type(EnemyGlyph)).\ - filter(GlyphCoordinate.intersects(missile)).\ - first() + glyph = ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph.of_type(EnemyGlyph)) + .filter(GlyphCoordinate.intersects(missile)) + .first() + ) missile.blank(window) if glyph or missile.top_bound: # missle is done @@ -642,15 +663,13 @@ def move_saucer(session, window, state): saucer_interval = 500 saucer_speed_interval = 4 - if state["saucer"] is None and \ - state["tick"] % saucer_interval != 0: + if state["saucer"] is None and state["tick"] % saucer_interval != 0: return if state["saucer"] is None: state["saucer"] = saucer = GlyphCoordinate( - session, - "saucer", -6, 1, - score=random.randrange(100, 600, 100)) + session, "saucer", -6, 1, score=random.randrange(100, 600, 100) + ) elif state["tick"] % saucer_speed_interval == 0: saucer = state["saucer"] saucer.blank(window) @@ -663,8 +682,9 @@ def move_saucer(session, window, state): def update_splat(session, window, state): """Render splat animations.""" - for splat in session.query(GlyphCoordinate).\ - join(GlyphCoordinate.glyph.of_type(SplatGlyph)): + for splat in session.query(GlyphCoordinate).join( + GlyphCoordinate.glyph.of_type(SplatGlyph) + ): age = state["tick"] - splat.tick if age > 10: splat.blank(window) @@ -683,8 +703,13 @@ def score(session, window, state, glyph): state["score"] += glyph.score # render a splat ! GlyphCoordinate( - session, "splat1", glyph.x, glyph.y, - tick=state["tick"], label=str(glyph.score)) + session, + "splat1", + glyph.x, + glyph.y, + tick=state["tick"], + label=str(glyph.score), + ) def update_state(session, window, state): @@ -713,19 +738,23 @@ def start(session, window, state, continue_=False): init_positions(session) - player = session.query(GlyphCoordinate).join( - GlyphCoordinate.glyph.of_type(PlayerGlyph) - ).one() - state.update({ - "field_pos": 0, - "alt": False, - "tick": 0, - "missile": None, - "saucer": None, - "player": player, - "army_direction": 0, - "flip": False - }) + player = ( + session.query(GlyphCoordinate) + .join(GlyphCoordinate.glyph.of_type(PlayerGlyph)) + .one() + ) + state.update( + { + "field_pos": 0, + "alt": False, + "tick": 0, + "missile": None, + "saucer": None, + "player": player, + "army_direction": 0, + "flip": False, + } + ) if not continue_: state["score"] = 0 @@ -748,7 +777,8 @@ def main(): while True: update_state(session, window, state) draw(session, window, state) - time.sleep(.01) + time.sleep(0.01) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/versioned_history/__init__.py b/examples/versioned_history/__init__.py index 7478450ac..7670cd613 100644 --- a/examples/versioned_history/__init__.py +++ b/examples/versioned_history/__init__.py @@ -60,4 +60,4 @@ can be applied:: .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/versioned_history/history_meta.py b/examples/versioned_history/history_meta.py index bad60a398..749a6a5ca 100644 --- a/examples/versioned_history/history_meta.py +++ b/examples/versioned_history/history_meta.py @@ -30,7 +30,7 @@ def _history_mapper(local_mapper): getattr(local_mapper.class_, prop.key).impl.active_history = True super_mapper = local_mapper.inherits - super_history_mapper = getattr(cls, '__history_mapper__', None) + super_history_mapper = getattr(cls, "__history_mapper__", None) polymorphic_on = None super_fks = [] @@ -38,18 +38,20 @@ def _history_mapper(local_mapper): def _col_copy(col): orig = col col = col.copy() - orig.info['history_copy'] = col + orig.info["history_copy"] = col col.unique = False col.default = col.server_default = None col.autoincrement = False return col properties = util.OrderedDict() - if not super_mapper or \ - local_mapper.local_table is not super_mapper.local_table: + if ( + not super_mapper + or local_mapper.local_table is not super_mapper.local_table + ): cols = [] version_meta = {"version_meta": True} # add column.info to identify - # columns specific to versioning + # columns specific to versioning for column in local_mapper.local_table.c: if _is_versioning_col(column): @@ -57,12 +59,13 @@ def _history_mapper(local_mapper): col = _col_copy(column) - if super_mapper and \ - col_references_table(column, super_mapper.local_table): + if super_mapper and col_references_table( + column, super_mapper.local_table + ): super_fks.append( ( col.key, - list(super_history_mapper.local_table.primary_key)[0] + list(super_history_mapper.local_table.primary_key)[0], ) ) @@ -73,38 +76,48 @@ def _history_mapper(local_mapper): orig_prop = local_mapper.get_property_by_column(column) # carry over column re-mappings - if len(orig_prop.columns) > 1 or \ - orig_prop.columns[0].key != orig_prop.key: + if ( + len(orig_prop.columns) > 1 + or orig_prop.columns[0].key != orig_prop.key + ): properties[orig_prop.key] = tuple( - col.info['history_copy'] for col in orig_prop.columns) + col.info["history_copy"] for col in orig_prop.columns + ) if super_mapper: super_fks.append( - ( - 'version', super_history_mapper.local_table.c.version - ) + ("version", super_history_mapper.local_table.c.version) ) # "version" stores the integer version id. This column is # required. cols.append( Column( - 'version', Integer, primary_key=True, - autoincrement=False, info=version_meta)) + "version", + Integer, + primary_key=True, + autoincrement=False, + info=version_meta, + ) + ) # "changed" column stores the UTC timestamp of when the # history row was created. # This column is optional and can be omitted. - cols.append(Column( - 'changed', DateTime, - default=datetime.datetime.utcnow, - info=version_meta)) + cols.append( + Column( + "changed", + DateTime, + default=datetime.datetime.utcnow, + info=version_meta, + ) + ) if super_fks: cols.append(ForeignKeyConstraint(*zip(*super_fks))) table = Table( - local_mapper.local_table.name + '_history', + local_mapper.local_table.name + "_history", local_mapper.local_table.metadata, *cols, schema=local_mapper.local_table.schema @@ -122,9 +135,8 @@ def _history_mapper(local_mapper): bases = (super_history_mapper.class_,) if table is not None: - properties['changed'] = ( - (table.c.changed, ) + - tuple(super_history_mapper.attrs.changed.columns) + properties["changed"] = (table.c.changed,) + tuple( + super_history_mapper.attrs.changed.columns ) else: @@ -137,16 +149,17 @@ def _history_mapper(local_mapper): inherits=super_history_mapper, polymorphic_on=polymorphic_on, polymorphic_identity=local_mapper.polymorphic_identity, - properties=properties + properties=properties, ) cls.__history_mapper__ = m if not super_history_mapper: local_mapper.local_table.append_column( - Column('version', Integer, default=1, nullable=False) + Column("version", Integer, default=1, nullable=False) ) local_mapper.add_property( - "version", local_mapper.local_table.c.version) + "version", local_mapper.local_table.c.version + ) class Versioned(object): @@ -156,16 +169,17 @@ class Versioned(object): mp = mapper(cls, *arg, **kw) _history_mapper(mp) return mp + return map - __table_args__ = {'sqlite_autoincrement': True} + __table_args__ = {"sqlite_autoincrement": True} """Use sqlite_autoincrement, to ensure unique integer values are used for new rows even for rows taht have been deleted.""" def versioned_objects(iter): for obj in iter: - if hasattr(obj, '__history_mapper__'): + if hasattr(obj, "__history_mapper__"): yield obj @@ -181,8 +195,7 @@ def create_version(obj, session, deleted=False): obj_changed = False for om, hm in zip( - obj_mapper.iterate_to_root(), - history_mapper.iterate_to_root() + obj_mapper.iterate_to_root(), history_mapper.iterate_to_root() ): if hm.single: continue @@ -228,10 +241,12 @@ def create_version(obj, session, deleted=False): # not changed, but we have relationships. OK # check those too for prop in obj_mapper.iterate_properties: - if isinstance(prop, RelationshipProperty) and \ - attributes.get_history( - obj, prop.key, - passive=attributes.PASSIVE_NO_INITIALIZE).has_changes(): + if ( + isinstance(prop, RelationshipProperty) + and attributes.get_history( + obj, prop.key, passive=attributes.PASSIVE_NO_INITIALIZE + ).has_changes() + ): for p in prop.local_columns: if p.foreign_keys: obj_changed = True @@ -242,7 +257,7 @@ def create_version(obj, session, deleted=False): if not obj_changed and not deleted: return - attr['version'] = obj.version + attr["version"] = obj.version hist = history_cls() for key, value in attr.items(): setattr(hist, key, value) @@ -251,7 +266,7 @@ def create_version(obj, session, deleted=False): def versioned_session(session): - @event.listens_for(session, 'before_flush') + @event.listens_for(session, "before_flush") def before_flush(session, flush_context, instances): for obj in versioned_objects(session.dirty): create_version(obj, session) diff --git a/examples/versioned_history/test_versioning.py b/examples/versioned_history/test_versioning.py index 37ef73936..3270ad5fd 100644 --- a/examples/versioned_history/test_versioning.py +++ b/examples/versioned_history/test_versioning.py @@ -4,10 +4,22 @@ module functions.""" from unittest import TestCase from sqlalchemy.ext.declarative import declarative_base from .history_meta import Versioned, versioned_session -from sqlalchemy import create_engine, Column, Integer, String, \ - ForeignKey, Boolean, select -from sqlalchemy.orm import clear_mappers, Session, deferred, relationship, \ - column_property +from sqlalchemy import ( + create_engine, + Column, + Integer, + String, + ForeignKey, + Boolean, + select, +) +from sqlalchemy.orm import ( + clear_mappers, + Session, + deferred, + relationship, + column_property, +) from sqlalchemy.testing import AssertsCompiledSQL, eq_, assert_raises, ne_ from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.orm import exc as orm_exc @@ -20,11 +32,11 @@ engine = None def setup_module(): global engine - engine = create_engine('sqlite://', echo=True) + engine = create_engine("sqlite://", echo=True) class TestVersioning(TestCase, AssertsCompiledSQL): - __dialect__ = 'default' + __dialect__ = "default" def setUp(self): self.session = Session(engine) @@ -41,18 +53,18 @@ class TestVersioning(TestCase, AssertsCompiledSQL): def test_plain(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() - sc.name = 'sc1modified' + sc.name = "sc1modified" sess.commit() assert sc.version == 2 @@ -60,56 +72,60 @@ class TestVersioning(TestCase, AssertsCompiledSQL): SomeClassHistory = SomeClass.__history_mapper__.class_ eq_( - sess.query(SomeClassHistory).filter( - SomeClassHistory.version == 1).all(), - [SomeClassHistory(version=1, name='sc1')] + sess.query(SomeClassHistory) + .filter(SomeClassHistory.version == 1) + .all(), + [SomeClassHistory(version=1, name="sc1")], ) - sc.name = 'sc1modified2' + sc.name = "sc1modified2" eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1'), - SomeClassHistory(version=2, name='sc1modified') - ] + SomeClassHistory(version=1, name="sc1"), + SomeClassHistory(version=2, name="sc1modified"), + ], ) assert sc.version == 3 sess.commit() - sc.name = 'temp' - sc.name = 'sc1modified2' + sc.name = "temp" + sc.name = "sc1modified2" sess.commit() eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1'), - SomeClassHistory(version=2, name='sc1modified') - ] + SomeClassHistory(version=1, name="sc1"), + SomeClassHistory(version=2, name="sc1modified"), + ], ) sess.delete(sc) sess.commit() eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1'), - SomeClassHistory(version=2, name='sc1modified'), - SomeClassHistory(version=3, name='sc1modified2') - ] + SomeClassHistory(version=1, name="sc1"), + SomeClassHistory(version=2, name="sc1modified"), + SomeClassHistory(version=3, name="sc1modified2"), + ], ) def test_w_mapper_versioning(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -118,27 +134,24 @@ class TestVersioning(TestCase, AssertsCompiledSQL): self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() s2 = Session(sess.bind) sc2 = s2.query(SomeClass).first() - sc2.name = 'sc1modified' + sc2.name = "sc1modified" - sc.name = 'sc1modified_again' + sc.name = "sc1modified_again" sess.commit() eq_(sc.version, 2) - assert_raises( - orm_exc.StaleDataError, - s2.flush - ) + assert_raises(orm_exc.StaleDataError, s2.flush) def test_from_null(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -149,14 +162,14 @@ class TestVersioning(TestCase, AssertsCompiledSQL): sess.add(sc) sess.commit() - sc.name = 'sc1' + sc.name = "sc1" sess.commit() assert sc.version == 2 def test_insert_null(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) boole = Column(Boolean, default=False) @@ -176,9 +189,10 @@ class TestVersioning(TestCase, AssertsCompiledSQL): SomeClassHistory = SomeClass.__history_mapper__.class_ eq_( - sess.query(SomeClassHistory.boole).order_by( - SomeClassHistory.id).all(), - [(True, ), (None, )] + sess.query(SomeClassHistory.boole) + .order_by(SomeClassHistory.id) + .all(), + [(True,), (None,)], ) eq_(sc.version, 3) @@ -187,7 +201,7 @@ class TestVersioning(TestCase, AssertsCompiledSQL): """test versioning of unloaded, deferred columns.""" class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) @@ -195,15 +209,15 @@ class TestVersioning(TestCase, AssertsCompiledSQL): self.create_tables() sess = self.session - sc = SomeClass(name='sc1', data='somedata') + sc = SomeClass(name="sc1", data="somedata") sess.add(sc) sess.commit() sess.close() sc = sess.query(SomeClass).first() - assert 'data' not in sc.__dict__ + assert "data" not in sc.__dict__ - sc.name = 'sc1modified' + sc.name = "sc1modified" sess.commit() assert sc.version == 2 @@ -211,137 +225,149 @@ class TestVersioning(TestCase, AssertsCompiledSQL): SomeClassHistory = SomeClass.__history_mapper__.class_ eq_( - sess.query(SomeClassHistory).filter( - SomeClassHistory.version == 1).all(), - [SomeClassHistory(version=1, name='sc1', data='somedata')] + sess.query(SomeClassHistory) + .filter(SomeClassHistory.version == 1) + .all(), + [SomeClassHistory(version=1, name="sc1", data="somedata")], ) def test_joined_inheritance(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base'} + "polymorphic_on": type, + "polymorphic_identity": "base", + } class SubClassSeparatePk(BaseClass): - __tablename__ = 'subtable1' + __tablename__ = "subtable1" id = column_property( - Column(Integer, primary_key=True), - BaseClass.id + Column(Integer, primary_key=True), BaseClass.id ) - base_id = Column(Integer, ForeignKey('basetable.id')) + base_id = Column(Integer, ForeignKey("basetable.id")) subdata1 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'sep'} + __mapper_args__ = {"polymorphic_identity": "sep"} class SubClassSamePk(BaseClass): - __tablename__ = 'subtable2' + __tablename__ = "subtable2" - id = Column( - Integer, ForeignKey('basetable.id'), primary_key=True) + id = Column(Integer, ForeignKey("basetable.id"), primary_key=True) subdata2 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'same'} + __mapper_args__ = {"polymorphic_identity": "same"} self.create_tables() sess = self.session - sep1 = SubClassSeparatePk(name='sep1', subdata1='sep1subdata') - base1 = BaseClass(name='base1') - same1 = SubClassSamePk(name='same1', subdata2='same1subdata') + sep1 = SubClassSeparatePk(name="sep1", subdata1="sep1subdata") + base1 = BaseClass(name="base1") + same1 = SubClassSamePk(name="same1", subdata2="same1subdata") sess.add_all([sep1, base1, same1]) sess.commit() - base1.name = 'base1mod' - same1.subdata2 = 'same1subdatamod' - sep1.name = 'sep1mod' + base1.name = "base1mod" + same1.subdata2 = "same1subdatamod" + sep1.name = "sep1mod" sess.commit() BaseClassHistory = BaseClass.__history_mapper__.class_ - SubClassSeparatePkHistory = \ + SubClassSeparatePkHistory = ( SubClassSeparatePk.__history_mapper__.class_ + ) SubClassSamePkHistory = SubClassSamePk.__history_mapper__.class_ eq_( sess.query(BaseClassHistory).order_by(BaseClassHistory.id).all(), [ SubClassSeparatePkHistory( - id=1, name='sep1', type='sep', version=1), - BaseClassHistory(id=2, name='base1', type='base', version=1), + id=1, name="sep1", type="sep", version=1 + ), + BaseClassHistory(id=2, name="base1", type="base", version=1), SubClassSamePkHistory( - id=3, name='same1', type='same', version=1) - ] + id=3, name="same1", type="same", version=1 + ), + ], ) - same1.subdata2 = 'same1subdatamod2' + same1.subdata2 = "same1subdatamod2" eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), [ SubClassSeparatePkHistory( - id=1, name='sep1', type='sep', version=1), - BaseClassHistory(id=2, name='base1', type='base', version=1), + id=1, name="sep1", type="sep", version=1 + ), + BaseClassHistory(id=2, name="base1", type="base", version=1), SubClassSamePkHistory( - id=3, name='same1', type='same', version=1), + id=3, name="same1", type="same", version=1 + ), SubClassSamePkHistory( - id=3, name='same1', type='same', version=2) - ] + id=3, name="same1", type="same", version=2 + ), + ], ) - base1.name = 'base1mod2' + base1.name = "base1mod2" eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), [ SubClassSeparatePkHistory( - id=1, name='sep1', type='sep', version=1), - BaseClassHistory(id=2, name='base1', type='base', version=1), + id=1, name="sep1", type="sep", version=1 + ), + BaseClassHistory(id=2, name="base1", type="base", version=1), BaseClassHistory( - id=2, name='base1mod', type='base', version=2), + id=2, name="base1mod", type="base", version=2 + ), SubClassSamePkHistory( - id=3, name='same1', type='same', version=1), + id=3, name="same1", type="same", version=1 + ), SubClassSamePkHistory( - id=3, name='same1', type='same', version=2) - ] + id=3, name="same1", type="same", version=2 + ), + ], ) def test_joined_inheritance_multilevel(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base'} + "polymorphic_on": type, + "polymorphic_identity": "base", + } class SubClass(BaseClass): - __tablename__ = 'subtable' + __tablename__ = "subtable" id = column_property( - Column(Integer, primary_key=True), - BaseClass.id + Column(Integer, primary_key=True), BaseClass.id ) - base_id = Column(Integer, ForeignKey('basetable.id')) + base_id = Column(Integer, ForeignKey("basetable.id")) subdata1 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'sub'} + __mapper_args__ = {"polymorphic_identity": "sub"} class SubSubClass(SubClass): - __tablename__ = 'subsubtable' + __tablename__ = "subsubtable" - id = Column(Integer, ForeignKey('subtable.id'), primary_key=True) + id = Column(Integer, ForeignKey("subtable.id"), primary_key=True) subdata2 = Column(String(50)) - __mapper_args__ = {'polymorphic_identity': 'subsub'} + __mapper_args__ = {"polymorphic_identity": "subsub"} self.create_tables() @@ -350,27 +376,18 @@ class TestVersioning(TestCase, AssertsCompiledSQL): q = sess.query(SubSubHistory) self.assert_compile( q, - - "SELECT " - "subsubtable_history.id AS subsubtable_history_id, " "subtable_history.id AS subtable_history_id, " "basetable_history.id AS basetable_history_id, " - "subsubtable_history.changed AS subsubtable_history_changed, " "subtable_history.changed AS subtable_history_changed, " "basetable_history.changed AS basetable_history_changed, " - "basetable_history.name AS basetable_history_name, " - "basetable_history.type AS basetable_history_type, " - "subsubtable_history.version AS subsubtable_history_version, " "subtable_history.version AS subtable_history_version, " "basetable_history.version AS basetable_history_version, " - - "subtable_history.base_id AS subtable_history_base_id, " "subtable_history.subdata1 AS subtable_history_subdata1, " "subsubtable_history.subdata2 AS subsubtable_history_subdata2 " @@ -380,64 +397,73 @@ class TestVersioning(TestCase, AssertsCompiledSQL): "AND basetable_history.version = subtable_history.version " "JOIN subsubtable_history ON subtable_history.id = " "subsubtable_history.id AND subtable_history.version = " - "subsubtable_history.version" + "subsubtable_history.version", ) - ssc = SubSubClass(name='ss1', subdata1='sd1', subdata2='sd2') + ssc = SubSubClass(name="ss1", subdata1="sd1", subdata2="sd2") sess.add(ssc) sess.commit() + eq_(sess.query(SubSubHistory).all(), []) + ssc.subdata1 = "sd11" + ssc.subdata2 = "sd22" + sess.commit() eq_( sess.query(SubSubHistory).all(), - [] + [ + SubSubHistory( + name="ss1", + subdata1="sd1", + subdata2="sd2", + type="subsub", + version=1, + ) + ], ) - ssc.subdata1 = 'sd11' - ssc.subdata2 = 'sd22' - sess.commit() eq_( - sess.query(SubSubHistory).all(), - [SubSubHistory(name='ss1', subdata1='sd1', - subdata2='sd2', type='subsub', version=1)] + ssc, + SubSubClass( + name="ss1", subdata1="sd11", subdata2="sd22", version=2 + ), ) - eq_(ssc, SubSubClass( - name='ss1', subdata1='sd11', - subdata2='sd22', version=2)) def test_joined_inheritance_changed(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(20)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base' + "polymorphic_on": type, + "polymorphic_identity": "base", } class SubClass(BaseClass): - __tablename__ = 'subtable' + __tablename__ = "subtable" - id = Column(Integer, ForeignKey('basetable.id'), primary_key=True) + id = Column(Integer, ForeignKey("basetable.id"), primary_key=True) - __mapper_args__ = {'polymorphic_identity': 'sep'} + __mapper_args__ = {"polymorphic_identity": "sep"} self.create_tables() BaseClassHistory = BaseClass.__history_mapper__.class_ SubClassHistory = SubClass.__history_mapper__.class_ sess = self.session - s1 = SubClass(name='s1') + s1 = SubClass(name="s1") sess.add(s1) sess.commit() - s1.name = 's2' + s1.name = "s2" sess.commit() actual_changed_base = sess.scalar( - select([BaseClass.__history_mapper__.local_table.c.changed])) + select([BaseClass.__history_mapper__.local_table.c.changed]) + ) actual_changed_sub = sess.scalar( - select([SubClass.__history_mapper__.local_table.c.changed])) + select([SubClass.__history_mapper__.local_table.c.changed]) + ) h1 = sess.query(BaseClassHistory).first() eq_(h1.changed, actual_changed_base) eq_(h1.changed, actual_changed_sub) @@ -448,53 +474,57 @@ class TestVersioning(TestCase, AssertsCompiledSQL): def test_single_inheritance(self): class BaseClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'basetable' + __tablename__ = "basetable" id = Column(Integer, primary_key=True) name = Column(String(50)) type = Column(String(50)) __mapper_args__ = { - 'polymorphic_on': type, - 'polymorphic_identity': 'base'} + "polymorphic_on": type, + "polymorphic_identity": "base", + } class SubClass(BaseClass): subname = Column(String(50), unique=True) - __mapper_args__ = {'polymorphic_identity': 'sub'} + __mapper_args__ = {"polymorphic_identity": "sub"} self.create_tables() sess = self.session - b1 = BaseClass(name='b1') - sc = SubClass(name='s1', subname='sc1') + b1 = BaseClass(name="b1") + sc = SubClass(name="s1", subname="sc1") sess.add_all([b1, sc]) sess.commit() - b1.name = 'b1modified' + b1.name = "b1modified" BaseClassHistory = BaseClass.__history_mapper__.class_ SubClassHistory = SubClass.__history_mapper__.class_ eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), - [BaseClassHistory(id=1, name='b1', type='base', version=1)] + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), + [BaseClassHistory(id=1, name="b1", type="base", version=1)], ) - sc.name = 's1modified' - b1.name = 'b1modified2' + sc.name = "s1modified" + b1.name = "b1modified2" eq_( - sess.query(BaseClassHistory).order_by( - BaseClassHistory.id, BaseClassHistory.version).all(), + sess.query(BaseClassHistory) + .order_by(BaseClassHistory.id, BaseClassHistory.version) + .all(), [ - BaseClassHistory(id=1, name='b1', type='base', version=1), + BaseClassHistory(id=1, name="b1", type="base", version=1), BaseClassHistory( - id=1, name='b1modified', type='base', version=2), - SubClassHistory(id=2, name='s1', type='sub', version=1) - ] + id=1, name="b1modified", type="base", version=2 + ), + SubClassHistory(id=2, name="s1", type="sub", version=1), + ], ) # test the unique constraint on the subclass @@ -504,7 +534,7 @@ class TestVersioning(TestCase, AssertsCompiledSQL): def test_unique(self): class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50), unique=True) @@ -512,40 +542,39 @@ class TestVersioning(TestCase, AssertsCompiledSQL): self.create_tables() sess = self.session - sc = SomeClass(name='sc1', data='sc1') + sc = SomeClass(name="sc1", data="sc1") sess.add(sc) sess.commit() - sc.data = 'sc1modified' + sc.data = "sc1modified" sess.commit() assert sc.version == 2 - sc.data = 'sc1modified2' + sc.data = "sc1modified2" sess.commit() assert sc.version == 3 def test_relationship(self): - class SomeRelated(self.Base, ComparableEntity): - __tablename__ = 'somerelated' + __tablename__ = "somerelated" id = Column(Integer, primary_key=True) class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) - related_id = Column(Integer, ForeignKey('somerelated.id')) - related = relationship("SomeRelated", backref='classes') + related_id = Column(Integer, ForeignKey("somerelated.id")) + related = relationship("SomeRelated", backref="classes") SomeClassHistory = SomeClass.__history_mapper__.class_ self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() @@ -558,36 +587,37 @@ class TestVersioning(TestCase, AssertsCompiledSQL): assert sc.version == 2 eq_( - sess.query(SomeClassHistory).filter( - SomeClassHistory.version == 1).all(), - [SomeClassHistory(version=1, name='sc1', related_id=None)] + sess.query(SomeClassHistory) + .filter(SomeClassHistory.version == 1) + .all(), + [SomeClassHistory(version=1, name="sc1", related_id=None)], ) sc.related = None eq_( - sess.query(SomeClassHistory).order_by( - SomeClassHistory.version).all(), + sess.query(SomeClassHistory) + .order_by(SomeClassHistory.version) + .all(), [ - SomeClassHistory(version=1, name='sc1', related_id=None), - SomeClassHistory(version=2, name='sc1', related_id=sr1.id) - ] + SomeClassHistory(version=1, name="sc1", related_id=None), + SomeClassHistory(version=2, name="sc1", related_id=sr1.id), + ], ) assert sc.version == 3 def test_backref_relationship(self): - class SomeRelated(self.Base, ComparableEntity): - __tablename__ = 'somerelated' + __tablename__ = "somerelated" id = Column(Integer, primary_key=True) name = Column(String(50)) - related_id = Column(Integer, ForeignKey('sometable.id')) - related = relationship("SomeClass", backref='related') + related_id = Column(Integer, ForeignKey("sometable.id")) + related = relationship("SomeClass", backref="related") class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) @@ -599,13 +629,13 @@ class TestVersioning(TestCase, AssertsCompiledSQL): assert sc.version == 1 - sr = SomeRelated(name='sr', related=sc) + sr = SomeRelated(name="sr", related=sc) sess.add(sr) sess.commit() assert sc.version == 1 - sr.name = 'sr2' + sr.name = "sr2" sess.commit() assert sc.version == 1 @@ -616,9 +646,8 @@ class TestVersioning(TestCase, AssertsCompiledSQL): assert sc.version == 1 def test_create_double_flush(self): - class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(30)) @@ -629,56 +658,56 @@ class TestVersioning(TestCase, AssertsCompiledSQL): sc = SomeClass() self.session.add(sc) self.session.flush() - sc.name = 'Foo' + sc.name = "Foo" self.session.flush() assert sc.version == 2 def test_mutate_plain_column(self): class Document(self.Base, Versioned): - __tablename__ = 'document' + __tablename__ = "document" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, nullable=True) - description_ = Column('description', String, nullable=True) + description_ = Column("description", String, nullable=True) self.create_tables() document = Document() self.session.add(document) - document.name = 'Foo' + document.name = "Foo" self.session.commit() - document.name = 'Bar' + document.name = "Bar" self.session.commit() DocumentHistory = Document.__history_mapper__.class_ v2 = self.session.query(Document).one() v1 = self.session.query(DocumentHistory).one() self.assertEqual(v1.id, v2.id) - self.assertEqual(v2.name, 'Bar') - self.assertEqual(v1.name, 'Foo') + self.assertEqual(v2.name, "Bar") + self.assertEqual(v1.name, "Foo") def test_mutate_named_column(self): class Document(self.Base, Versioned): - __tablename__ = 'document' + __tablename__ = "document" id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String, nullable=True) - description_ = Column('description', String, nullable=True) + description_ = Column("description", String, nullable=True) self.create_tables() document = Document() self.session.add(document) - document.description_ = 'Foo' + document.description_ = "Foo" self.session.commit() - document.description_ = 'Bar' + document.description_ = "Bar" self.session.commit() DocumentHistory = Document.__history_mapper__.class_ v2 = self.session.query(Document).one() v1 = self.session.query(DocumentHistory).one() self.assertEqual(v1.id, v2.id) - self.assertEqual(v2.description_, 'Bar') - self.assertEqual(v1.description_, 'Foo') + self.assertEqual(v2.description_, "Bar") + self.assertEqual(v1.description_, "Foo") def test_unique_identifiers_across_deletes(self): """Ensure unique integer values are used for the primary table. @@ -690,21 +719,21 @@ class TestVersioning(TestCase, AssertsCompiledSQL): """ class SomeClass(Versioned, self.Base, ComparableEntity): - __tablename__ = 'sometable' + __tablename__ = "sometable" id = Column(Integer, primary_key=True) name = Column(String(50)) self.create_tables() sess = self.session - sc = SomeClass(name='sc1') + sc = SomeClass(name="sc1") sess.add(sc) sess.commit() sess.delete(sc) sess.commit() - sc2 = SomeClass(name='sc2') + sc2 = SomeClass(name="sc2") sess.add(sc2) sess.commit() @@ -721,5 +750,5 @@ class TestVersioning(TestCase, AssertsCompiledSQL): ne_(sc2.id, scdeleted.id) # If previous assertion fails, this will also fail: - sc2.name = 'sc2 modified' + sc2.name = "sc2 modified" sess.commit() diff --git a/examples/versioned_rows/__init__.py b/examples/versioned_rows/__init__.py index 637e1aca6..e5016a740 100644 --- a/examples/versioned_rows/__init__.py +++ b/examples/versioned_rows/__init__.py @@ -9,4 +9,4 @@ history row to a separate history table. .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/versioned_rows/versioned_map.py b/examples/versioned_rows/versioned_map.py index 6a5c86a3a..46bdbb783 100644 --- a/examples/versioned_rows/versioned_map.py +++ b/examples/versioned_rows/versioned_map.py @@ -28,11 +28,17 @@ those additional values. """ -from sqlalchemy import Column, String, Integer, ForeignKey, \ - create_engine +from sqlalchemy import Column, String, Integer, ForeignKey, create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import attributes, relationship, backref, \ - sessionmaker, make_transient, validates, Session +from sqlalchemy.orm import ( + attributes, + relationship, + backref, + sessionmaker, + make_transient, + validates, + Session, +) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy import event @@ -45,8 +51,9 @@ def before_flush(session, flush_context, instances): """ for instance in session.dirty: - if hasattr(instance, 'new_version') and \ - session.is_modified(instance, passive=True): + if hasattr(instance, "new_version") and session.is_modified( + instance, passive=True + ): # make it transient instance.new_version(session) @@ -54,6 +61,7 @@ def before_flush(session, flush_context, instances): # re-add session.add(instance) + Base = declarative_base() @@ -67,7 +75,8 @@ class ConfigData(Base): string name mapped to a string/int value. """ - __tablename__ = 'config' + + __tablename__ = "config" id = Column(Integer, primary_key=True) """Primary key column of this ConfigData.""" @@ -76,7 +85,7 @@ class ConfigData(Base): "ConfigValueAssociation", collection_class=attribute_mapped_collection("name"), backref=backref("config_data"), - lazy="subquery" + lazy="subquery", ) """Dictionary-backed collection of ConfigValueAssociation objects, keyed to the name of the associated ConfigValue. @@ -97,7 +106,7 @@ class ConfigData(Base): def __init__(self, data): self.data = data - @validates('elements') + @validates("elements") def _associate_with_element(self, key, element): """Associate incoming ConfigValues with this ConfigData, if not already associated. @@ -117,11 +126,11 @@ class ConfigData(Base): # history of the 'elements' collection. # this is a tuple of groups: (added, unchanged, deleted) - hist = attributes.get_history(self, 'elements') + hist = attributes.get_history(self, "elements") # rewrite the 'elements' collection # from scratch, removing all history - attributes.set_committed_value(self, 'elements', {}) + attributes.set_committed_value(self, "elements", {}) # new elements in the "added" group # are moved to our new collection. @@ -133,7 +142,8 @@ class ConfigData(Base): # the old ones stay associated with the old ConfigData for elem in hist.unchanged: self.elements[elem.name] = ConfigValueAssociation( - elem.config_value) + elem.config_value + ) # we also need to expire changes on each ConfigValueAssociation # that is to remain associated with the old ConfigData. @@ -144,12 +154,12 @@ class ConfigData(Base): class ConfigValueAssociation(Base): """Relate ConfigData objects to associated ConfigValue objects.""" - __tablename__ = 'config_value_association' + __tablename__ = "config_value_association" - config_id = Column(ForeignKey('config.id'), primary_key=True) + config_id = Column(ForeignKey("config.id"), primary_key=True) """Reference the primary key of the ConfigData object.""" - config_value_id = Column(ForeignKey('config_value.id'), primary_key=True) + config_value_id = Column(ForeignKey("config_value.id"), primary_key=True) """Reference the primary key of the ConfigValue object.""" config_value = relationship("ConfigValue", lazy="joined", innerjoin=True) @@ -182,8 +192,7 @@ class ConfigValueAssociation(Base): """ if value != self.config_value.value: - self.config_data.elements[self.name] = \ - ConfigValueAssociation( + self.config_data.elements[self.name] = ConfigValueAssociation( ConfigValue(self.config_value.name, value) ) @@ -194,13 +203,14 @@ class ConfigValue(Base): ConfigValue is immutable. """ - __tablename__ = 'config_value' + + __tablename__ = "config_value" id = Column(Integer, primary_key=True) name = Column(String(50), nullable=False) originating_config_id = Column( - Integer, ForeignKey('config.id'), - nullable=False) + Integer, ForeignKey("config.id"), nullable=False + ) int_value = Column(Integer) string_value = Column(String(255)) @@ -221,7 +231,7 @@ class ConfigValue(Base): @property def value(self): - for k in ('int_value', 'string_value'): + for k in ("int_value", "string_value"): v = getattr(self, k) if v is not None: return v @@ -237,25 +247,23 @@ class ConfigValue(Base): self.string_value = str(value) self.int_value = None -if __name__ == '__main__': - engine = create_engine('sqlite://', echo=True) + +if __name__ == "__main__": + engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) Session = sessionmaker(engine) sess = Session() - config = ConfigData({ - 'user_name': 'twitter', - 'hash_id': '4fedffca37eaf', - 'x': 27, - 'y': 450 - }) + config = ConfigData( + {"user_name": "twitter", "hash_id": "4fedffca37eaf", "x": 27, "y": 450} + ) sess.add(config) sess.commit() version_one = config.id - config.data['user_name'] = 'yahoo' + config.data["user_name"] = "yahoo" sess.commit() version_two = config.id @@ -265,27 +273,29 @@ if __name__ == '__main__': # two versions have been created. assert config.data == { - 'user_name': 'yahoo', - 'hash_id': '4fedffca37eaf', - 'x': 27, - 'y': 450 + "user_name": "yahoo", + "hash_id": "4fedffca37eaf", + "x": 27, + "y": 450, } old_config = sess.query(ConfigData).get(version_one) assert old_config.data == { - 'user_name': 'twitter', - 'hash_id': '4fedffca37eaf', - 'x': 27, - 'y': 450 + "user_name": "twitter", + "hash_id": "4fedffca37eaf", + "x": 27, + "y": 450, } # the history of any key can be acquired using # the originating_config_id attribute - history = sess.query(ConfigValue).\ - filter(ConfigValue.name == 'user_name').\ - order_by(ConfigValue.originating_config_id).\ - all() + history = ( + sess.query(ConfigValue) + .filter(ConfigValue.name == "user_name") + .order_by(ConfigValue.originating_config_id) + .all() + ) assert [(h.value, h.originating_config_id) for h in history] == ( - [('twitter', version_one), ('yahoo', version_two)] + [("twitter", version_one), ("yahoo", version_two)] ) diff --git a/examples/versioned_rows/versioned_rows.py b/examples/versioned_rows/versioned_rows.py index ca896190d..03e1c3510 100644 --- a/examples/versioned_rows/versioned_rows.py +++ b/examples/versioned_rows/versioned_rows.py @@ -3,8 +3,13 @@ an UPDATE statement on a single row into an INSERT statement, so that a new row is inserted with the new data, keeping the old row intact. """ -from sqlalchemy.orm import sessionmaker, relationship, make_transient, \ - backref, Session +from sqlalchemy.orm import ( + sessionmaker, + relationship, + make_transient, + backref, + Session, +) from sqlalchemy import Column, ForeignKey, create_engine, Integer, String from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import attributes @@ -38,9 +43,10 @@ def before_flush(session, flush_context, instances): # re-add session.add(instance) + Base = declarative_base() -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Session = sessionmaker(engine) @@ -48,43 +54,44 @@ Session = sessionmaker(engine) class Example(Versioned, Base): - __tablename__ = 'example' + __tablename__ = "example" id = Column(Integer, primary_key=True) data = Column(String) + Base.metadata.create_all(engine) session = Session() -e1 = Example(data='e1') +e1 = Example(data="e1") session.add(e1) session.commit() -e1.data = 'e2' +e1.data = "e2" session.commit() assert session.query(Example.id, Example.data).order_by(Example.id).all() == ( - [(1, 'e1'), (2, 'e2')] + [(1, "e1"), (2, "e2")] ) # example 2, versioning with a parent class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) - child_id = Column(Integer, ForeignKey('child.id')) - child = relationship("Child", backref=backref('parent', uselist=False)) + child_id = Column(Integer, ForeignKey("child.id")) + child = relationship("Child", backref=backref("parent", uselist=False)) class Child(Versioned, Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) data = Column(String) def new_version(self, session): # expire parent's reference to us - session.expire(self.parent, ['child']) + session.expire(self.parent, ["child"]) # create new version Versioned.new_version(self, session) @@ -92,18 +99,19 @@ class Child(Versioned, Base): # re-add ourselves to the parent self.parent.child = self + Base.metadata.create_all(engine) session = Session() -p1 = Parent(child=Child(data='c1')) +p1 = Parent(child=Child(data="c1")) session.add(p1) session.commit() -p1.child.data = 'c2' +p1.child.data = "c2" session.commit() assert p1.child_id == 2 assert session.query(Child.id, Child.data).order_by(Child.id).all() == ( - [(1, 'c1'), (2, 'c2')] + [(1, "c1"), (2, "c2")] ) diff --git a/examples/versioned_rows/versioned_rows_w_versionid.py b/examples/versioned_rows/versioned_rows_w_versionid.py index 8445401c5..5fd6f9fc4 100644 --- a/examples/versioned_rows/versioned_rows_w_versionid.py +++ b/examples/versioned_rows/versioned_rows_w_versionid.py @@ -6,10 +6,24 @@ This example adds a numerical version_id to the Versioned class as well as the ability to see which row is the most "current" vesion. """ -from sqlalchemy.orm import sessionmaker, relationship, make_transient, \ - backref, Session, column_property -from sqlalchemy import Column, ForeignKeyConstraint, create_engine, \ - Integer, String, Boolean, select, func +from sqlalchemy.orm import ( + sessionmaker, + relationship, + make_transient, + backref, + Session, + column_property, +) +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + create_engine, + Integer, + String, + Boolean, + select, + func, +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import attributes from sqlalchemy import event @@ -38,7 +52,8 @@ class Versioned(object): # optional - set previous version to have is_current_version=False old_id = self.id session.query(self.__class__).filter_by(id=old_id).update( - values=dict(is_current_version=False), synchronize_session=False) + values=dict(is_current_version=False), synchronize_session=False + ) # make us transient (removes persistent # identity). @@ -65,9 +80,10 @@ def before_flush(session, flush_context, instances): # re-add session.add(instance) + Base = declarative_base() -engine = create_engine('sqlite://', echo=True) +engine = create_engine("sqlite://", echo=True) Session = sessionmaker(engine) @@ -75,17 +91,18 @@ Session = sessionmaker(engine) class Example(Versioned, Base): - __tablename__ = 'example' + __tablename__ = "example" data = Column(String) + Base.metadata.create_all(engine) session = Session() -e1 = Example(id=1, data='e1') +e1 = Example(id=1, data="e1") session.add(e1) session.commit() -e1.data = 'e2' +e1.data = "e2" session.commit() assert session.query( @@ -93,36 +110,36 @@ assert session.query( Example.version_id, Example.is_current_version, Example.calc_is_current_version, - Example.data).order_by(Example.id, Example.version_id).all() == ( - [(1, 1, False, False, 'e1'), (1, 2, True, True, 'e2')] + Example.data, +).order_by(Example.id, Example.version_id).all() == ( + [(1, 1, False, False, "e1"), (1, 2, True, True, "e2")] ) # example 2, versioning with a parent class Parent(Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) child_id = Column(Integer) child_version_id = Column(Integer) - child = relationship("Child", backref=backref('parent', uselist=False)) + child = relationship("Child", backref=backref("parent", uselist=False)) __table_args__ = ( ForeignKeyConstraint( - ['child_id', 'child_version_id'], - ['child.id', 'child.version_id'], + ["child_id", "child_version_id"], ["child.id", "child.version_id"] ), ) class Child(Versioned, Base): - __tablename__ = 'child' + __tablename__ = "child" data = Column(String) def new_version(self, session): # expire parent's reference to us - session.expire(self.parent, ['child']) + session.expire(self.parent, ["child"]) # create new version Versioned.new_version(self, session) @@ -131,15 +148,16 @@ class Child(Versioned, Base): # parent foreign key to be updated also self.parent.child = self + Base.metadata.create_all(engine) session = Session() -p1 = Parent(child=Child(id=1, data='c1')) +p1 = Parent(child=Child(id=1, data="c1")) session.add(p1) session.commit() -p1.child.data = 'c2' +p1.child.data = "c2" session.commit() assert p1.child_id == 1 @@ -150,6 +168,7 @@ assert session.query( Child.version_id, Child.is_current_version, Child.calc_is_current_version, - Child.data).order_by(Child.id, Child.version_id).all() == ( - [(1, 1, False, False, 'c1'), (1, 2, True, True, 'c2')] + Child.data, +).order_by(Child.id, Child.version_id).all() == ( + [(1, 1, False, False, "c1"), (1, 2, True, True, "c2")] ) diff --git a/examples/versioned_rows/versioned_update_old_row.py b/examples/versioned_rows/versioned_update_old_row.py index 0159d2567..17c82fdc3 100644 --- a/examples/versioned_rows/versioned_update_old_row.py +++ b/examples/versioned_rows/versioned_update_old_row.py @@ -6,12 +6,23 @@ to only the most recent version. """ from sqlalchemy import ( - create_engine, Integer, String, event, Column, DateTime, - inspect, literal + create_engine, + Integer, + String, + event, + Column, + DateTime, + inspect, + literal, ) from sqlalchemy.orm import ( - make_transient, Session, relationship, attributes, backref, - make_transient_to_detached, Query + make_transient, + Session, + relationship, + attributes, + backref, + make_transient_to_detached, + Query, ) from sqlalchemy.ext.declarative import declarative_base import datetime @@ -50,7 +61,8 @@ class VersionedStartEnd(object): # make the "old" version of us, which we will turn into an # UPDATE old_copy_of_us = self.__class__( - id=self.id, start=self.start, end=self.end) + id=self.id, start=self.start, end=self.end + ) # turn old_copy_of_us into an UPDATE make_transient_to_detached(old_copy_of_us) @@ -95,11 +107,11 @@ def before_compile(query): """ensure all queries for VersionedStartEnd include criteria """ for ent in query.column_descriptions: - entity = ent['entity'] + entity = ent["entity"] if entity is None: continue - insp = inspect(ent['entity']) - mapper = getattr(insp, 'mapper', None) + insp = inspect(ent["entity"]) + mapper = getattr(insp, "mapper", None) if mapper and issubclass(mapper.class_, VersionedStartEnd): query = query.enable_assertions(False).filter( # using a literal "now" because SQLite's "between" @@ -107,14 +119,14 @@ def before_compile(query): # ``func.now()`` and we'd be using PostgreSQL literal( current_time() + datetime.timedelta(seconds=1) - ).between(ent['entity'].start, ent['entity'].end) + ).between(ent["entity"].start, ent["entity"].end) ) return query class Parent(VersionedStartEnd, Base): - __tablename__ = 'parent' + __tablename__ = "parent" id = Column(Integer, primary_key=True) start = Column(DateTime, primary_key=True) end = Column(DateTime, primary_key=True) @@ -124,10 +136,7 @@ class Parent(VersionedStartEnd, Base): child = relationship( "Child", - primaryjoin=( - "Child.id == foreign(Parent.child_n)" - ), - + primaryjoin=("Child.id == foreign(Parent.child_n)"), # note the primaryjoin can also be: # # "and_(Child.id == foreign(Parent.child_n), " @@ -138,14 +147,13 @@ class Parent(VersionedStartEnd, Base): # as well, it just means the criteria will be present twice for most # parent->child load operations # - uselist=False, - backref=backref('parent', uselist=False) + backref=backref("parent", uselist=False), ) class Child(VersionedStartEnd, Base): - __tablename__ = 'child' + __tablename__ = "child" id = Column(Integer, primary_key=True) start = Column(DateTime, primary_key=True) @@ -155,7 +163,7 @@ class Child(VersionedStartEnd, Base): def new_version(self, session): # expire parent's reference to us - session.expire(self.parent, ['child']) + session.expire(self.parent, ["child"]) # create new version VersionedStartEnd.new_version(self, session) @@ -163,6 +171,7 @@ class Child(VersionedStartEnd, Base): # re-add ourselves to the parent self.parent.child = self + times = [] @@ -185,27 +194,37 @@ def time_passes(s): assert times[-1] > times[-2] return times[-1] -e = create_engine("sqlite://", echo='debug') + +e = create_engine("sqlite://", echo="debug") Base.metadata.create_all(e) s = Session(e) now = time_passes(s) -c1 = Child(id=1, data='child 1') -p1 = Parent(id=1, data='c1', child=c1) +c1 = Child(id=1, data="child 1") +p1 = Parent(id=1, data="c1", child=c1) s.add(p1) s.commit() # assert raw DB data assert s.query(Parent.__table__).all() == [ - (1, times[0] - datetime.timedelta(days=3), - times[0] + datetime.timedelta(days=3), 'c1', 1) + ( + 1, + times[0] - datetime.timedelta(days=3), + times[0] + datetime.timedelta(days=3), + "c1", + 1, + ) ] assert s.query(Child.__table__).all() == [ - (1, times[0] - datetime.timedelta(days=3), - times[0] + datetime.timedelta(days=3), 'child 1') + ( + 1, + times[0] - datetime.timedelta(days=3), + times[0] + datetime.timedelta(days=3), + "child 1", + ) ] now = time_passes(s) @@ -214,7 +233,7 @@ p1_check = s.query(Parent).first() assert p1_check is p1 assert p1_check.child is c1 -p1.child.data = 'elvis presley' +p1.child.data = "elvis presley" s.commit() @@ -226,40 +245,51 @@ c2_check = p2_check.child assert p2_check.child is c1 # new data -assert c1.data == 'elvis presley' +assert c1.data == "elvis presley" # new end time assert c1.end == now + datetime.timedelta(days=2) # assert raw DB data assert s.query(Parent.__table__).all() == [ - (1, times[0] - datetime.timedelta(days=3), - times[0] + datetime.timedelta(days=3), 'c1', 1) + ( + 1, + times[0] - datetime.timedelta(days=3), + times[0] + datetime.timedelta(days=3), + "c1", + 1, + ) ] assert s.query(Child.__table__).order_by(Child.end).all() == [ - (1, times[0] - datetime.timedelta(days=3), times[1], 'child 1'), - (1, times[1], times[1] + datetime.timedelta(days=2), 'elvis presley') + (1, times[0] - datetime.timedelta(days=3), times[1], "child 1"), + (1, times[1], times[1] + datetime.timedelta(days=2), "elvis presley"), ] now = time_passes(s) -p1.data = 'c2 elvis presley' +p1.data = "c2 elvis presley" s.commit() # assert raw DB data. now there are two parent rows. assert s.query(Parent.__table__).order_by(Parent.end).all() == [ - (1, times[0] - datetime.timedelta(days=3), times[2], 'c1', 1), - (1, times[2], times[2] + datetime.timedelta(days=2), 'c2 elvis presley', 1) + (1, times[0] - datetime.timedelta(days=3), times[2], "c1", 1), + ( + 1, + times[2], + times[2] + datetime.timedelta(days=2), + "c2 elvis presley", + 1, + ), ] assert s.query(Child.__table__).order_by(Child.end).all() == [ - (1, times[0] - datetime.timedelta(days=3), times[1], 'child 1'), - (1, times[1], times[1] + datetime.timedelta(days=2), 'elvis presley') + (1, times[0] - datetime.timedelta(days=3), times[1], "child 1"), + (1, times[1], times[1] + datetime.timedelta(days=2), "elvis presley"), ] # add some more rows to test that these aren't coming back for # queries -s.add(Parent(id=2, data='unrelated', child=Child(id=2, data='unrelated'))) +s.add(Parent(id=2, data="unrelated", child=Child(id=2, data="unrelated"))) s.commit() @@ -274,6 +304,6 @@ c3_check = s.query(Child).filter(Child.parent == p3_check).one() assert c3_check is c1 # one child one parent.... -c3_check = s.query(Child).join(Parent.child).filter( - Parent.id == p3_check.id).one() - +c3_check = ( + s.query(Child).join(Parent.child).filter(Parent.id == p3_check.id).one() +) diff --git a/examples/vertical/__init__.py b/examples/vertical/__init__.py index 0b69f32ea..0e09b9a55 100644 --- a/examples/vertical/__init__.py +++ b/examples/vertical/__init__.py @@ -31,4 +31,4 @@ Example:: .. autosource:: -"""
\ No newline at end of file +""" diff --git a/examples/vertical/dictlike-polymorphic.py b/examples/vertical/dictlike-polymorphic.py index 7147ac40b..c000ff3cf 100644 --- a/examples/vertical/dictlike-polymorphic.py +++ b/examples/vertical/dictlike-polymorphic.py @@ -30,6 +30,7 @@ from sqlalchemy import event from sqlalchemy import literal_column from .dictlike import ProxiedDictMixin + class PolymorphicVerticalProperty(object): """A key/value pair with polymorphic value storage. @@ -69,6 +70,7 @@ class PolymorphicVerticalProperty(object): """A comparator for .value, builds a polymorphic comparison via CASE. """ + def __init__(self, cls): self.cls = cls @@ -77,40 +79,59 @@ class PolymorphicVerticalProperty(object): whens = [ ( literal_column("'%s'" % discriminator), - cast(getattr(self.cls, attribute), String) - ) for attribute, discriminator in pairs + cast(getattr(self.cls, attribute), String), + ) + for attribute, discriminator in pairs if attribute is not None ] return case(whens, self.cls.type, null()) + def __eq__(self, other): return self._case() == cast(other, String) + def __ne__(self, other): return self._case() != cast(other, String) def __repr__(self): - return '<%s %r=%r>' % (self.__class__.__name__, self.key, self.value) + return "<%s %r=%r>" % (self.__class__.__name__, self.key, self.value) + -@event.listens_for(PolymorphicVerticalProperty, "mapper_configured", propagate=True) +@event.listens_for( + PolymorphicVerticalProperty, "mapper_configured", propagate=True +) def on_new_class(mapper, cls_): """Look for Column objects with type info in them, and work up a lookup table.""" info_dict = {} - info_dict[type(None)] = (None, 'none') - info_dict['none'] = (None, 'none') + info_dict[type(None)] = (None, "none") + info_dict["none"] = (None, "none") for k in mapper.c.keys(): col = mapper.c[k] - if 'type' in col.info: - python_type, discriminator = col.info['type'] + if "type" in col.info: + python_type, discriminator = col.info["type"] info_dict[python_type] = (k, discriminator) info_dict[discriminator] = (k, discriminator) cls_.type_map = info_dict -if __name__ == '__main__': - from sqlalchemy import (Column, Integer, Unicode, - ForeignKey, UnicodeText, and_, or_, String, Boolean, cast, - null, case, create_engine) + +if __name__ == "__main__": + from sqlalchemy import ( + Column, + Integer, + Unicode, + ForeignKey, + UnicodeText, + and_, + or_, + String, + Boolean, + cast, + null, + case, + create_engine, + ) from sqlalchemy.orm import relationship, Session from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.ext.declarative import declarative_base @@ -118,36 +139,38 @@ if __name__ == '__main__': Base = declarative_base() - class AnimalFact(PolymorphicVerticalProperty, Base): """A fact about an animal.""" - __tablename__ = 'animal_fact' + __tablename__ = "animal_fact" - animal_id = Column(ForeignKey('animal.id'), primary_key=True) + animal_id = Column(ForeignKey("animal.id"), primary_key=True) key = Column(Unicode(64), primary_key=True) type = Column(Unicode(16)) # add information about storage for different types # in the info dictionary of Columns - int_value = Column(Integer, info={'type': (int, 'integer')}) - char_value = Column(UnicodeText, info={'type': (str, 'string')}) - boolean_value = Column(Boolean, info={'type': (bool, 'boolean')}) + int_value = Column(Integer, info={"type": (int, "integer")}) + char_value = Column(UnicodeText, info={"type": (str, "string")}) + boolean_value = Column(Boolean, info={"type": (bool, "boolean")}) class Animal(ProxiedDictMixin, Base): """an Animal""" - __tablename__ = 'animal' + __tablename__ = "animal" id = Column(Integer, primary_key=True) name = Column(Unicode(100)) - facts = relationship("AnimalFact", - collection_class=attribute_mapped_collection('key')) + facts = relationship( + "AnimalFact", collection_class=attribute_mapped_collection("key") + ) - _proxied = association_proxy("facts", "value", - creator= - lambda key, value: AnimalFact(key=key, value=value)) + _proxied = association_proxy( + "facts", + "value", + creator=lambda key, value: AnimalFact(key=key, value=value), + ) def __init__(self, name): self.name = name @@ -159,66 +182,66 @@ if __name__ == '__main__': def with_characteristic(self, key, value): return self.facts.any(key=key, value=value) - engine = create_engine('sqlite://', echo=True) + engine = create_engine("sqlite://", echo=True) Base.metadata.create_all(engine) session = Session(engine) - stoat = Animal('stoat') - stoat['color'] = 'red' - stoat['cuteness'] = 7 - stoat['weasel-like'] = True + stoat = Animal("stoat") + stoat["color"] = "red" + stoat["cuteness"] = 7 + stoat["weasel-like"] = True session.add(stoat) session.commit() - critter = session.query(Animal).filter(Animal.name == 'stoat').one() - print(critter['color']) - print(critter['cuteness']) + critter = session.query(Animal).filter(Animal.name == "stoat").one() + print(critter["color"]) + print(critter["cuteness"]) print("changing cuteness value and type:") - critter['cuteness'] = 'very cute' + critter["cuteness"] = "very cute" session.commit() - marten = Animal('marten') - marten['cuteness'] = 5 - marten['weasel-like'] = True - marten['poisonous'] = False + marten = Animal("marten") + marten["cuteness"] = 5 + marten["weasel-like"] = True + marten["poisonous"] = False session.add(marten) - shrew = Animal('shrew') - shrew['cuteness'] = 5 - shrew['weasel-like'] = False - shrew['poisonous'] = True + shrew = Animal("shrew") + shrew["cuteness"] = 5 + shrew["weasel-like"] = False + shrew["poisonous"] = True session.add(shrew) session.commit() - q = (session.query(Animal). - filter(Animal.facts.any( - and_(AnimalFact.key == 'weasel-like', - AnimalFact.value == True)))) - print('weasel-like animals', q.all()) - - q = (session.query(Animal). - filter(Animal.with_characteristic('weasel-like', True))) - print('weasel-like animals again', q.all()) - - q = (session.query(Animal). - filter(Animal.with_characteristic('poisonous', False))) - print('animals with poisonous=False', q.all()) - - q = (session.query(Animal). - filter(or_( - Animal.with_characteristic('poisonous', False), - ~Animal.facts.any(AnimalFact.key == 'poisonous') - ) - ) + q = session.query(Animal).filter( + Animal.facts.any( + and_(AnimalFact.key == "weasel-like", AnimalFact.value == True) ) - print('non-poisonous animals', q.all()) - - q = (session.query(Animal). - filter(Animal.facts.any(AnimalFact.value == 5))) - print('any animal with a .value of 5', q.all()) + ) + print("weasel-like animals", q.all()) + + q = session.query(Animal).filter( + Animal.with_characteristic("weasel-like", True) + ) + print("weasel-like animals again", q.all()) + + q = session.query(Animal).filter( + Animal.with_characteristic("poisonous", False) + ) + print("animals with poisonous=False", q.all()) + + q = session.query(Animal).filter( + or_( + Animal.with_characteristic("poisonous", False), + ~Animal.facts.any(AnimalFact.key == "poisonous"), + ) + ) + print("non-poisonous animals", q.all()) + q = session.query(Animal).filter(Animal.facts.any(AnimalFact.value == 5)) + print("any animal with a .value of 5", q.all()) diff --git a/examples/vertical/dictlike.py b/examples/vertical/dictlike.py index 08989d8c2..f1f364207 100644 --- a/examples/vertical/dictlike.py +++ b/examples/vertical/dictlike.py @@ -32,6 +32,7 @@ can be used with many common vertical schemas as-is or with minor adaptations. """ from __future__ import unicode_literals + class ProxiedDictMixin(object): """Adds obj[key] access to a mapped class. @@ -60,9 +61,16 @@ class ProxiedDictMixin(object): del self._proxied[key] -if __name__ == '__main__': - from sqlalchemy import (Column, Integer, Unicode, - ForeignKey, UnicodeText, and_, create_engine) +if __name__ == "__main__": + from sqlalchemy import ( + Column, + Integer, + Unicode, + ForeignKey, + UnicodeText, + and_, + create_engine, + ) from sqlalchemy.orm import relationship, Session from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.ext.declarative import declarative_base @@ -73,26 +81,29 @@ if __name__ == '__main__': class AnimalFact(Base): """A fact about an animal.""" - __tablename__ = 'animal_fact' + __tablename__ = "animal_fact" - animal_id = Column(ForeignKey('animal.id'), primary_key=True) + animal_id = Column(ForeignKey("animal.id"), primary_key=True) key = Column(Unicode(64), primary_key=True) value = Column(UnicodeText) class Animal(ProxiedDictMixin, Base): """an Animal""" - __tablename__ = 'animal' + __tablename__ = "animal" id = Column(Integer, primary_key=True) name = Column(Unicode(100)) - facts = relationship("AnimalFact", - collection_class=attribute_mapped_collection('key')) + facts = relationship( + "AnimalFact", collection_class=attribute_mapped_collection("key") + ) - _proxied = association_proxy("facts", "value", - creator= - lambda key, value: AnimalFact(key=key, value=value)) + _proxied = association_proxy( + "facts", + "value", + creator=lambda key, value: AnimalFact(key=key, value=value), + ) def __init__(self, name): self.name = name @@ -109,57 +120,56 @@ if __name__ == '__main__': session = Session(bind=engine) - stoat = Animal('stoat') - stoat['color'] = 'reddish' - stoat['cuteness'] = 'somewhat' + stoat = Animal("stoat") + stoat["color"] = "reddish" + stoat["cuteness"] = "somewhat" # dict-like assignment transparently creates entries in the # stoat.facts collection: - print(stoat.facts['color']) + print(stoat.facts["color"]) session.add(stoat) session.commit() - critter = session.query(Animal).filter(Animal.name == 'stoat').one() - print(critter['color']) - print(critter['cuteness']) + critter = session.query(Animal).filter(Animal.name == "stoat").one() + print(critter["color"]) + print(critter["cuteness"]) - critter['cuteness'] = 'very' + critter["cuteness"] = "very" - print('changing cuteness:') + print("changing cuteness:") - marten = Animal('marten') - marten['color'] = 'brown' - marten['cuteness'] = 'somewhat' + marten = Animal("marten") + marten["color"] = "brown" + marten["cuteness"] = "somewhat" session.add(marten) - shrew = Animal('shrew') - shrew['cuteness'] = 'somewhat' - shrew['poisonous-part'] = 'saliva' + shrew = Animal("shrew") + shrew["cuteness"] = "somewhat" + shrew["poisonous-part"] = "saliva" session.add(shrew) - loris = Animal('slow loris') - loris['cuteness'] = 'fairly' - loris['poisonous-part'] = 'elbows' + loris = Animal("slow loris") + loris["cuteness"] = "fairly" + loris["poisonous-part"] = "elbows" session.add(loris) - q = (session.query(Animal). - filter(Animal.facts.any( - and_(AnimalFact.key == 'color', - AnimalFact.value == 'reddish')))) - print('reddish animals', q.all()) + q = session.query(Animal).filter( + Animal.facts.any( + and_(AnimalFact.key == "color", AnimalFact.value == "reddish") + ) + ) + print("reddish animals", q.all()) - q = session.query(Animal).\ - filter(Animal.with_characteristic("color", 'brown')) - print('brown animals', q.all()) + q = session.query(Animal).filter( + Animal.with_characteristic("color", "brown") + ) + print("brown animals", q.all()) - q = session.query(Animal).\ - filter(~Animal.with_characteristic("poisonous-part", 'elbows')) - print('animals without poisonous-part == elbows', q.all()) + q = session.query(Animal).filter( + ~Animal.with_characteristic("poisonous-part", "elbows") + ) + print("animals without poisonous-part == elbows", q.all()) - q = (session.query(Animal). - filter(Animal.facts.any(value='somewhat'))) + q = session.query(Animal).filter(Animal.facts.any(value="somewhat")) print('any animal with any .value of "somewhat"', q.all()) - - - |
