diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index aa4fcc200d..3f9f9e2fbc 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -66,6 +66,7 @@ public R accept(AbstractNodeVisitor nodeVisitor, C context) { } public enum TrendlineType { - SMA + SMA, + WMA } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 7bf10964cf..69a5dc1038 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -6,6 +6,7 @@ package org.opensearch.sql.planner.physical; import static java.time.temporal.ChronoUnit.MILLIS; +import static java.util.stream.Collectors.toList; import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableMap.Builder; @@ -14,17 +15,23 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Queue; +import java.util.function.Function; +import java.util.stream.IntStream; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; @@ -106,44 +113,43 @@ private Map consumeInputTuple(ExprValue inputValue) { private static TrendlineAccumulator createAccumulator( Pair computation) { - // Add a switch statement based on computation type to choose the accumulator when more - // types of computations are supported. - return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); + return switch (computation.getKey().getComputationType()) { + case SMA -> new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); + case WMA -> new WeightedMovingAverageAccumulator( + computation.getKey(), computation.getValue()); + }; } /** Maintains stateful information for calculating the trendline. */ - private interface TrendlineAccumulator { - void accumulate(ExprValue value); + protected abstract static class TrendlineAccumulator { - ExprValue calculate(); + protected final LiteralExpression dataPointsNeeded; - static ArithmeticEvaluator getEvaluator(ExprCoreType type) { - switch (type) { - case DOUBLE: - return NumericArithmeticEvaluator.INSTANCE; - case DATE: - return DateArithmeticEvaluator.INSTANCE; - case TIME: - return TimeArithmeticEvaluator.INSTANCE; - case TIMESTAMP: - return TimestampArithmeticEvaluator.INSTANCE; + protected final Queue receivedValues; + + private TrendlineAccumulator(Trendline.TrendlineComputation config) { + Integer numberOfDataPoints = config.getNumberOfDataPoints(); + if (numberOfDataPoints <= 0) { + throw new SemanticCheckException( + String.format("Invalid dataPoints [%d] value.", numberOfDataPoints)); } - throw new IllegalArgumentException( - String.format("Invalid type %s used for moving average.", type.typeName())); + this.dataPointsNeeded = DSL.literal(numberOfDataPoints); + this.receivedValues = EvictingQueue.create(numberOfDataPoints); } + + abstract void accumulate(ExprValue value); + + abstract ExprValue calculate(); } - private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { - private final LiteralExpression dataPointsNeeded; - private final EvictingQueue receivedValues; + private static class SimpleMovingAverageAccumulator extends TrendlineAccumulator { private final ArithmeticEvaluator evaluator; private Expression runningTotal = null; public SimpleMovingAverageAccumulator( Trendline.TrendlineComputation computation, ExprCoreType type) { - dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); - receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints()); - evaluator = TrendlineAccumulator.getEvaluator(type); + super(computation); + evaluator = getEvaluator(type); } @Override @@ -185,133 +191,233 @@ public ExprValue calculate() { } return evaluator.evaluate(runningTotal, dataPointsNeeded); } - } - private interface ArithmeticEvaluator { - Expression calculateFirstTotal(List dataPoints); + static ArithmeticEvaluator getEvaluator(ExprCoreType type) { + return switch (type) { + case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> NumericArithmeticEvaluator.INSTANCE; + case DATE -> DateArithmeticEvaluator.INSTANCE; + case TIME -> TimeArithmeticEvaluator.INSTANCE; + case TIMESTAMP -> TimestampArithmeticEvaluator.INSTANCE; + default -> throw new SemanticCheckException( + String.format("Invalid type %s used for moving average.", type.typeName())); + }; + } - Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue); + private interface ArithmeticEvaluator { + Expression calculateFirstTotal(List dataPoints); - ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints); - } + Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue); - private static class NumericArithmeticEvaluator implements ArithmeticEvaluator { - private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator(); + ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints); + } - private NumericArithmeticEvaluator() {} + private static class NumericArithmeticEvaluator implements ArithmeticEvaluator { + private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator(); - @Override - public Expression calculateFirstTotal(List dataPoints) { - Expression total = DSL.literal(0.0D); - for (ExprValue dataPoint : dataPoints) { - total = DSL.add(total, DSL.literal(dataPoint.doubleValue())); + private NumericArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0.0D); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.doubleValue())); + } + return DSL.literal(total.valueOf().doubleValue()); } - return DSL.literal(total.valueOf().doubleValue()); - } - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return DSL.literal( - DSL.add(runningTotal, DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))) - .valueOf() - .doubleValue()); - } + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))) + .valueOf() + .doubleValue()); + } - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - return DSL.divide(runningTotal, numberOfDataPoints).valueOf(); + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return DSL.divide(runningTotal, numberOfDataPoints).valueOf(); + } } - } - private static class DateArithmeticEvaluator implements ArithmeticEvaluator { - private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator(); + private static class DateArithmeticEvaluator implements ArithmeticEvaluator { + private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator(); - private DateArithmeticEvaluator() {} + private DateArithmeticEvaluator() {} - @Override - public Expression calculateFirstTotal(List dataPoints) { - return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints); - } + @Override + public Expression calculateFirstTotal(List dataPoints) { + return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints); + } - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue); - } + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue); + } - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - final ExprValue timestampResult = - TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints); - return ExprValueUtils.dateValue(timestampResult.dateValue()); + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + final ExprValue timestampResult = + TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints); + return ExprValueUtils.dateValue(timestampResult.dateValue()); + } } - } - private static class TimeArithmeticEvaluator implements ArithmeticEvaluator { - private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator(); + private static class TimeArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator(); - private TimeArithmeticEvaluator() {} + private TimeArithmeticEvaluator() {} - @Override - public Expression calculateFirstTotal(List dataPoints) { - Expression total = DSL.literal(0); - for (ExprValue dataPoint : dataPoints) { - total = DSL.add(total, DSL.literal(MILLIS.between(LocalTime.MIN, dataPoint.timeValue()))); + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(MILLIS.between(LocalTime.MIN, dataPoint.timeValue()))); + } + return DSL.literal(total.valueOf().longValue()); } - return DSL.literal(total.valueOf().longValue()); - } - @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return DSL.literal( - DSL.add( - runningTotal, - DSL.subtract( - DSL.literal(MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), - DSL.literal(MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))) - .valueOf()); + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), + DSL.literal(MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timeValue( + LocalTime.MIN.plus( + DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), MILLIS)); + } } - @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - return ExprValueUtils.timeValue( - LocalTime.MIN.plus( - DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), MILLIS)); + private static class TimestampArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimestampArithmeticEvaluator INSTANCE = + new TimestampArithmeticEvaluator(); + + private TimestampArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli())); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(incomingValue.timestampValue().toEpochMilli()), + DSL.literal(evictedValue.timestampValue().toEpochMilli()))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timestampValue( + Instant.ofEpochMilli( + DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue())); + } } } - private static class TimestampArithmeticEvaluator implements ArithmeticEvaluator { - private static final TimestampArithmeticEvaluator INSTANCE = new TimestampArithmeticEvaluator(); + private static class WeightedMovingAverageAccumulator extends TrendlineAccumulator { - private TimestampArithmeticEvaluator() {} + private final Function, ExprValue> evaluator; + private final List weights; - @Override - public Expression calculateFirstTotal(List dataPoints) { - Expression total = DSL.literal(0); - for (ExprValue dataPoint : dataPoints) { - total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli())); - } - return DSL.literal(total.valueOf().longValue()); + public WeightedMovingAverageAccumulator( + Trendline.TrendlineComputation computation, ExprCoreType type) { + super(computation); + int dataPoints = computation.getNumberOfDataPoints(); + this.evaluator = getWmaEvaluator(type); + this.weights = + IntStream.rangeClosed(1, dataPoints) + .mapToObj(i -> i / ((dataPoints * (dataPoints + 1)) / 2d)) + .collect(toList()); + } + + Function, ExprValue> getWmaEvaluator(ExprCoreType type) { + return switch (type) { + case INTEGER, SHORT, LONG, FLOAT, DOUBLE -> WMA_NUMERIC_EVALUATOR; + case DATE, TIMESTAMP -> WMA_TIMESTAMP_EVALUATOR; + case TIME -> WMA_TIME_EVALUATOR; + default -> throw new SemanticCheckException( + String.format("Invalid type %s used for weighted moving average.", type.typeName())); + }; } @Override - public Expression add( - Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { - return DSL.literal( - DSL.add( - runningTotal, - DSL.subtract( - DSL.literal(incomingValue.timestampValue().toEpochMilli()), - DSL.literal(evictedValue.timestampValue().toEpochMilli()))) - .valueOf()); + public void accumulate(ExprValue value) { + receivedValues.add(value); } @Override - public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { - return ExprValueUtils.timestampValue( - Instant.ofEpochMilli(DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue())); + public ExprValue calculate() { + if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) { + return null; + } else if (dataPointsNeeded.valueOf().integerValue() == 1) { + return receivedValues.peek(); + } + return evaluator.apply(receivedValues); + } + + public final Function, ExprValue> WMA_NUMERIC_EVALUATOR = + (receivedValues) -> + new ExprDoubleValue(calculateWmaInDouble(receivedValues, ExprValue::doubleValue)); + ; + + public final Function, ExprValue> WMA_TIMESTAMP_EVALUATOR = + (receivedValues) -> { + Long wmaResult = + Math.round( + calculateWmaInDouble( + receivedValues, i -> (double) (i.timestampValue().toEpochMilli()))); + return ExprValueUtils.timestampValue(Instant.ofEpochMilli((wmaResult))); + }; + + public final Function, ExprValue> WMA_TIME_EVALUATOR = + (receivedValues) -> { + Long wmaResult = + Math.round( + calculateWmaInDouble( + receivedValues, + i -> (double) (MILLIS.between(LocalTime.MIN, i.timeValue())))); + return ExprValueUtils.timeValue(LocalTime.MIN.plus(wmaResult, MILLIS)); + }; + + /** + * Responsible to iterate the internal buffer, perform necessary calculation, and the up-to-date + * wma result in Double + * + * @param receivedValues internal buffer which stores all value in range. + * @param exprToDouble transformation function to convert incoming values to double for + * calcaution. + * @return wma result in Double form. + */ + private Double calculateWmaInDouble( + Queue receivedValues, Function exprToDouble) { + double sum = 0D; + Iterator weightIter = weights.iterator(); + for (ExprValue next : receivedValues) { + sum += exprToDouble.apply(next) * (weightIter.next()); + } + return sum; } } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index ef2c2907ce..e6a6d1e045 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -5,42 +5,106 @@ package org.opensearch.sql.planner.physical; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.WMA; +import static org.opensearch.sql.data.model.ExprValueUtils.tupleValue; import com.google.common.collect.ImmutableMap; import java.time.Instant; import java.time.LocalDate; import java.time.LocalTime; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Stream; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.SemanticCheckException; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -public class TrendlineOperatorTest { +public class TrendlineOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; + static Stream supportedDataTypes() { + return Stream.of( + Arguments.of(ExprCoreType.SHORT), + Arguments.of(ExprCoreType.INTEGER), + Arguments.of(ExprCoreType.LONG), + Arguments.of(ExprCoreType.FLOAT), + Arguments.of(ExprCoreType.DOUBLE)); + } + + static Stream unSupportedDataTypes() { + return Stream.of(SMA, WMA) + .flatMap( + trendlineType -> + Stream.of( + Arguments.of(trendlineType, ExprCoreType.UNDEFINED), + Arguments.of(trendlineType, ExprCoreType.BYTE), + Arguments.of(trendlineType, ExprCoreType.STRING), + Arguments.of(trendlineType, ExprCoreType.BOOLEAN), + Arguments.of(trendlineType, ExprCoreType.INTERVAL), + Arguments.of(trendlineType, ExprCoreType.IP), + Arguments.of(trendlineType, ExprCoreType.STRUCT), + Arguments.of(trendlineType, ExprCoreType.ARRAY))); + } + + static Stream invalidArguments() { + return Stream.of(SMA, WMA) + .flatMap( + trendlineType -> + Stream.of( + // WMA + Arguments.of( + 2, + AstDSL.field("distance"), + "distance_alias", + trendlineType, + ExprCoreType.ARRAY, + "DateType - Array"), + Arguments.of( + -100, + AstDSL.field("distance"), + "distance_alias", + trendlineType, + ExprCoreType.INTEGER, + "DataPoints - Negative"), + Arguments.of( + 0, + AstDSL.field("distance"), + "distance_alias", + trendlineType, + ExprCoreType.INTEGER, + "DataPoints - zero"))); + } + @Test public void calculates_simple_moving_average_one_field_one_sample() { - when(inputPlan.hasNext()).thenReturn(true, false); - when(inputPlan.next()) - .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); - + mockPlanWithData(List.of(tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( inputPlan, @@ -49,21 +113,20 @@ public void calculates_simple_moving_average_one_field_one_sample() { AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), - plan.next()); + List result = execute(plan); + assertEquals(1, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); } @Test public void calculates_simple_moving_average_one_field_two_samples() { - when(inputPlan.hasNext()).thenReturn(true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -73,26 +136,22 @@ public void calculates_simple_moving_average_one_field_two_samples() { AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(2, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)))); } @Test public void calculates_simple_moving_average_one_field_two_samples_three_rows() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -102,31 +161,23 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)))); } @Test public void calculates_simple_moving_average_multiple_computations() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); var plan = new TrendlineOperator( @@ -139,33 +190,27 @@ public void calculates_simple_moving_average_multiple_computations() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)))); } @Test public void alias_overwrites_input_field() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); var plan = new TrendlineOperator( @@ -175,26 +220,23 @@ public void alias_overwrites_input_field() { AstDSL.computation(2, AstDSL.field("distance"), "time", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)), plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100)), + tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), + tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)))); } @Test public void calculates_simple_moving_average_one_field_two_samples_three_rows_null_value() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); var plan = new TrendlineOperator( @@ -204,28 +246,23 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0)), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0)))); } @Test public void use_null_value() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); var plan = new TrendlineOperator( @@ -235,46 +272,25 @@ public void use_null_value() { AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), - plan.next()); - assertFalse(plan.hasNext()); - } - - @Test - public void use_illegal_core_type() { - assertThrows( - IllegalArgumentException.class, - () -> { - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), - ExprCoreType.ARRAY))); - }); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)))); } @Test public void calculates_simple_moving_average_date() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - ExprValueUtils.tupleValue( + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + tupleValue( ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); + tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)))))); var plan = new TrendlineOperator( @@ -284,44 +300,35 @@ public void calculates_simple_moving_average_date() { AstDSL.computation(2, AstDSL.field("date"), "date_alias", SMA), ExprCoreType.DATE))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "date", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), - "date_alias", - ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9)))), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), + tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9)))))); } @Test public void calculates_simple_moving_average_time() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), - ExprValueUtils.tupleValue( + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + tupleValue( ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), - ExprValueUtils.tupleValue( - ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); + tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12)))))); var plan = new TrendlineOperator( @@ -331,37 +338,34 @@ public void calculates_simple_moving_average_time() { AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), ExprCoreType.TIME))); - plan.open(); - assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(9))), - plan.next()); - assertFalse(plan.hasNext()); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("time", LocalTime.MIN)), + tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), + tupleValue( + ImmutableMap.of( + "time", + LocalTime.MIN.plusHours(12), + "time_alias", + LocalTime.MIN.plusHours(9))))); } @Test public void calculates_simple_moving_average_timestamp() { - when(inputPlan.hasNext()).thenReturn(true, true, true, false); - when(inputPlan.next()) - .thenReturn( - ExprValueUtils.tupleValue( - ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), - ExprValueUtils.tupleValue( + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + tupleValue( ImmutableMap.of( "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), - ExprValueUtils.tupleValue( + tupleValue( ImmutableMap.of( - "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500)))))); var plan = new TrendlineOperator( @@ -371,28 +375,388 @@ public void calculates_simple_moving_average_timestamp() { AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", SMA), ExprCoreType.TIMESTAMP))); + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), + tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(500))), + tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1250))))); + } + + @Test + public void calculates_weighted_moving_average_one_field_one_sample() { + mockPlanWithData(List.of(tupleValue(ImmutableMap.of("distance", 100, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); + plan.open(); + assertTrue(plan.hasNext()); assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1000), - "timestamp_alias", - Instant.EPOCH.plusMillis(500))), - plan.next()); - assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "timestamp", - Instant.EPOCH.plusMillis(1500), - "timestamp_alias", - Instant.EPOCH.plusMillis(1250))), + tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), plan.next()); - assertFalse(plan.hasNext()); + } + + @Test + public void calculates_weighted_moving_average_one_field_four_samples_four_rows() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); + + List result = execute(plan); + assertEquals(4, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)))); + } + + @Test + public void calculates_weighted_moving_average_one_field_five_samples_four_rows() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); + + List result = execute(plan); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200)))); + } + + @Test + public void calculates_weighted_moving_average_multiple_computations() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)))); + + var plan = + new TrendlineOperator( + inputPlan, + Arrays.asList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE), + Pair.of( + AstDSL.computation(4, AstDSL.field("time"), "time_alias", WMA), + ExprCoreType.DOUBLE))); + + List result = execute(plan); + assertEquals(4, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 190, "time_alias", 19.0)))); + } + + @Test + public void calculates_weighted_moving_average_null_value() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), + ExprCoreType.DOUBLE))); + + List result = execute(plan); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10)), + tupleValue(ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 290)))); + } + + @Test + public void calculates_weighted_moving_average_date() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), + tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("date"), "date_alias", WMA), + ExprCoreType.DATE))); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(4)))), + tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(10)))))); + } + + @Test + public void calculates_weighted_moving_average_time() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), + tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12)))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", WMA), + ExprCoreType.TIME))); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("time", LocalTime.MIN)), + tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(4))), + tupleValue( + ImmutableMap.of( + "time", + LocalTime.MIN.plusHours(12), + "time_alias", + LocalTime.MIN.plusHours(10))))); + } + + @Test + public void calculates_weighted_moving_average_timestamp() { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), + tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500)))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", WMA), + ExprCoreType.TIMESTAMP))); + + List result = execute(plan); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), + tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(667))), + tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1333))))); + } + + @ParameterizedTest + @MethodSource("supportedDataTypes") + public void trendLine_dataType_support_sma(ExprCoreType supportedType) { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", SMA), + supportedType))); + + List result = execute(plan); + System.out.println(result); + assertEquals(4, result.size()); + assertThat( + String.format( + "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 175)))); + } + + @ParameterizedTest + @MethodSource("supportedDataTypes") + public void trendLine_dataType_support_wma(ExprCoreType supportedType) { + mockPlanWithData( + List.of( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(4, AstDSL.field("distance"), "distance_alias", WMA), + supportedType))); + + List result = execute(plan); + System.out.println(result); + assertEquals(4, result.size()); + assertThat( + String.format( + "Assertion error on TrendLine-WMA dataType support: %s", supportedType.typeName()), + result, + containsInAnyOrder( + tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 190)))); + } + + @ParameterizedTest + @MethodSource("unSupportedDataTypes") + public void trendLine_unsupported_dataType( + Trendline.TrendlineType trendlineType, ExprCoreType dataType) { + assertThrows( + SemanticCheckException.class, + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation( + 2, AstDSL.field("distance"), "distance_alias", trendlineType), + dataType)))); + } + + @ParameterizedTest + @MethodSource("invalidArguments") + public void use_invalid_configuration( + Integer dataPoints, + Field field, + String alias, + Trendline.TrendlineType trendlineType, + ExprCoreType dataType, + String errorMessage) { + assertThrows( + SemanticCheckException.class, + () -> + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(dataPoints, field, alias, trendlineType), dataType))), + "Unsupported arguments: " + errorMessage); + } + + private void mockPlanWithData(List inputs) { + List hasNextElements = new ArrayList<>(Collections.nCopies(inputs.size(), true)); + hasNextElements.add(false); + + Iterator hasNextIterator = hasNextElements.iterator(); + when(inputPlan.hasNext()) + .thenAnswer(i -> hasNextIterator.hasNext() ? hasNextIterator.next() : null); + Iterator iterator = inputs.iterator(); + when(inputPlan.next()).thenAnswer(i -> iterator.hasNext() ? iterator.next() : null); } } diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index e6df0d7a2c..923b634484 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -15,17 +15,16 @@ Description Syntax ============ -`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` +`TRENDLINE [sort <[+|-] sort-field>] (number-of-datapoints, field) [AS alias] [(number-of-datapoints, field) [AS alias]]...` * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. * sort-field: mandatory when sorting is used. The field used to sort. +* trendline-type: mandatory. The type of algorithm being used for the calculation, only SMA (simple moving average) or WMA (weighted moving average) are supported at the moment. * number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). * field: mandatory. The name of the field the moving average should be calculated for. -* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). +* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "__trendline"). -At the moment only the Simple Moving Average (SMA) type is supported. - -It is calculated like +In the case of Simple Moving Average - SMA, result will be calculated as per the below formula. f[i]: The value of field 'f' in the i-th data-point n: The number of data-points in the moving window (period) @@ -33,7 +32,7 @@ It is calculated like SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t -Example 1: Calculate the moving average on one field. +Example 1: Calculate the simple moving average on one field. ===================================================== The example shows how to calculate the moving average on one field. @@ -52,7 +51,7 @@ PPL query:: +------+ -Example 2: Calculate the moving average on multiple fields. +Example 2: Calculate the simple moving average on multiple fields. =========================================================== The example shows how to calculate the moving average on multiple fields. @@ -70,14 +69,14 @@ PPL query:: | 15.5 | 30.5 | +------+-----------+ -Example 4: Calculate the moving average on one field without specifying an alias. +Example 3: Calculate the simple moving average on one field without specifying an alias. ================================================================================= -The example shows how to calculate the moving average on one field. +The example shows how to calculate the moving average on one field without specifying an alias. PPL query:: - os> source=accounts | trendline sma(2, account_number) | fields account_number_trendline; + os> source=accounts | trendline sma(2, account_number) | fields account_number_sma_trendline; fetched rows / total rows = 4/4 +--------------------------+ | account_number_trendline | @@ -88,3 +87,69 @@ PPL query:: | 15.5 | +--------------------------+ + + +In the case of Weighted Moving Average - WMA, result will be calculated as per the below formula. + + f[i]: The value of field 'f' in the i-th data point + n: The number of data points in the moving window (period) + t: The current time index + w[i]: The weight of the i-th data point, increasing by one per step to prioritize recent points. + + WMA(t) = ( Σ from i=t−n+1 to t of (w[i] * f[i]) ) / ( Σ from i=t−n+1 to t of w[i] ) + +Example 1: Calculate the weighted moving average on one field. +===================================================== + +The example shows how to calculate the weighted moving average on one field. + +PPL query:: + + os> source=accounts | trendline wma(2, account_number) as an | fields an; + fetched rows / total rows = 4/4 + +--------------------+ + | an | + |--------------------| + | null | + | 4.333333333333333 | + | 10.666666666666666 | + | 16.333333333333332 | + +--------------------+ + +Example 2: Calculate the weighted moving average on multiple fields. +=========================================================== + +The example shows how to calculate the weighted moving average on multiple fields. + +PPL query:: + + os> source=accounts | trendline wma(2, account_number) as an sma(2, age) as age_trend | fields an, age_trend ; + fetched rows / total rows = 4/4 + +--------------------+-----------+ + | an | age_trend | + |--------------------+-----------| + | null | null | + | 4.333333333333333 | 34.0 | + | 10.666666666666666 | 32.0 | + | 16.333333333333332 | 30.5 | + +--------------------+-----------+ + + +Example 3: Calculate the weighted moving average on one field without specifying an alias. +================================================================================= + +The example shows how to calculate the weighted moving average on one field without specifying an alias. + +PPL query:: + + os> source=accounts | trendline wma(2, account_number) | fields account_number_wma_trendline; + fetched rows / total rows = 4/4 + +--------------------------+ + | account_number_trendline | + |--------------------------| + | null | + | 4.333333333333333 | + | 10.666666666666666 | + | 16.333333333333332 | + +--------------------------+ + diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 38baa0f01f..02860fe1e7 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -7,7 +7,9 @@ import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; import java.io.IOException; import org.json.JSONObject; @@ -55,13 +57,14 @@ public void testTrendlineOverwritesExistingField() throws IOException { } @Test - public void testTrendlineNoAlias() throws IOException { + public void testTrendlineNoAliasDefaultName() throws IOException { final JSONObject result = executeQuery( String.format( "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) |" - + " fields balance_trendline", + + " fields balance_sma_trendline", TEST_INDEX_BANK)); + verifySchema(result, schema("balance_sma_trendline", "double")); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } @@ -71,8 +74,95 @@ public void testTrendlineWithSort() throws IOException { executeQuery( String.format( "source=%s | where balance > 39000 | trendline sort balance sma(2, balance) |" - + " fields balance_trendline", + + " fields balance_sma_trendline", TEST_INDEX_BANK)); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } + + @Test + public void testTrendlineWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | sort balance | head 4 | trendline wma(4, balance) as" + + " balance_trend | fields balance_trend", + TEST_INDEX_BANK)); + verifyDataRows( + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8)); + } + + @Test + public void testTrendlineMultipleFieldsWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | sort balance | head 5 | trendline wma(4, balance) as" + + " balance_trend wma(5, account_number) as account_number_trend | fields" + + " balance_trend, account_number_trend", + TEST_INDEX_BANK)); + verifyDataRows( + result, + rows(null, null), + rows(null, null), + rows(null, null), + rows(19615.8, null), + rows(29393.6, 9.8)); + } + + @Test + public void testTrendlineOverwritesExistingFieldWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | sort balance | head 6 | trendline wma(4, balance) as" + + " age | fields age", + TEST_INDEX_BANK)); + verifyDataRows( + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8), + rows(29393.6), + rows(36192.9)); + } + + @Test + public void testTrendlineNoAliasWmaDefaultName() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | sort balance | head 5 | trendline wma(4, balance) |" + + " fields balance_wma_trendline", + TEST_INDEX_BANK)); + verifySchema(result, schema("balance_wma_trendline", "double")); + verifyDataRows( + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8), + rows(29393.6)); + } + + @Test + public void testTrendlineWithSortWma() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | sort balance | head 5 | trendline sort balance wma(4, balance) |" + + " fields balance_wma_trendline", + TEST_INDEX_BANK)); + verifyDataRows( + result, + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(new Object[] {null}), + rows(19615.8), + rows(29393.6)); + } } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index c484f34a2a..badfed9a73 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -60,6 +60,7 @@ NUM: 'NUM'; // TRENDLINE KEYWORDS SMA: 'SMA'; +WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index acae54b7d9..3e7e88717f 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -156,6 +156,7 @@ trendlineClause trendlineType : SMA + | WMA ; kmeansCommand diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 5a7522683a..e5b43bc834 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -89,13 +89,15 @@ public Trendline.TrendlineComputation visitTrendlineClause( } final Field dataField = (Field) this.visitFieldExpression(ctx.field); + final Trendline.TrendlineType computationType = + Trendline.TrendlineType.valueOf(ctx.trendlineType().getText().toUpperCase(Locale.ROOT)); final String alias = ctx.alias != null ? ctx.alias.getText() - : dataField.getChild().get(0).toString() + "_trendline"; - - final Trendline.TrendlineType computationType = - Trendline.TrendlineType.valueOf(ctx.trendlineType().getText().toUpperCase(Locale.ROOT)); + : dataField.getChild().getFirst().toString() + + "_" + + computationType.name().toLowerCase() + + "_trendline"; return new Trendline.TrendlineComputation( numberOfDataPoints, dataField, alias, computationType); }