diff --git a/pylegend/core/databse/sql_to_string/db_extension.py b/pylegend/core/databse/sql_to_string/db_extension.py index 46e10ca9..79876976 100644 --- a/pylegend/core/databse/sql_to_string/db_extension.py +++ b/pylegend/core/databse/sql_to_string/db_extension.py @@ -83,6 +83,7 @@ StringPosExpression, StringConcatExpression, AbsoluteExpression, + PowerExpression, ) @@ -339,6 +340,8 @@ def expression_processor( return extension.process_string_concat_expression(expression, config) elif isinstance(expression, AbsoluteExpression): return extension.process_absolute_expression(expression, config) + elif isinstance(expression, PowerExpression): + return extension.process_power_expression(expression, config) else: raise ValueError("Unsupported expression type: " + str(type(expression))) # pragma: no cover @@ -945,6 +948,12 @@ def process_absolute_expression(self, expr: AbsoluteExpression, config: SqlToStr value=self.process_expression(expr.value, config) ) + def process_power_expression(self, expr: PowerExpression, config: SqlToStringConfig) -> str: + return "POWER({first}, {second})".format( + first=self.process_expression(expr.first, config), + second=self.process_expression(expr.second, config) + ) + def process_qualified_name(self, qualified_name: QualifiedName, config: SqlToStringConfig) -> str: return qualified_name_processor(qualified_name, self, config) diff --git a/pylegend/core/language/operations/number_operation_expressions.py b/pylegend/core/language/operations/number_operation_expressions.py index 5d79be6d..0acab09a 100644 --- a/pylegend/core/language/operations/number_operation_expressions.py +++ b/pylegend/core/language/operations/number_operation_expressions.py @@ -34,6 +34,7 @@ ) from pylegend.core.sql.metamodel_extension import ( AbsoluteExpression, + PowerExpression, ) @@ -48,6 +49,7 @@ "PyLegendNumberGreaterThanEqualExpression", "PyLegendNumberNegativeExpression", "PyLegendNumberAbsoluteExpression", + "PyLegendNumberPowerExpression", ] @@ -255,3 +257,24 @@ def __init__(self, operand: PyLegendExpressionNumberReturn) -> None: operand, PyLegendNumberAbsoluteExpression.__to_sql_func ) + + +class PyLegendNumberPowerExpression(PyLegendBinaryExpression, PyLegendExpressionNumberReturn): + + @staticmethod + def __to_sql_func( + expression1: Expression, + expression2: Expression, + frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification], + config: FrameToSqlConfig + ) -> Expression: + return PowerExpression(expression1, expression2) + + def __init__(self, operand1: PyLegendExpressionNumberReturn, operand2: PyLegendExpressionNumberReturn) -> None: + PyLegendExpressionNumberReturn.__init__(self) + PyLegendBinaryExpression.__init__( + self, + operand1, + operand2, + PyLegendNumberPowerExpression.__to_sql_func + ) diff --git a/pylegend/core/language/primitives/number.py b/pylegend/core/language/primitives/number.py index b2c4d820..5e0d03c1 100644 --- a/pylegend/core/language/primitives/number.py +++ b/pylegend/core/language/primitives/number.py @@ -38,6 +38,7 @@ PyLegendNumberGreaterThanEqualExpression, PyLegendNumberNegativeExpression, PyLegendNumberAbsoluteExpression, + PyLegendNumberPowerExpression, ) from pylegend.core.sql.metamodel import ( Expression, @@ -175,6 +176,22 @@ def __neg__(self) -> "PyLegendNumber": def __abs__(self) -> "PyLegendNumber": return PyLegendNumber(PyLegendNumberAbsoluteExpression(self.__value)) + def __pow__( + self, + other: PyLegendUnion[int, float, "PyLegendInteger", "PyLegendFloat", "PyLegendNumber"] + ) -> "PyLegendNumber": + PyLegendNumber.validate_param_to_be_number(other, "Number power (**) parameter") + other_op = PyLegendNumber.__convert_to_number_expr(other) + return PyLegendNumber(PyLegendNumberPowerExpression(self.__value, other_op)) + + def __rpow__( + self, + other: PyLegendUnion[int, float, "PyLegendInteger", "PyLegendFloat", "PyLegendNumber"] + ) -> "PyLegendNumber": + PyLegendNumber.validate_param_to_be_number(other, "Number power (**) parameter") + other_op = PyLegendNumber.__convert_to_number_expr(other) + return PyLegendNumber(PyLegendNumberPowerExpression(other_op, self.__value)) + @staticmethod def __convert_to_number_expr( val: PyLegendUnion[int, float, "PyLegendInteger", "PyLegendFloat", "PyLegendNumber"] diff --git a/pylegend/core/sql/metamodel_extension.py b/pylegend/core/sql/metamodel_extension.py index 70c41c52..c6bd9275 100644 --- a/pylegend/core/sql/metamodel_extension.py +++ b/pylegend/core/sql/metamodel_extension.py @@ -30,6 +30,7 @@ "StringPosExpression", "StringConcatExpression", "AbsoluteExpression", + "PowerExpression", ] @@ -137,3 +138,17 @@ def __init__( ) -> None: super().__init__(_type="absoluteExpression") self.value = value + + +class PowerExpression(Expression): + first: "Expression" + second: "Expression" + + def __init__( + self, + first: "Expression", + second: "Expression" + ) -> None: + super().__init__(_type="powerExpression") + self.first = first + self.second = second diff --git a/pylegend/tests/core/database/test_sql_to_string.py b/pylegend/tests/core/database/test_sql_to_string.py index 53ad225a..d5ff7565 100644 --- a/pylegend/tests/core/database/test_sql_to_string.py +++ b/pylegend/tests/core/database/test_sql_to_string.py @@ -80,6 +80,7 @@ StringPosExpression, StringConcatExpression, AbsoluteExpression, + PowerExpression, ) @@ -1216,3 +1217,10 @@ def test_process_absolute_expression(self) -> None: expr = AbsoluteExpression(IntegerLiteral(-1)) assert extension.process_expression(expr, config) == "ABS(-1)" + + def test_process_power_expression(self) -> None: + extension = SqlToStringDbExtension() + config = SqlToStringConfig(SqlToStringFormat(pretty=False)) + + expr = PowerExpression(IntegerLiteral(9), IntegerLiteral(3)) + assert extension.process_expression(expr, config) == "POWER(9, 3)" diff --git a/pylegend/tests/core/language/primitives/test_number.py b/pylegend/tests/core/language/primitives/test_number.py index 1761113e..825c6481 100644 --- a/pylegend/tests/core/language/primitives/test_number.py +++ b/pylegend/tests/core/language/primitives/test_number.py @@ -119,6 +119,14 @@ def test_number_abs_expr(self) -> None: assert self.__generate_sql_string(lambda x: abs(x.get_number("col2") + x.get_number("col1"))) == \ 'ABS(("root".col2 + "root".col1))' + def test_number_power_expr(self) -> None: + assert self.__generate_sql_string(lambda x: x.get_number("col2") ** x.get_number("col1")) == \ + 'POWER("root".col2, "root".col1)' + assert self.__generate_sql_string(lambda x: x.get_number("col2") ** 10) == \ + 'POWER("root".col2, 10)' + assert self.__generate_sql_string(lambda x: 1.2 ** x.get_number("col2")) == \ + 'POWER(1.2, "root".col2)' + def __generate_sql_string(self, f) -> str: # type: ignore return self.db_extension.process_expression( f(self.tds_row).to_sql_expression({"t": self.base_query}, self.frame_to_sql_config),