summaryrefslogtreecommitdiff
path: root/rdflib/plugins/sparql/aggregates.py
blob: 571bac4f77f968537e029091f4544610dcc558dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from rdflib import Literal, XSD

from rdflib.plugins.sparql.evalutils import _eval, NotBoundError
from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.datatypes import type_promotion

from rdflib.plugins.sparql.compat import num_max, num_min

from decimal import Decimal

"""
Aggregation functions
"""


def _eval_rows(expr, group):
    for row in group:
        try:
            yield _eval(expr, row)
        except:
            pass


def agg_Sum(a, group, bindings):
    c = 0

    dt = None
    for x in group:
        try:
            e = _eval(a.vars, x)
            n = numeric(e)
            if dt == None:
                dt = e.datatype
            else:
                dt = type_promotion(dt, e.datatype)

            if type(c) == float and type(n) == Decimal:
                c += float(n)
            elif type(n) == float and type(c) == Decimal:
                c = float(c) + n
            else:
                c += n
        except:
            pass  # simply dont count

    bindings[a.res] = Literal(c, datatype=dt)

# Perhaps TODO: keep datatype for max/min?


def agg_Min(a, group, bindings):
    m = None

    for x in group:
        try:
            v = numeric(_eval(a.vars, x))
            if m is None:
                m = v
            else:
                m = num_min(v, m)
        except:
            return  # error in aggregate => no binding

    if m is not None:
        bindings[a.res] = Literal(m)


def agg_Max(a, group, bindings):
    m = None

    for x in group:
        try:
            v = numeric(_eval(a.vars, x))
            if m is None:
                m = v
            else:
                m = num_max(v, m)
        except:
            return  # error in aggregate => no binding

    if m is not None:
        bindings[a.res] = Literal(m)


def agg_Count(a, group, bindings):
    c = 0
    for x in group:
        try:
            if a.vars != '*':
                try:
                    _eval(a.vars, x)
                except NotBoundError:
                    continue
            c += 1
        except:
            return  # error in aggregate => no binding
            # pass  # simply dont count

    bindings[a.res] = Literal(c)


def agg_Sample(a, group, bindings):
    for ctx in group:
        try:
            bindings[a.res] = _eval(a.vars, ctx)
            break
        except NotBoundError:
            pass


def agg_GroupConcat(a, group, bindings):

    sep = a.separator or " "

    bindings[a.res] = Literal(
        sep.join(unicode(x) for x in _eval_rows(a.vars, group)))


def agg_Avg(a, group, bindings):

    c = 0
    s = 0
    dt = None
    for x in group:
        try:
            e = _eval(a.vars, x)
            n = numeric(e)
            if dt == None:
                dt = e.datatype
            else:
                dt = type_promotion(dt, e.datatype)

            if type(s) == float and type(n) == Decimal:
                s += float(n)
            elif type(n) == float and type(s) == Decimal:
                s = float(s) + n
            else:
                s += n
            c += 1
        except:
            return  # error in aggregate => no binding

    if c == 0:
        bindings[a.res] = Literal(0)
    if dt == XSD.float or dt == XSD.double:
        bindings[a.res] = Literal(s / c)
    else:
        bindings[a.res] = Literal(Decimal(s) / Decimal(c))


def evalAgg(a, group, bindings):
    if a.name == 'Aggregate_Count':
        return agg_Count(a, group, bindings)
    elif a.name == 'Aggregate_Sum':
        return agg_Sum(a, group, bindings)
    elif a.name == 'Aggregate_Sample':
        return agg_Sample(a, group, bindings)
    elif a.name == 'Aggregate_GroupConcat':
        return agg_GroupConcat(a, group, bindings)
    elif a.name == 'Aggregate_Avg':
        return agg_Avg(a, group, bindings)
    elif a.name == 'Aggregate_Min':
        return agg_Min(a, group, bindings)
    elif a.name == 'Aggregate_Max':
        return agg_Max(a, group, bindings)

    else:
        raise Exception("Unknown aggregate function " + a.name)