summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortiko-tiko <github.com@ticotico.crabdance.com>2017-01-05 14:13:53 +0100
committerGitHub <noreply@github.com>2017-01-05 14:13:53 +0100
commit2a4b9ba087094ab90a29d4a6ace4d29f2f0f8bfa (patch)
tree731f52a439b94a4d7f4a27490a0ffc61ebddb976
parent2a869caa4a34197bc45e8c79302d559b026ef7b9 (diff)
downloadrdflib-2a4b9ba087094ab90a29d4a6ace4d29f2f0f8bfa.tar.gz
Implement "on the fly" aggregation
-rw-r--r--rdflib/plugins/sparql/aggregates.py327
1 files changed, 210 insertions, 117 deletions
diff --git a/rdflib/plugins/sparql/aggregates.py b/rdflib/plugins/sparql/aggregates.py
index 1ca9e3b8..8716adf1 100644
--- a/rdflib/plugins/sparql/aggregates.py
+++ b/rdflib/plugins/sparql/aggregates.py
@@ -5,6 +5,7 @@ from rdflib.plugins.sparql.operators import numeric
from rdflib.plugins.sparql.datatypes import type_promotion
from rdflib.plugins.sparql.compat import num_max, num_min
+from rdflib.plugins.sparql.sparql import SPARQLTypeError
from decimal import Decimal
@@ -12,157 +13,249 @@ from decimal import Decimal
Aggregation functions
"""
+class Accumulator(object):
+ """abstract base class for different aggregation functions """
-def _eval_rows(expr, group, distinct):
- seen = set()
- for row in group:
- try:
- val = _eval(expr, row)
- if not distinct or not val in seen:
- yield val
- seen.add(val)
- except:
- pass
+ def __init__(self, aggregation):
+ self.var = aggregation.res
+ self.expr = aggregation.vars
+ if not aggregation.distinct:
+ self.use_row = self.dont_care
+ self.distinct = False
+ else:
+ self.distinct = aggregation.distinct
+ self.seen = set()
+
+ def dont_care(self, row):
+ """skips distinct test """
+ return True
+
+ def use_row(self, row):
+ """tests distinct with set """
+ return _eval(self.expr, row) not in self.seen
+
+ def set_value(self, bindings):
+ """sets final value in bindings"""
+ bindings[self.var] = self.get_value()
-def agg_Sum(a, group, bindings):
- c = 0
+class Counter(Accumulator):
- dt = None
- for e in _eval_rows(a.vars, group, a.distinct):
+ def __init__(self, aggregation):
+ super(Counter, self).__init__(aggregation)
+ self.value = 0
+ if self.expr == "*":
+ # cannot eval "*" => always use the full row
+ self.eval_row = self.eval_full_row
+
+ def update(self, row, aggregator):
try:
- n = numeric(e)
- if dt == None:
- dt = e.datatype
- else:
- dt = type_promotion(dt, e.datatype)
+ val = self.eval_row(row)
+ except NotBoundError:
+ # skip UNDEF
+ return
+ self.value += 1
+ if self.distinct:
+ self.seen.add(val)
- if type(c) == float and type(n) == Decimal:
- c += float(n)
- elif type(n) == float and type(c) == Decimal:
- c = float(c) + n
- else:
- c += n
- except:
- pass # simply dont count
+ def get_value(self):
+ return Literal(self.value)
- bindings[a.res] = Literal(c, datatype=dt)
+ def eval_row(self, row):
+ return _eval(self.expr, row)
-# Perhaps TODO: keep datatype for max/min?
+ def eval_full_row(self, row):
+ return row
+ def use_row(self, row):
+ return self.eval_row(row) not in self.seen
-def agg_Min(a, group, bindings):
- m = None
- for v in _eval_rows(a.vars, group, None): # DISTINCT makes no difference for MIN
+def type_safe_numbers(*args):
+ types = map(type, args)
+ if float in types and Decimal in types:
+ return map(float, args)
+ return args
+
+
+class Sum(Accumulator):
+
+ def __init__(self, aggregation):
+ super(Sum, self).__init__(aggregation)
+ self.value = 0
+ self.datatype = None
+
+ def update(self, row, aggregator):
try:
- v = numeric(v)
- if m is None:
- m = v
+ value = _eval(self.expr, row)
+ dt = self.datatype
+ if dt is None:
+ dt = value.datatype
else:
- m = num_min(v, m)
- except:
- continue # try other values
+ dt = type_promotion(dt, value.datatype)
+ self.datatype = dt
+ self.value = sum(type_safe_numbers(self.value, numeric(value)))
+ if self.distinct:
+ self.seen.add(value)
+ except NotBoundError:
+ # skip UNDEF
+ pass
- if m is not None:
- bindings[a.res] = Literal(m)
+ def get_value(self):
+ return Literal(self.value, datatype=self.datatype)
+class Average(Accumulator):
-def agg_Max(a, group, bindings):
- m = None
+ def __init__(self, aggregation):
+ super(Average, self).__init__(aggregation)
+ self.counter = 0
+ self.sum = 0
+ self.datatype = None
- for v in _eval_rows(a.vars, group, None): # DISTINCT makes no difference for MAX
+ def update(self, row, aggregator):
try:
- v = numeric(v)
- if m is None:
- m = v
+ value = _eval(self.expr, row)
+ dt = self.datatype
+ self.sum = sum(type_safe_numbers(self.sum, numeric(value)))
+ if dt is None:
+ dt = value.datatype
else:
- m = num_max(v, m)
- except:
- return # error in aggregate => no binding
+ dt = type_promotion(dt, value.datatype)
+ self.datatype = dt
+ if self.distinct:
+ self.seen.add(value)
+ self.counter += 1
+ # skip UNDEF or BNode => SPARQLTypeError
+ except NotBoundError:
+ pass
+ except SPARQLTypeError:
+ pass
- if m is not None:
- bindings[a.res] = Literal(m)
+ def get_value(self):
+ if self.counter == 0:
+ return Literal(0)
+ if self.datatype in (XSD.float, XSD.double):
+ return Literal(self.sum / self.counter)
+ else:
+ return Literal(Decimal(self.sum) / Decimal(self.counter))
-def agg_Count(a, group, bindings):
- if a.vars == '*':
- c = len(group)
- else:
- c = 0
- for e in _eval_rows(a.vars, group, a.distinct):
- c += 1
+class Extremum(Accumulator):
+ """abstract base class for Minimum and Maximum"""
- bindings[a.res] = Literal(c)
+ def __init__(self, aggregation):
+ super(Extremum, self).__init__(aggregation)
+ self.value = None
+ # DISTINCT would not change the value for MIN or MAX
+ self.use_row = self.dont_care
+ def set_value(self, bindings):
+ if self.value is not None:
+ # simply do not set if self.value is still None
+ bindings[self.var] = Literal(self.value)
-def agg_Sample(a, group, bindings):
- for ctx in group:
+ def update(self, row, aggregator):
try:
- bindings[a.res] = _eval(a.vars, ctx)
- break
+ if self.value is None:
+ self.value = numeric(_eval(self.expr, row))
+ else:
+ # self.compare is implemented by Minimum/Maximum
+ self.value = self.compare(self.value, numeric(_eval(self.expr, row)))
+ # skip UNDEF or BNode => SPARQLTypeError
except NotBoundError:
pass
+ except SPARQLTypeError:
+ pass
+
+
+class Minimum(Extremum):
+
+ def compare(self, val1, val2):
+ return num_min(val1, val2)
-def agg_GroupConcat(a, group, bindings):
+class Maximum(Extremum):
- sep = a.separator or " "
- if a.distinct:
- agg = lambda x: x
- else:
- add = set
+ def compare(self, val1, val2):
+ return num_max(val1, val2)
- bindings[a.res] = Literal(
- sep.join(unicode(x) for x in _eval_rows(a.vars, group, a.distinct)))
+class Sample(Accumulator):
+ """takes the first eligable value"""
-def agg_Avg(a, group, bindings):
+ def __init__(self, aggregation):
+ super(Sample, self).__init__(aggregation)
+ # DISTINCT would not change the value
+ self.use_row = self.dont_care
- c = 0
- s = 0
- dt = None
- for e in _eval_rows(a.vars, group, a.distinct):
+ def update(self, row, aggregator):
try:
- n = numeric(e)
- if dt == None:
- dt = e.datatype
- else:
- dt = type_promotion(dt, e.datatype)
+ # set the value now
+ aggregator.bindings[self.var] = _eval(self.expr, row)
+ # and skip this accumulator for future rows
+ del aggregator.accumulators[self.var]
+ except NotBoundError:
+ pass
+
+ def get_value(self):
+ # set None if no value was set
+ return None
+
+class GroupConcat(Accumulator):
+
+ def __init__(self, aggregation):
+ super(GroupConcat, self).__init__(aggregation)
+ # only GROUPCONCAT needs to have a list as accumlator
+ self.value = []
+ self.separator = aggregation.separator or " "
+
+ def update(self, row, aggregator):
+ try:
+ value = _eval(self.expr, row)
+ self.value.append(value)
+ if self.distinct:
+ self.seen.add(value)
+ # skip UNDEF
+ except NotBoundError:
+ pass
+
+ def get_value(self):
+ return Literal(self.separator.join(unicode(v) for v in self.value))
+
+
+class Aggregator(object):
+ """combines different Accumulator objects"""
+
+ accumulator_classes = {
+ "Aggregate_Count": Counter,
+ "Aggregate_Sample": Sample,
+ "Aggregate_Sum": Sum,
+ "Aggregate_Avg": Average,
+ "Aggregate_Min": Minimum,
+ "Aggregate_Max": Maximum,
+ "Aggregate_GroupConcat": GroupConcat,
+ }
+
+ def __init__(self, aggregations):
+ self.bindings = {}
+ self.accumulators = {}
+ for a in aggregations:
+ accumulator_class = self.accumulator_classes.get(a.name)
+ if accumulator_class is None:
+ raise Exception("Unknown aggregate function " + a.name)
+ self.accumulators[a.res] = accumulator_class(a)
+
+ def update(self, row):
+ """update all own accumulators"""
+ # SAMPLE accumulators may delete themselves
+ # => iterate over list not generator
+ for acc in self.accumulators.values():
+ if acc.use_row(row):
+ acc.update(row, self)
+
+ def get_bindings(self):
+ """calculate and set last values"""
+ for acc in self.accumulators.itervalues():
+ acc.set_value(self.bindings)
+ return self.bindings
- if type(s) == float and type(n) == Decimal:
- s += float(n)
- elif type(n) == float and type(s) == Decimal:
- s = float(s) + n
- else:
- s += n
- c += 1
- except:
- return # error in aggregate => no binding
-
- if c == 0:
- bindings[a.res] = Literal(0)
- if dt == XSD.float or dt == XSD.double:
- bindings[a.res] = Literal(s / c)
- else:
- bindings[a.res] = Literal(Decimal(s) / Decimal(c))
-
-
-def evalAgg(a, group, bindings):
- if a.name == 'Aggregate_Count':
- return agg_Count(a, group, bindings)
- elif a.name == 'Aggregate_Sum':
- return agg_Sum(a, group, bindings)
- elif a.name == 'Aggregate_Sample':
- return agg_Sample(a, group, bindings)
- elif a.name == 'Aggregate_GroupConcat':
- return agg_GroupConcat(a, group, bindings)
- elif a.name == 'Aggregate_Avg':
- return agg_Avg(a, group, bindings)
- elif a.name == 'Aggregate_Min':
- return agg_Min(a, group, bindings)
- elif a.name == 'Aggregate_Max':
- return agg_Max(a, group, bindings)
-
- else:
- raise Exception("Unknown aggregate function " + a.name)