protected static String processFuncSet(Formatter formatter, FunctionCall node) { StringBuilder builder = new StringBuilder(); String functionName = getFunctionName(node); int numArguments = node.getArguments().size(); builder.append(functionName).append('(').append(formatter.process(node.getArguments().get(0), null)).append( ')'); if (numArguments > 1) { builder.append(" ON "); } for (int i = 1; i < numArguments; i++) { Expression item = node.getArguments().get(i); if (i == 1) { builder.append(formatter.process(item, null)); } else { builder.append(", ").append(formatter.process(item, null)); } } return builder.toString(); }
@Override protected String visitArithmeticExpression(ArithmeticExpression node, Void context) { if (node.getType().equals(ArithmeticExpression.Type.DIVIDE)) { if (_outputDivideByZeroGuard == true) { if (node.getRight() instanceof FunctionCall) { if (getFunctionName((FunctionCall) node.getRight()).equals("nullifzero")) { // bypass appending nullifzero return formatBinaryExpression(node.getType().getValue(), node.getLeft(), node.getRight()); } } else if (node.getRight() instanceof Literal) { // purely literal return formatBinaryExpression(node.getType().getValue(), node.getLeft(), node.getRight()); } List<Expression> arguments = new ArrayList<Expression>(); arguments.add(node.getRight()); FunctionCall nullifzeroFunc = new FunctionCall(new QualifiedName("nullifzero"), arguments); return formatBinaryExpression(node.getType().getValue(), node.getLeft(), nullifzeroFunc); } else { return formatBinaryExpression(node.getType().getValue(), node.getLeft(), node.getRight()); } } else { return formatBinaryExpression(node.getType().getValue(), node.getLeft(), node.getRight()); } }
@Override protected String visitFunctionCall(FunctionCall node, StackableAstVisitorContext<Integer> indent) { StringBuilder builder = new StringBuilder(); String arguments = joinExpressions(node.getArguments(), indent); if (node.getArguments().isEmpty() && "count".equalsIgnoreCase(node.getName().getSuffix())) { arguments = "*"; } if (node.isDistinct()) { arguments = "DISTINCT " + arguments; } builder.append(formatQualifiedName(node.getName())) .append('(').append(arguments).append(')'); if (node.getFilter().isPresent()) { builder.append(" FILTER ").append(visitFilter(node.getFilter().get(), indent)); } if (node.getWindow().isPresent()) { builder.append(" OVER ").append(visitWindow(node.getWindow().get(), indent)); } return builder.toString(); }
@Override protected RowExpression visitFunctionCall(FunctionCall node, Void context) { List<RowExpression> arguments = node.getArguments().stream() .map(value -> process(value, context)) .collect(toImmutableList()); List<TypeSignature> argumentTypes = arguments.stream() .map(RowExpression::getType) .map(Type::getTypeSignature) .collect(toImmutableList()); Signature signature = new Signature(node.getName().getSuffix(), functionKind, types.get(node).getTypeSignature(), argumentTypes); return call(signature, types.get(node), arguments); }
private PlanBuilder distinct(PlanBuilder subPlan, QuerySpecification node, List<FieldOrExpression> outputs, List<FieldOrExpression> orderBy) { if (node.getSelect().isDistinct()) { checkState(outputs.containsAll(orderBy), "Expected ORDER BY terms to be in SELECT. Broken analysis"); AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), subPlan.getRoot(), subPlan.getRoot().getOutputSymbols(), ImmutableMap.<Symbol, FunctionCall>of(), ImmutableMap.<Symbol, Signature>of(), ImmutableMap.<Symbol, Symbol>of(), AggregationNode.Step.SINGLE, Optional.empty(), 1.0, Optional.empty()); return new PlanBuilder(subPlan.getTranslations(), aggregation, subPlan.getSampleWeight()); } return subPlan; }
private PhysicalOperation planGlobalAggregation(int operatorId, AggregationNode node, PhysicalOperation source) { int outputChannel = 0; ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); List<AccumulatorFactory> accumulatorFactories = new ArrayList<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); accumulatorFactories.add(buildAccumulatorFactory(source, node.getFunctions().get(symbol), entry.getValue(), node.getMasks().get(entry.getKey()), Optional.<Integer>empty(), node.getSampleWeight(), node.getConfidence())); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } OperatorFactory operatorFactory = new AggregationOperatorFactory(operatorId, node.getId(), node.getStep(), accumulatorFactories); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); }
@Override public Void visitAggregation(AggregationNode node, Void context) { PlanNode source = node.getSource(); source.accept(this, context); // visit child verifyUniqueId(node); Set<Symbol> inputs = ImmutableSet.copyOf(source.getOutputSymbols()); checkDependencies(inputs, node.getGroupBy(), "Invalid node. Group by symbols (%s) not in source plan output (%s)", node.getGroupBy(), node.getSource().getOutputSymbols()); if (node.getSampleWeight().isPresent()) { checkArgument(inputs.contains(node.getSampleWeight().get()), "Invalid node. Sample weight symbol (%s) is not in source plan output (%s)", node.getSampleWeight().get(), node.getSource().getOutputSymbols()); } for (FunctionCall call : node.getAggregations().values()) { Set<Symbol> dependencies = DependencyExtractor.extractUnique(call); checkDependencies(inputs, dependencies, "Invalid node. Aggregation dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } return null; }
@Override protected Object visitArrayConstructor(ArrayConstructor node, Object context) { Type elementType = ((ArrayType) expressionTypes.get(node)).getElementType(); BlockBuilder arrayBlockBuilder = elementType.createBlockBuilder(new BlockBuilderStatus(), node.getValues().size()); for (Expression expression : node.getValues()) { Object value = process(expression, context); if (value instanceof Expression) { return visitFunctionCall(new FunctionCall(QualifiedName.of(ArrayConstructor.ARRAY_CONSTRUCTOR), node.getValues()), context); } writeNativeValue(elementType, arrayBlockBuilder, value); } return arrayBlockBuilder.build(); }
@Override public Expression rewriteCurrentTime(CurrentTime node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { if (node.getPrecision() != null) { throw new UnsupportedOperationException("not yet implemented: non-default precision"); } switch (node.getType()) { case DATE: return new FunctionCall(new QualifiedName("current_date"), ImmutableList.<Expression>of()); case TIME: return new FunctionCall(new QualifiedName("current_time"), ImmutableList.<Expression>of()); case LOCALTIME: return new FunctionCall(new QualifiedName("localtime"), ImmutableList.<Expression>of()); case TIMESTAMP: return new FunctionCall(new QualifiedName("current_timestamp"), ImmutableList.<Expression>of()); case LOCALTIMESTAMP: return new FunctionCall(new QualifiedName("localtimestamp"), ImmutableList.<Expression>of()); default: throw new UnsupportedOperationException("not yet implemented: " + node.getType()); } }
public static boolean isCountConstant(ProjectNode projectNode, FunctionCall functionCall, Signature signature) { if (!"count".equals(signature.getName()) || signature.getArgumentTypes().size() != 1 || !signature.getReturnType().getBase().equals(StandardTypes.BIGINT)) { return false; } Expression argument = functionCall.getArguments().get(0); if (argument instanceof Literal && !(argument instanceof NullLiteral)) { return true; } if (argument instanceof QualifiedNameReference) { QualifiedNameReference qualifiedNameReference = (QualifiedNameReference) argument; QualifiedName qualifiedName = qualifiedNameReference.getName(); Symbol argumentSymbol = Symbol.fromQualifiedName(qualifiedName); Expression argumentExpression = projectNode.getAssignments().get(argumentSymbol); return (argumentExpression instanceof Literal) && (!(argumentExpression instanceof NullLiteral)); } return false; }
@Override public PlanNode visitSample(SampleNode node, RewriteContext<Void> context) { if (node.getSampleType() == SampleNode.Type.BERNOULLI) { PlanNode rewrittenSource = context.rewrite(node.getSource()); ComparisonExpression expression = new ComparisonExpression( ComparisonExpression.Type.LESS_THAN, new FunctionCall(QualifiedName.of("rand"), ImmutableList.<Expression>of()), new DoubleLiteral(Double.toString(node.getSampleRatio()))); return new FilterNode(node.getId(), rewrittenSource, expression); } else if (node.getSampleType() == SampleNode.Type.POISSONIZED || node.getSampleType() == SampleNode.Type.SYSTEM) { return context.defaultRewrite(node); } throw new UnsupportedOperationException("not yet implemented"); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<Symbol>> context) { // optimize if and only if // all aggregation functions have a single common distinct mask symbol // AND all aggregation functions have mask Set<Symbol> masks = ImmutableSet.copyOf(node.getMasks().values()); if (masks.size() != 1 || node.getMasks().size() != node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); } PlanNode source = context.rewrite(node.getSource(), Optional.of(Iterables.getOnlyElement(masks))); Map<Symbol, FunctionCall> aggregations = ImmutableMap.copyOf(Maps.transformValues(node.getAggregations(), call -> new FunctionCall(call.getName(), call.getWindow(), false, call.getArguments()))); return new AggregationNode(idAllocator.getNextId(), source, node.getGroupBy(), aggregations, node.getFunctions(), Collections.emptyMap(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
protected static String processFuncQuarter(Formatter formatter, FunctionCall node) { FunctionCall month = new FunctionCall(new QualifiedName("month"), node.getArguments()); ArithmeticExpression substract = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, month, new LongLiteral("1")); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, substract, new LongLiteral("3")); FunctionCall floor = new FunctionCall(new QualifiedName("floor"), Arrays.asList(divide)); ArithmeticExpression add = new ArithmeticExpression(ArithmeticExpression.Type.ADD, floor, new LongLiteral("1")); return formatter.process(add, null); }
protected static String processFuncSinh(Formatter formatter, FunctionCall node) { NegativeExpression negExp = new NegativeExpression(node.getArguments().get(0)); FunctionCall termA = new FunctionCall(new QualifiedName("exp"), node.getArguments()); FunctionCall termB = new FunctionCall(new QualifiedName("exp"), Arrays.asList(negExp)); ArithmeticExpression substract = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, termA, termB); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, substract, new LongLiteral("2")); return formatter.process(divide, null); }
protected static String processFuncCosh(Formatter formatter, FunctionCall node) { NegativeExpression negExp = new NegativeExpression(node.getArguments().get(0)); FunctionCall termA = new FunctionCall(new QualifiedName("exp"), node.getArguments()); FunctionCall termB = new FunctionCall(new QualifiedName("exp"), Arrays.asList(negExp)); ArithmeticExpression add = new ArithmeticExpression(ArithmeticExpression.Type.ADD, termA, termB); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, add, new LongLiteral("2")); return formatter.process(divide, null); }
protected static String processFuncTanh(Formatter formatter, FunctionCall node, DBType dbType) { /* * if (dbType == DBType.ACCESS) { // 20150803: ToDo Access doesn't like * using iif() to guard against division by 0 so I can only write plain * formula StringBuilder builder = new StringBuilder(); * builder.append("((exp(") * .append(formatter.process(node.getArguments().get(0), * null)).append(")") .append(" - ") .append("exp(-(") * .append(formatter.process(node.getArguments().get(0), * null)).append(")") .append("))") .append(" / ") .append("((exp(") * .append(formatter.process(node.getArguments().get(0), * null)).append(")") .append(" + ") .append("exp(-(") * .append(formatter.process(node.getArguments().get(0), * null)).append(")") .append("))))"); return builder.toString(); } else * { NegativeExpression negExp = new * NegativeExpression(node.getArguments().get(0)); FunctionCall termA = * new FunctionCall(new QualifiedName("exp"), node.getArguments()); * FunctionCall termB = new FunctionCall(new QualifiedName("exp"), * Arrays.asList(negExp)); ArithmeticExpression subtract = new * ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, termA, * termB); ArithmeticExpression add = new * ArithmeticExpression(ArithmeticExpression.Type.ADD, termA, termB); * ArithmeticExpression divide = new * ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, subtract, * add); return formatter.process(divide, null); } */ NegativeExpression negExp = new NegativeExpression(node.getArguments().get(0)); FunctionCall termA = new FunctionCall(new QualifiedName("exp"), node.getArguments()); FunctionCall termB = new FunctionCall(new QualifiedName("exp"), Arrays.asList(negExp)); ArithmeticExpression subtract = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, termA, termB); ArithmeticExpression add = new ArithmeticExpression(ArithmeticExpression.Type.ADD, termA, termB); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, subtract, add); return formatter.process(divide, null); }
protected static String processFuncAsin(Formatter formatter, FunctionCall node, DBType dbType) { /* * if (dbType == DBType.ACCESS) { // 20150803: ToDo Access doesn't like * using iif() to guard against division by 0 so I can only write plain * formula StringBuilder builder = new StringBuilder(); * builder.append("atan(") * .append(formatter.process(node.getArguments().get(0), null)) * .append(" / ") * .append("sqrt(1-power(").append(formatter.process(node.getArguments() * .get(0), null)).append(", 2))") .append(')'); return * builder.toString(); } else { FunctionCall xx = new FunctionCall(new * QualifiedName("power"), Arrays.asList(node.getArguments().get(0), new * LongLiteral("2"))); ArithmeticExpression subtract = new * ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, new * LongLiteral("1"), xx); FunctionCall sqrt = new FunctionCall(new * QualifiedName("sqrt"), Arrays.asList(subtract)); ArithmeticExpression * divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, * node.getArguments().get(0), sqrt); FunctionCall atan = new * FunctionCall(new QualifiedName("atan"), Arrays.asList(divide)); * return formatter.process(atan, null); } */ FunctionCall xx = new FunctionCall(new QualifiedName("power"), Arrays.asList(node.getArguments().get(0), new LongLiteral("2"))); ArithmeticExpression subtract = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, new LongLiteral("1"), xx); FunctionCall sqrt = new FunctionCall(new QualifiedName("sqrt"), Arrays.asList(subtract)); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, node.getArguments().get( 0), sqrt); FunctionCall atan = new FunctionCall(new QualifiedName("atan"), Arrays.asList(divide)); return formatter.process(atan, null); }
protected static String processFuncAsinh(Formatter formatter, FunctionCall node) { ArithmeticExpression zSquare = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, node.getArguments().get( 0), node.getArguments().get(0)); ArithmeticExpression zSquareAddOne = new ArithmeticExpression(ArithmeticExpression.Type.ADD, zSquare, new LongLiteral("1")); FunctionCall sqrt = new FunctionCall(new QualifiedName("sqrt"), Arrays.asList(zSquareAddOne)); ArithmeticExpression zAddSqrt = new ArithmeticExpression(ArithmeticExpression.Type.ADD, node.getArguments().get( 0), sqrt); FunctionCall ln = new FunctionCall(new QualifiedName("ln"), Arrays.asList(zAddSqrt)); return formatter.process(ln, null); }
protected static String processFuncAcosh(Formatter formatter, FunctionCall node) { ArithmeticExpression zAddOne = new ArithmeticExpression(ArithmeticExpression.Type.ADD, node.getArguments().get( 0), new LongLiteral("1")); FunctionCall sqrtZAddOne = new FunctionCall(new QualifiedName("sqrt"), Arrays.asList(zAddOne)); ArithmeticExpression zSubOne = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, node.getArguments().get( 0), new LongLiteral("1")); FunctionCall sqrtZSubOne = new FunctionCall(new QualifiedName("sqrt"), Arrays.asList(zSubOne)); ArithmeticExpression sqrtMultiply = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, sqrtZAddOne, sqrtZSubOne); ArithmeticExpression zAddSqrtMultiply = new ArithmeticExpression(ArithmeticExpression.Type.ADD, node.getArguments().get( 0), sqrtMultiply); FunctionCall ln = new FunctionCall(new QualifiedName("ln"), Arrays.asList(zAddSqrtMultiply)); return formatter.process(ln, null); }
protected static String processFuncAtanh(Formatter formatter, FunctionCall node) { ArithmeticExpression oneAddZ = new ArithmeticExpression(ArithmeticExpression.Type.ADD, new LongLiteral("1"), node.getArguments().get( 0)); ArithmeticExpression oneSubZ = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, new LongLiteral("1"), node.getArguments().get( 0)); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, oneAddZ, oneSubZ); FunctionCall ln = new FunctionCall(new QualifiedName("ln"), Arrays.asList(divide)); ArithmeticExpression multiply = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, new DoubleLiteral("0.5"), ln); return formatter.process(multiply, null); }
protected static String processFuncPower(Formatter formatter, FunctionCall node) { FunctionCall ln = new FunctionCall(new QualifiedName("ln"), Arrays.asList(node.getArguments().get(0))); ArithmeticExpression multiply = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, node.getArguments().get( 1), ln); FunctionCall exp = new FunctionCall(new QualifiedName("exp"), Arrays.asList(multiply)); return formatter.process(exp, null); }
protected static String processFuncAtan2(Formatter formatter, FunctionCall node) { Expression x = node.getArguments().get(0); Expression y = node.getArguments().get(1); FunctionCall xx = new FunctionCall(new QualifiedName("power"), Arrays.asList(x, new LongLiteral("2"))); FunctionCall yy = new FunctionCall(new QualifiedName("power"), Arrays.asList(y, new LongLiteral("2"))); ArithmeticExpression xxAddyy = new ArithmeticExpression(ArithmeticExpression.Type.ADD, xx, yy); FunctionCall sqrt_xxAddyy = new FunctionCall(new QualifiedName("sqrt"), Arrays.asList(xxAddyy)); ArithmeticExpression substract = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, sqrt_xxAddyy, x); ArithmeticExpression divide = new ArithmeticExpression(ArithmeticExpression.Type.DIVIDE, substract, y); FunctionCall arctan = new FunctionCall(new QualifiedName("atan"), Arrays.asList(divide)); ArithmeticExpression multiply = new ArithmeticExpression(ArithmeticExpression.Type.MULTIPLY, new DoubleLiteral("2"), arctan); return formatter.process(multiply, null); }
protected static String processFuncNullifzero(Formatter formatter, FunctionCall node) { Expression x = node.getArguments().get(0); List<WhenClause> listWhen = new ArrayList<WhenClause>(); ComparisonExpression ce = new ComparisonExpression(ComparisonExpression.Type.EQUAL, x, new LongLiteral("0")); WhenClause wc = new WhenClause(ce, new NullLiteral()); listWhen.add(wc); SearchedCaseExpression sce = new SearchedCaseExpression(listWhen, x); return formatter.process(sce, null); }
protected static Expression processFuncLast(ComparisonExpression node) { System.out.println("Processing last()"); Expression rightNode = node.getRight(); Expression leftNode = node.getLeft(); FunctionCall last = (FunctionCall) rightNode; // # of arguments are already checked outside 1 or 2 String number = last.getArguments().get(0).toString(); String format = "DAY"; // default if (last.getArguments().size() == 2) { format = last.getArguments().get(1).toString().replaceAll("\"", ""); } IntervalLiteral.Sign sign; if (number.startsWith("-")) { sign = IntervalLiteral.Sign.NEGATIVE; number = number.substring(1); } else { sign = IntervalLiteral.Sign.POSITIVE; } CurrentTime cTime = new CurrentTime(CurrentTime.Type.DATE); IntervalLiteral interval = new IntervalLiteral(number, sign, format); ArithmeticExpression arithmOp = new ArithmeticExpression(ArithmeticExpression.Type.SUBTRACT, cTime, interval); BetweenPredicate bPredicate = new BetweenPredicate(leftNode, arithmOp, cTime); return bPredicate; }
@Override protected String visitFunctionCall(FunctionCall node, Void context) { StringBuilder builder = new StringBuilder(); String functionName = VeroFunctions.getFunctionName(node); //int numArguments = node.getArguments().size(); String arguments = joinExpressions(node.getArguments()); if (node.getArguments().isEmpty() && "count".equalsIgnoreCase(node.getName().getSuffix())) { arguments = "*"; } if (node.isDistinct()) { arguments = "DISTINCT " + arguments; } if (functionName.equals("quarter")) { builder.append(processFuncQuarter(this, node)); } else if (functionName.equals("concat")) { builder.append(joinExpressions(" || ", node.getArguments())); } else { // use super return super.visitFunctionCall(node, context); } if (node.getWindow().isPresent()) { builder.append(" OVER ").append(visitWindow(node.getWindow().get(), null)); } return builder.toString(); }
/** * Extracts the literal value from an expression (if expression is supported) * @param expression * @param state * @return a Long, Boolean, Double or String object */ private Object getLiteralValue(Expression expression, QueryState state){ if(expression instanceof LongLiteral) return ((LongLiteral)expression).getValue(); else if(expression instanceof BooleanLiteral) return ((BooleanLiteral)expression).getValue(); else if(expression instanceof DoubleLiteral) return ((DoubleLiteral)expression).getValue(); else if(expression instanceof StringLiteral) return ((StringLiteral)expression).getValue(); else if(expression instanceof ArithmeticUnaryExpression){ ArithmeticUnaryExpression unaryExp = (ArithmeticUnaryExpression)expression; Sign sign = unaryExp.getSign(); Number num = (Number)getLiteralValue(unaryExp.getValue(), state); if(sign == Sign.MINUS){ if(num instanceof Long) return -1*num.longValue(); else if(num instanceof Double) return -1*num.doubleValue(); else { state.addException("Unsupported numeric literal expression encountered : "+num.getClass()); return null; } } return num; } else if(expression instanceof FunctionCall){ FunctionCall fc = (FunctionCall)expression; if(fc.getName().toString().equals("now")) return new Date(); else state.addException("Function '"+fc.getName()+"' is not supported"); }else if(expression instanceof CurrentTime){ CurrentTime ct = (CurrentTime)expression; if(ct.getType() == CurrentTime.Type.DATE) return new LocalDate().toDate(); else if(ct.getType() == CurrentTime.Type.TIME) return new Date(new LocalTime(DateTimeZone.UTC).getMillisOfDay()); else if(ct.getType() == CurrentTime.Type.TIMESTAMP) return new Date(); else if(ct.getType() == CurrentTime.Type.LOCALTIME) return new Date(new LocalTime(DateTimeZone.UTC).getMillisOfDay()); else if(ct.getType() == CurrentTime.Type.LOCALTIMESTAMP) return new Date(); else state.addException("CurrentTime function '"+ct.getType()+"' is not supported"); }else state.addException("Literal type "+expression.getClass().getSimpleName()+" is not supported"); return null; }
@Override public Void visitAggregation(AggregationNode node, Integer indent) { String type = ""; if (node.getStep() != AggregationNode.Step.SINGLE) { type = format("(%s)", node.getStep().toString()); } String key = ""; if (!node.getGroupBy().isEmpty()) { key = node.getGroupBy().toString(); } String sampleWeight = ""; if (node.getSampleWeight().isPresent()) { sampleWeight = format("[sampleWeight = %s]", node.getSampleWeight().get()); } print(indent, "- Aggregate%s%s%s => [%s]", type, key, sampleWeight, formatOutputs(node.getOutputSymbols())); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { if (node.getMasks().containsKey(entry.getKey())) { print(indent + 2, "%s := %s (mask = %s)", entry.getKey(), entry.getValue(), node.getMasks().get(entry.getKey())); } else { print(indent + 2, "%s := %s", entry.getKey(), entry.getValue()); } } return processChildren(node, indent + 1); }
public Symbol newSymbol(Expression expression, Type type, String suffix) { String nameHint = "expr"; if (expression instanceof QualifiedNameReference) { nameHint = ((QualifiedNameReference) expression).getName().getSuffix(); } else if (expression instanceof FunctionCall) { nameHint = ((FunctionCall) expression).getName().getSuffix(); } return newSymbol(nameHint, type, suffix); }
@Override protected Void visitFunctionCall(FunctionCall node, AtomicBoolean deterministic) { // TODO: total hack to figure out if a function is deterministic. martint should fix this when he refactors the planning code if (node.getName().equals(new QualifiedName("rand")) || node.getName().equals(new QualifiedName("random"))) { deterministic.set(false); } return super.visitFunctionCall(node, deterministic); }
private AccumulatorFactory buildAccumulatorFactory( PhysicalOperation source, Signature function, FunctionCall call, @Nullable Symbol mask, Optional<Integer> defaultMaskChannel, Optional<Symbol> sampleWeight, double confidence) { List<Integer> arguments = new ArrayList<>(); for (Expression argument : call.getArguments()) { Symbol argumentSymbol = Symbol.fromQualifiedName(((QualifiedNameReference) argument).getName()); arguments.add(source.getLayout().get(argumentSymbol)); } Optional<Integer> maskChannel = defaultMaskChannel; if (mask != null) { maskChannel = Optional.of(source.getLayout().get(mask)); } Optional<Integer> sampleWeightChannel = Optional.empty(); if (sampleWeight.isPresent()) { sampleWeightChannel = Optional.of(source.getLayout().get(sampleWeight.get())); } return metadata.getFunctionRegistry().getAggregateFunctionImplementation(function).bind(arguments, maskChannel, sampleWeightChannel, confidence); }
private PlanNode distinct(PlanNode node) { return new AggregationNode(idAllocator.getNextId(), node, node.getOutputSymbols(), ImmutableMap.<Symbol, FunctionCall>of(), ImmutableMap.<Symbol, Signature>of(), ImmutableMap.<Symbol, Symbol>of(), AggregationNode.Step.SINGLE, Optional.empty(), 1.0, Optional.empty()); }
@Override public Void visitWindow(WindowNode node, Void context) { PlanNode source = node.getSource(); source.accept(this, context); // visit child verifyUniqueId(node); Set<Symbol> inputs = ImmutableSet.copyOf(source.getOutputSymbols()); checkDependencies(inputs, node.getPartitionBy(), "Invalid node. Partition by symbols (%s) not in source plan output (%s)", node.getPartitionBy(), node.getSource().getOutputSymbols()); checkDependencies(inputs, node.getOrderBy(), "Invalid node. Order by symbols (%s) not in source plan output (%s)", node.getOrderBy(), node.getSource().getOutputSymbols()); ImmutableList.Builder<Symbol> bounds = ImmutableList.builder(); if (node.getFrame().getStartValue().isPresent()) { bounds.add(node.getFrame().getStartValue().get()); } if (node.getFrame().getEndValue().isPresent()) { bounds.add(node.getFrame().getEndValue().get()); } checkDependencies(inputs, bounds.build(), "Invalid node. Frame bounds (%s) not in source plan output (%s)", bounds.build(), node.getSource().getOutputSymbols()); for (FunctionCall call : node.getWindowFunctions().values()) { Set<Symbol> dependencies = DependencyExtractor.extractUnique(call); checkDependencies(inputs, dependencies, "Invalid node. Window function dependencies (%s) not in source plan output (%s)", dependencies, node.getSource().getOutputSymbols()); } return null; }
@VisibleForTesting @NotNull public static Expression createFailureFunction(RuntimeException exception, Type type) { requireNonNull(exception, "Exception is null"); String failureInfo = JsonCodec.jsonCodec(FailureInfo.class).toJson(Failures.toFailure(exception).toFailureInfo()); FunctionCall jsonParse = new FunctionCall(QualifiedName.of("json_parse"), ImmutableList.of(new StringLiteral(failureInfo))); FunctionCall failureFunction = new FunctionCall(QualifiedName.of("fail"), ImmutableList.of(jsonParse)); return new Cast(failureFunction, type.getTypeSignature().toString()); }
@Override public Expression rewriteExtract(Extract node, Void context, ExpressionTreeRewriter<Void> treeRewriter) { Expression value = treeRewriter.rewrite(node.getExpression(), context); switch (node.getField()) { case YEAR: return new FunctionCall(new QualifiedName("year"), ImmutableList.of(value)); case QUARTER: return new FunctionCall(new QualifiedName("quarter"), ImmutableList.of(value)); case MONTH: return new FunctionCall(new QualifiedName("month"), ImmutableList.of(value)); case WEEK: return new FunctionCall(new QualifiedName("week"), ImmutableList.of(value)); case DAY: case DAY_OF_MONTH: return new FunctionCall(new QualifiedName("day"), ImmutableList.of(value)); case DAY_OF_WEEK: case DOW: return new FunctionCall(new QualifiedName("day_of_week"), ImmutableList.of(value)); case DAY_OF_YEAR: case DOY: return new FunctionCall(new QualifiedName("day_of_year"), ImmutableList.of(value)); case YEAR_OF_WEEK: case YOW: return new FunctionCall(new QualifiedName("year_of_week"), ImmutableList.of(value)); case HOUR: return new FunctionCall(new QualifiedName("hour"), ImmutableList.of(value)); case MINUTE: return new FunctionCall(new QualifiedName("minute"), ImmutableList.of(value)); case SECOND: return new FunctionCall(new QualifiedName("second"), ImmutableList.of(value)); case TIMEZONE_MINUTE: return new FunctionCall(new QualifiedName("timezone_minute"), ImmutableList.of(value)); case TIMEZONE_HOUR: return new FunctionCall(new QualifiedName("timezone_hour"), ImmutableList.of(value)); } throw new UnsupportedOperationException("not yet implemented: " + node.getField()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context) { Map<Symbol, FunctionCall> aggregations = new LinkedHashMap<>(node.getAggregations()); Map<Symbol, Signature> functions = new LinkedHashMap<>(node.getFunctions()); PlanNode source = context.rewrite(node.getSource()); if (source instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) source; for (Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); FunctionCall functionCall = entry.getValue(); Signature signature = node.getFunctions().get(symbol); if (isCountConstant(projectNode, functionCall, signature)) { aggregations.put(symbol, new FunctionCall(functionCall.getName(), functionCall.isDistinct(), ImmutableList.<Expression>of())); functions.put(symbol, new Signature("count", AGGREGATE, StandardTypes.BIGINT)); } } } return new AggregationNode( node.getId(), source, node.getGroupBy(), aggregations, functions, node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
@JsonCreator public AggregationNode(@JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("groupBy") List<Symbol> groupByKeys, @JsonProperty("aggregations") Map<Symbol, FunctionCall> aggregations, @JsonProperty("functions") Map<Symbol, Signature> functions, @JsonProperty("masks") Map<Symbol, Symbol> masks, @JsonProperty("step") Step step, @JsonProperty("sampleWeight") Optional<Symbol> sampleWeight, @JsonProperty("confidence") double confidence, @JsonProperty("hashSymbol") Optional<Symbol> hashSymbol) { super(id); this.source = source; this.groupByKeys = ImmutableList.copyOf(requireNonNull(groupByKeys, "groupByKeys is null")); this.aggregations = ImmutableMap.copyOf(requireNonNull(aggregations, "aggregations is null")); this.functions = ImmutableMap.copyOf(requireNonNull(functions, "functions is null")); this.masks = ImmutableMap.copyOf(requireNonNull(masks, "masks is null")); for (Symbol mask : masks.keySet()) { checkArgument(aggregations.containsKey(mask), "mask does not match any aggregations"); } this.step = step; this.sampleWeight = requireNonNull(sampleWeight, "sampleWeight is null"); checkArgument(confidence >= 0 && confidence <= 1, "confidence must be in [0, 1]"); this.confidence = confidence; this.hashSymbol = hashSymbol; }
@Override protected Void visitFunctionCall(FunctionCall node, Void context) { if (metadata.isAggregationFunction(node.getName()) && !node.getWindow().isPresent()) { aggregates.add(node); return null; } return super.visitFunctionCall(node, null); }
@Override protected Void visitFunctionCall(FunctionCall node, Void context) { if (node.getWindow().isPresent()) { windowFunctions.add(node); return null; } return super.visitFunctionCall(node, null); }
@Override protected Boolean visitFunctionCall(FunctionCall node, Void context) { if (!node.getWindow().isPresent() && metadata.isAggregationFunction(node.getName())) { AggregateExtractor aggregateExtractor = new AggregateExtractor(metadata); WindowFunctionExtractor windowExtractor = new WindowFunctionExtractor(); for (Expression argument : node.getArguments()) { aggregateExtractor.process(argument, null); windowExtractor.process(argument, null); } if (!aggregateExtractor.getAggregates().isEmpty()) { throw new SemanticException(NESTED_AGGREGATION, node, "Cannot nest aggregations inside aggregation '%s': %s", node.getName(), aggregateExtractor.getAggregates()); } if (!windowExtractor.getWindowFunctions().isEmpty()) { throw new SemanticException(NESTED_WINDOW, node, "Cannot nest window functions inside aggregation '%s': %s", node.getName(), windowExtractor.getWindowFunctions()); } return true; } if (node.getWindow().isPresent() && !process(node.getWindow().get(), context)) { return false; } return node.getArguments().stream().allMatch(expression -> process(expression, context)); }
static void verifyNoAggregatesOrWindowFunctions(Metadata metadata, Expression predicate, String clause) { AggregateExtractor extractor = new AggregateExtractor(metadata); extractor.process(predicate, null); WindowFunctionExtractor windowExtractor = new WindowFunctionExtractor(); windowExtractor.process(predicate, null); List<FunctionCall> found = ImmutableList.copyOf(Iterables.concat(extractor.getAggregates(), windowExtractor.getWindowFunctions())); if (!found.isEmpty()) { throw new SemanticException(CANNOT_HAVE_AGGREGATIONS_OR_WINDOWS, predicate, "%s clause cannot contain aggregations or window functions: %s", clause, found); } }