summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
blob: d1ae9f796837dcfaaac89778b094c6225242a85f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# mapper/util.py
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

from sqlalchemy import sql, util, exceptions

all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none"])

class CascadeOptions(object):
    """Keeps track of the options sent to relation().cascade"""

    def __init__(self, arg=""):
        values = util.Set([c.strip() for c in arg.split(',')])
        self.delete_orphan = "delete-orphan" in values
        self.delete = "delete" in values or self.delete_orphan or "all" in values
        self.save_update = "save-update" in values or "all" in values
        self.merge = "merge" in values or "all" in values
        self.expunge = "expunge" in values or "all" in values
        self.refresh_expire = "refresh-expire" in values or "all" in values
        
        for x in values:
            if x not in all_cascades:
                raise exceptions.ArgumentError("Invalid cascade option '%s'" % x)
                
    def __contains__(self, item):
        return getattr(self, item.replace("-", "_"), False)
    def __repr__(self):
        return "CascadeOptions(arg=%s)" % repr(",".join([x for x in ['delete', 'save_update', 'merge', 'expunge', 'delete_orphan', 'refresh-expire'] if getattr(self, x, False) is True]))

def polymorphic_union(table_map, typecolname, aliasname='p_union'):
    """create a UNION statement used by a polymorphic mapper.
    
    See the SQLAlchemy advanced mapping docs for an example of how this is used."""
    colnames = util.Set()
    colnamemaps = {}
    types = {}
    for key in table_map.keys():
        table = table_map[key]

        # mysql doesnt like selecting from a select; make it an alias of the select
        if isinstance(table, sql.Select):
            table = table.alias()
            table_map[key] = table

        m = {}
        for c in table.c:
            colnames.add(c.name)
            m[c.name] = c
            types[c.name] = c.type
        colnamemaps[table] = m

    def col(name, table):
        try:
            return colnamemaps[table][name]
        except KeyError:
            return sql.cast(sql.null(), types[name]).label(name)

    result = []
    for type, table in table_map.iteritems():
        if typecolname is not None:
            result.append(sql.select([col(name, table) for name in colnames] + [sql.literal_column("'%s'" % type).label(typecolname)], from_obj=[table]))
        else:
            result.append(sql.select([col(name, table) for name in colnames], from_obj=[table]))
    return sql.union_all(*result).alias(aliasname)

class TranslatingDict(dict):
    """A dictionary that stores ColumnElement objects as keys.

    Incoming ColumnElement keys are translated against those of an
    underling FromClause for all operations.  This way the columns
    from any Selectable that is derived from or underlying this
    TranslatingDict's selectable can be used as keys.
    """

    def __init__(self, selectable):
        super(TranslatingDict, self).__init__()
        self.selectable = selectable

    def __translate_col(self, col):
        ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
        #if col is not ourcol:
        #    print "TD TRANSLATING ", col, "TO", ourcol
        if ourcol is None:
            return col
        else:
            return ourcol

    def __getitem__(self, col):
        return super(TranslatingDict, self).__getitem__(self.__translate_col(col))

    def has_key(self, col):
        return super(TranslatingDict, self).has_key(self.__translate_col(col))

    def __setitem__(self, col, value):
        return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value)

    def __contains__(self, col):
        return self.has_key(col)

    def setdefault(self, col, value):
        return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)

class BinaryVisitor(sql.ClauseVisitor):
    def __init__(self, func):
        self.func = func

    def visit_binary(self, binary):
        self.func(binary)

def instance_str(instance):
    """Return a string describing an instance."""

    return instance.__class__.__name__ + "@" + hex(id(instance))

def attribute_str(instance, attribute):
    return instance_str(instance) + "." + attribute