diff options
author | tiko-tiko <github.com@ticotico.crabdance.com> | 2017-01-05 14:13:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-05 14:13:53 +0100 |
commit | 2a4b9ba087094ab90a29d4a6ace4d29f2f0f8bfa (patch) | |
tree | 731f52a439b94a4d7f4a27490a0ffc61ebddb976 | |
parent | 2a869caa4a34197bc45e8c79302d559b026ef7b9 (diff) | |
download | rdflib-2a4b9ba087094ab90a29d4a6ace4d29f2f0f8bfa.tar.gz |
Implement "on the fly" aggregation
-rw-r--r-- | rdflib/plugins/sparql/aggregates.py | 327 |
1 files changed, 210 insertions, 117 deletions
diff --git a/rdflib/plugins/sparql/aggregates.py b/rdflib/plugins/sparql/aggregates.py index 1ca9e3b8..8716adf1 100644 --- a/rdflib/plugins/sparql/aggregates.py +++ b/rdflib/plugins/sparql/aggregates.py @@ -5,6 +5,7 @@ from rdflib.plugins.sparql.operators import numeric from rdflib.plugins.sparql.datatypes import type_promotion from rdflib.plugins.sparql.compat import num_max, num_min +from rdflib.plugins.sparql.sparql import SPARQLTypeError from decimal import Decimal @@ -12,157 +13,249 @@ from decimal import Decimal Aggregation functions """ +class Accumulator(object): + """abstract base class for different aggregation functions """ -def _eval_rows(expr, group, distinct): - seen = set() - for row in group: - try: - val = _eval(expr, row) - if not distinct or not val in seen: - yield val - seen.add(val) - except: - pass + def __init__(self, aggregation): + self.var = aggregation.res + self.expr = aggregation.vars + if not aggregation.distinct: + self.use_row = self.dont_care + self.distinct = False + else: + self.distinct = aggregation.distinct + self.seen = set() + + def dont_care(self, row): + """skips distinct test """ + return True + + def use_row(self, row): + """tests distinct with set """ + return _eval(self.expr, row) not in self.seen + + def set_value(self, bindings): + """sets final value in bindings""" + bindings[self.var] = self.get_value() -def agg_Sum(a, group, bindings): - c = 0 +class Counter(Accumulator): - dt = None - for e in _eval_rows(a.vars, group, a.distinct): + def __init__(self, aggregation): + super(Counter, self).__init__(aggregation) + self.value = 0 + if self.expr == "*": + # cannot eval "*" => always use the full row + self.eval_row = self.eval_full_row + + def update(self, row, aggregator): try: - n = numeric(e) - if dt == None: - dt = e.datatype - else: - dt = type_promotion(dt, e.datatype) + val = self.eval_row(row) + except NotBoundError: + # skip UNDEF + return + self.value += 1 + if self.distinct: + self.seen.add(val) - if type(c) == float and type(n) == Decimal: - c += float(n) - elif type(n) == float and type(c) == Decimal: - c = float(c) + n - else: - c += n - except: - pass # simply dont count + def get_value(self): + return Literal(self.value) - bindings[a.res] = Literal(c, datatype=dt) + def eval_row(self, row): + return _eval(self.expr, row) -# Perhaps TODO: keep datatype for max/min? + def eval_full_row(self, row): + return row + def use_row(self, row): + return self.eval_row(row) not in self.seen -def agg_Min(a, group, bindings): - m = None - for v in _eval_rows(a.vars, group, None): # DISTINCT makes no difference for MIN +def type_safe_numbers(*args): + types = map(type, args) + if float in types and Decimal in types: + return map(float, args) + return args + + +class Sum(Accumulator): + + def __init__(self, aggregation): + super(Sum, self).__init__(aggregation) + self.value = 0 + self.datatype = None + + def update(self, row, aggregator): try: - v = numeric(v) - if m is None: - m = v + value = _eval(self.expr, row) + dt = self.datatype + if dt is None: + dt = value.datatype else: - m = num_min(v, m) - except: - continue # try other values + dt = type_promotion(dt, value.datatype) + self.datatype = dt + self.value = sum(type_safe_numbers(self.value, numeric(value))) + if self.distinct: + self.seen.add(value) + except NotBoundError: + # skip UNDEF + pass - if m is not None: - bindings[a.res] = Literal(m) + def get_value(self): + return Literal(self.value, datatype=self.datatype) +class Average(Accumulator): -def agg_Max(a, group, bindings): - m = None + def __init__(self, aggregation): + super(Average, self).__init__(aggregation) + self.counter = 0 + self.sum = 0 + self.datatype = None - for v in _eval_rows(a.vars, group, None): # DISTINCT makes no difference for MAX + def update(self, row, aggregator): try: - v = numeric(v) - if m is None: - m = v + value = _eval(self.expr, row) + dt = self.datatype + self.sum = sum(type_safe_numbers(self.sum, numeric(value))) + if dt is None: + dt = value.datatype else: - m = num_max(v, m) - except: - return # error in aggregate => no binding + dt = type_promotion(dt, value.datatype) + self.datatype = dt + if self.distinct: + self.seen.add(value) + self.counter += 1 + # skip UNDEF or BNode => SPARQLTypeError + except NotBoundError: + pass + except SPARQLTypeError: + pass - if m is not None: - bindings[a.res] = Literal(m) + def get_value(self): + if self.counter == 0: + return Literal(0) + if self.datatype in (XSD.float, XSD.double): + return Literal(self.sum / self.counter) + else: + return Literal(Decimal(self.sum) / Decimal(self.counter)) -def agg_Count(a, group, bindings): - if a.vars == '*': - c = len(group) - else: - c = 0 - for e in _eval_rows(a.vars, group, a.distinct): - c += 1 +class Extremum(Accumulator): + """abstract base class for Minimum and Maximum""" - bindings[a.res] = Literal(c) + def __init__(self, aggregation): + super(Extremum, self).__init__(aggregation) + self.value = None + # DISTINCT would not change the value for MIN or MAX + self.use_row = self.dont_care + def set_value(self, bindings): + if self.value is not None: + # simply do not set if self.value is still None + bindings[self.var] = Literal(self.value) -def agg_Sample(a, group, bindings): - for ctx in group: + def update(self, row, aggregator): try: - bindings[a.res] = _eval(a.vars, ctx) - break + if self.value is None: + self.value = numeric(_eval(self.expr, row)) + else: + # self.compare is implemented by Minimum/Maximum + self.value = self.compare(self.value, numeric(_eval(self.expr, row))) + # skip UNDEF or BNode => SPARQLTypeError except NotBoundError: pass + except SPARQLTypeError: + pass + + +class Minimum(Extremum): + + def compare(self, val1, val2): + return num_min(val1, val2) -def agg_GroupConcat(a, group, bindings): +class Maximum(Extremum): - sep = a.separator or " " - if a.distinct: - agg = lambda x: x - else: - add = set + def compare(self, val1, val2): + return num_max(val1, val2) - bindings[a.res] = Literal( - sep.join(unicode(x) for x in _eval_rows(a.vars, group, a.distinct))) +class Sample(Accumulator): + """takes the first eligable value""" -def agg_Avg(a, group, bindings): + def __init__(self, aggregation): + super(Sample, self).__init__(aggregation) + # DISTINCT would not change the value + self.use_row = self.dont_care - c = 0 - s = 0 - dt = None - for e in _eval_rows(a.vars, group, a.distinct): + def update(self, row, aggregator): try: - n = numeric(e) - if dt == None: - dt = e.datatype - else: - dt = type_promotion(dt, e.datatype) + # set the value now + aggregator.bindings[self.var] = _eval(self.expr, row) + # and skip this accumulator for future rows + del aggregator.accumulators[self.var] + except NotBoundError: + pass + + def get_value(self): + # set None if no value was set + return None + +class GroupConcat(Accumulator): + + def __init__(self, aggregation): + super(GroupConcat, self).__init__(aggregation) + # only GROUPCONCAT needs to have a list as accumlator + self.value = [] + self.separator = aggregation.separator or " " + + def update(self, row, aggregator): + try: + value = _eval(self.expr, row) + self.value.append(value) + if self.distinct: + self.seen.add(value) + # skip UNDEF + except NotBoundError: + pass + + def get_value(self): + return Literal(self.separator.join(unicode(v) for v in self.value)) + + +class Aggregator(object): + """combines different Accumulator objects""" + + accumulator_classes = { + "Aggregate_Count": Counter, + "Aggregate_Sample": Sample, + "Aggregate_Sum": Sum, + "Aggregate_Avg": Average, + "Aggregate_Min": Minimum, + "Aggregate_Max": Maximum, + "Aggregate_GroupConcat": GroupConcat, + } + + def __init__(self, aggregations): + self.bindings = {} + self.accumulators = {} + for a in aggregations: + accumulator_class = self.accumulator_classes.get(a.name) + if accumulator_class is None: + raise Exception("Unknown aggregate function " + a.name) + self.accumulators[a.res] = accumulator_class(a) + + def update(self, row): + """update all own accumulators""" + # SAMPLE accumulators may delete themselves + # => iterate over list not generator + for acc in self.accumulators.values(): + if acc.use_row(row): + acc.update(row, self) + + def get_bindings(self): + """calculate and set last values""" + for acc in self.accumulators.itervalues(): + acc.set_value(self.bindings) + return self.bindings - if type(s) == float and type(n) == Decimal: - s += float(n) - elif type(n) == float and type(s) == Decimal: - s = float(s) + n - else: - s += n - c += 1 - except: - return # error in aggregate => no binding - - if c == 0: - bindings[a.res] = Literal(0) - if dt == XSD.float or dt == XSD.double: - bindings[a.res] = Literal(s / c) - else: - bindings[a.res] = Literal(Decimal(s) / Decimal(c)) - - -def evalAgg(a, group, bindings): - if a.name == 'Aggregate_Count': - return agg_Count(a, group, bindings) - elif a.name == 'Aggregate_Sum': - return agg_Sum(a, group, bindings) - elif a.name == 'Aggregate_Sample': - return agg_Sample(a, group, bindings) - elif a.name == 'Aggregate_GroupConcat': - return agg_GroupConcat(a, group, bindings) - elif a.name == 'Aggregate_Avg': - return agg_Avg(a, group, bindings) - elif a.name == 'Aggregate_Min': - return agg_Min(a, group, bindings) - elif a.name == 'Aggregate_Max': - return agg_Max(a, group, bindings) - - else: - raise Exception("Unknown aggregate function " + a.name) |