diff --git a/pylegend/core/databse/sql_to_string/db_extension.py b/pylegend/core/databse/sql_to_string/db_extension.py index 4a0a3527..6f9a2524 100644 --- a/pylegend/core/databse/sql_to_string/db_extension.py +++ b/pylegend/core/databse/sql_to_string/db_extension.py @@ -87,6 +87,7 @@ CeilExpression, FloorExpression, SqrtExpression, + CbrtExpression, ) @@ -351,6 +352,8 @@ def expression_processor( return extension.process_floor_expression(expression, config) elif isinstance(expression, SqrtExpression): return extension.process_sqrt_expression(expression, config) + elif isinstance(expression, CbrtExpression): + return extension.process_cbrt_expression(expression, config) else: raise ValueError("Unsupported expression type: " + str(type(expression))) # pragma: no cover @@ -978,6 +981,11 @@ def process_sqrt_expression(self, expr: SqrtExpression, config: SqlToStringConfi value=self.process_expression(expr.value, config) ) + def process_cbrt_expression(self, expr: CbrtExpression, config: SqlToStringConfig) -> str: + return "CBRT({value})".format( + value=self.process_expression(expr.value, config) + ) + def process_qualified_name(self, qualified_name: QualifiedName, config: SqlToStringConfig) -> str: return qualified_name_processor(qualified_name, self, config) diff --git a/pylegend/core/language/operations/number_operation_expressions.py b/pylegend/core/language/operations/number_operation_expressions.py index 04a839ad..0bd7db12 100644 --- a/pylegend/core/language/operations/number_operation_expressions.py +++ b/pylegend/core/language/operations/number_operation_expressions.py @@ -39,6 +39,7 @@ CeilExpression, FloorExpression, SqrtExpression, + CbrtExpression, ) @@ -57,6 +58,7 @@ "PyLegendNumberCeilExpression", "PyLegendNumberFloorExpression", "PyLegendNumberSqrtExpression", + "PyLegendNumberCbrtExpression", ] @@ -342,3 +344,22 @@ def __init__(self, operand: PyLegendExpressionNumberReturn) -> None: operand, PyLegendNumberSqrtExpression.__to_sql_func ) + + +class PyLegendNumberCbrtExpression(PyLegendUnaryExpression, PyLegendExpressionNumberReturn): + + @staticmethod + def __to_sql_func( + expression: Expression, + frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification], + config: FrameToSqlConfig + ) -> Expression: + return CbrtExpression(expression) + + def __init__(self, operand: PyLegendExpressionNumberReturn) -> None: + PyLegendExpressionNumberReturn.__init__(self) + PyLegendUnaryExpression.__init__( + self, + operand, + PyLegendNumberCbrtExpression.__to_sql_func + ) diff --git a/pylegend/core/language/primitives/number.py b/pylegend/core/language/primitives/number.py index 9151b77e..c04cc424 100644 --- a/pylegend/core/language/primitives/number.py +++ b/pylegend/core/language/primitives/number.py @@ -42,6 +42,7 @@ PyLegendNumberCeilExpression, PyLegendNumberFloorExpression, PyLegendNumberSqrtExpression, + PyLegendNumberCbrtExpression, ) from pylegend.core.sql.metamodel import ( Expression, @@ -206,6 +207,9 @@ def floor(self) -> "PyLegendInteger": def sqrt(self) -> "PyLegendNumber": return PyLegendNumber(PyLegendNumberSqrtExpression(self.__value)) + def cbrt(self) -> "PyLegendNumber": + return PyLegendNumber(PyLegendNumberCbrtExpression(self.__value)) + @staticmethod def __convert_to_number_expr( val: PyLegendUnion[int, float, "PyLegendInteger", "PyLegendFloat", "PyLegendNumber"] diff --git a/pylegend/core/sql/metamodel_extension.py b/pylegend/core/sql/metamodel_extension.py index 2aa31b4e..94fab2d3 100644 --- a/pylegend/core/sql/metamodel_extension.py +++ b/pylegend/core/sql/metamodel_extension.py @@ -34,6 +34,7 @@ "CeilExpression", "FloorExpression", "SqrtExpression", + "CbrtExpression", ] @@ -188,3 +189,14 @@ def __init__( ) -> None: super().__init__(_type="sqrtExpression") self.value = value + + +class CbrtExpression(Expression): + value: "Expression" + + def __init__( + self, + value: "Expression", + ) -> None: + super().__init__(_type="cbrtExpression") + self.value = value diff --git a/pylegend/tests/core/database/test_sql_to_string.py b/pylegend/tests/core/database/test_sql_to_string.py index 63abe86a..19390fc4 100644 --- a/pylegend/tests/core/database/test_sql_to_string.py +++ b/pylegend/tests/core/database/test_sql_to_string.py @@ -84,6 +84,7 @@ CeilExpression, FloorExpression, SqrtExpression, + CbrtExpression, ) @@ -1248,3 +1249,10 @@ def test_process_sqrt_expression(self) -> None: expr = SqrtExpression(IntegerLiteral(10)) assert extension.process_expression(expr, config) == "SQRT(10)" + + def test_process_cbrt_expression(self) -> None: + extension = SqlToStringDbExtension() + config = SqlToStringConfig(SqlToStringFormat(pretty=False)) + + expr = CbrtExpression(IntegerLiteral(10)) + assert extension.process_expression(expr, config) == "CBRT(10)" diff --git a/pylegend/tests/core/language/primitives/test_number.py b/pylegend/tests/core/language/primitives/test_number.py index eefcc7ca..3d36693e 100644 --- a/pylegend/tests/core/language/primitives/test_number.py +++ b/pylegend/tests/core/language/primitives/test_number.py @@ -145,6 +145,12 @@ def test_number_sqrt_expr(self) -> None: assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).sqrt()) == \ 'SQRT(("root".col2 + "root".col1))' + def test_number_cbrt_expr(self) -> None: + assert self.__generate_sql_string(lambda x: x.get_number("col2").cbrt()) == \ + 'CBRT("root".col2)' + assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).cbrt()) == \ + 'CBRT(("root".col2 + "root".col1))' + def __generate_sql_string(self, f) -> str: # type: ignore return self.db_extension.process_expression( f(self.tds_row).to_sql_expression({"t": self.base_query}, self.frame_to_sql_config),