Skip to content

Commit

Permalink
Number ceil/floor/sqrt operations
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-ssh16 authored and gs-ssh16 committed Jul 26, 2023
1 parent 935dad7 commit 98df2a0
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 0 deletions.
24 changes: 24 additions & 0 deletions pylegend/core/databse/sql_to_string/db_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
StringConcatExpression,
AbsoluteExpression,
PowerExpression,
CeilExpression,
FloorExpression,
SqrtExpression,
)


Expand Down Expand Up @@ -342,6 +345,12 @@ def expression_processor(
return extension.process_absolute_expression(expression, config)
elif isinstance(expression, PowerExpression):
return extension.process_power_expression(expression, config)
elif isinstance(expression, CeilExpression):
return extension.process_ceil_expression(expression, config)
elif isinstance(expression, FloorExpression):
return extension.process_floor_expression(expression, config)
elif isinstance(expression, SqrtExpression):
return extension.process_sqrt_expression(expression, config)
else:
raise ValueError("Unsupported expression type: " + str(type(expression))) # pragma: no cover

Expand Down Expand Up @@ -954,6 +963,21 @@ def process_power_expression(self, expr: PowerExpression, config: SqlToStringCon
second=self.process_expression(expr.second, config)
)

def process_ceil_expression(self, expr: CeilExpression, config: SqlToStringConfig) -> str:
return "CEIL({value})".format(
value=self.process_expression(expr.value, config)
)

def process_floor_expression(self, expr: FloorExpression, config: SqlToStringConfig) -> str:
return "FLOOR({value})".format(
value=self.process_expression(expr.value, config)
)

def process_sqrt_expression(self, expr: SqrtExpression, config: SqlToStringConfig) -> str:
return "SQRT({value})".format(
value=self.process_expression(expr.value, config)
)

def process_qualified_name(self, qualified_name: QualifiedName, config: SqlToStringConfig) -> str:
return qualified_name_processor(qualified_name, self, config)

Expand Down
64 changes: 64 additions & 0 deletions pylegend/core/language/operations/number_operation_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from pylegend.core.language.expression import (
PyLegendExpressionNumberReturn,
PyLegendExpressionIntegerReturn,
PyLegendExpressionBooleanReturn,
)
from pylegend.core.language.operations.binary_expression import PyLegendBinaryExpression
Expand All @@ -35,6 +36,9 @@
from pylegend.core.sql.metamodel_extension import (
AbsoluteExpression,
PowerExpression,
CeilExpression,
FloorExpression,
SqrtExpression,
)


Expand All @@ -50,6 +54,9 @@
"PyLegendNumberNegativeExpression",
"PyLegendNumberAbsoluteExpression",
"PyLegendNumberPowerExpression",
"PyLegendNumberCeilExpression",
"PyLegendNumberFloorExpression",
"PyLegendNumberSqrtExpression",
]


Expand Down Expand Up @@ -278,3 +285,60 @@ def __init__(self, operand1: PyLegendExpressionNumberReturn, operand2: PyLegendE
operand2,
PyLegendNumberPowerExpression.__to_sql_func
)


class PyLegendNumberCeilExpression(PyLegendUnaryExpression, PyLegendExpressionIntegerReturn):

@staticmethod
def __to_sql_func(
expression: Expression,
frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification],
config: FrameToSqlConfig
) -> Expression:
return CeilExpression(expression)

def __init__(self, operand: PyLegendExpressionNumberReturn) -> None:
PyLegendExpressionIntegerReturn.__init__(self)
PyLegendUnaryExpression.__init__(
self,
operand,
PyLegendNumberCeilExpression.__to_sql_func
)


class PyLegendNumberFloorExpression(PyLegendUnaryExpression, PyLegendExpressionIntegerReturn):

@staticmethod
def __to_sql_func(
expression: Expression,
frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification],
config: FrameToSqlConfig
) -> Expression:
return FloorExpression(expression)

def __init__(self, operand: PyLegendExpressionNumberReturn) -> None:
PyLegendExpressionIntegerReturn.__init__(self)
PyLegendUnaryExpression.__init__(
self,
operand,
PyLegendNumberFloorExpression.__to_sql_func
)


class PyLegendNumberSqrtExpression(PyLegendUnaryExpression, PyLegendExpressionNumberReturn):

@staticmethod
def __to_sql_func(
expression: Expression,
frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification],
config: FrameToSqlConfig
) -> Expression:
return SqrtExpression(expression)

def __init__(self, operand: PyLegendExpressionNumberReturn) -> None:
PyLegendExpressionNumberReturn.__init__(self)
PyLegendUnaryExpression.__init__(
self,
operand,
PyLegendNumberSqrtExpression.__to_sql_func
)
14 changes: 14 additions & 0 deletions pylegend/core/language/primitives/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
PyLegendNumberNegativeExpression,
PyLegendNumberAbsoluteExpression,
PyLegendNumberPowerExpression,
PyLegendNumberCeilExpression,
PyLegendNumberFloorExpression,
PyLegendNumberSqrtExpression,
)
from pylegend.core.sql.metamodel import (
Expression,
Expand Down Expand Up @@ -192,6 +195,17 @@ def __rpow__(
other_op = PyLegendNumber.__convert_to_number_expr(other)
return PyLegendNumber(PyLegendNumberPowerExpression(other_op, self.__value))

def ceil(self) -> "PyLegendInteger":
from pylegend.core.language.primitives.integer import PyLegendInteger
return PyLegendInteger(PyLegendNumberCeilExpression(self.__value))

def floor(self) -> "PyLegendInteger":
from pylegend.core.language.primitives.integer import PyLegendInteger
return PyLegendInteger(PyLegendNumberFloorExpression(self.__value))

def sqrt(self) -> "PyLegendNumber":
return PyLegendNumber(PyLegendNumberSqrtExpression(self.__value))

@staticmethod
def __convert_to_number_expr(
val: PyLegendUnion[int, float, "PyLegendInteger", "PyLegendFloat", "PyLegendNumber"]
Expand Down
36 changes: 36 additions & 0 deletions pylegend/core/sql/metamodel_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
"StringConcatExpression",
"AbsoluteExpression",
"PowerExpression",
"CeilExpression",
"FloorExpression",
"SqrtExpression",
]


Expand Down Expand Up @@ -152,3 +155,36 @@ def __init__(
super().__init__(_type="powerExpression")
self.first = first
self.second = second


class CeilExpression(Expression):
value: "Expression"

def __init__(
self,
value: "Expression",
) -> None:
super().__init__(_type="ceilExpression")
self.value = value


class FloorExpression(Expression):
value: "Expression"

def __init__(
self,
value: "Expression",
) -> None:
super().__init__(_type="floorExpression")
self.value = value


class SqrtExpression(Expression):
value: "Expression"

def __init__(
self,
value: "Expression",
) -> None:
super().__init__(_type="sqrtExpression")
self.value = value
24 changes: 24 additions & 0 deletions pylegend/tests/core/database/test_sql_to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
StringConcatExpression,
AbsoluteExpression,
PowerExpression,
CeilExpression,
FloorExpression,
SqrtExpression,
)


Expand Down Expand Up @@ -1224,3 +1227,24 @@ def test_process_power_expression(self) -> None:

expr = PowerExpression(IntegerLiteral(9), IntegerLiteral(3))
assert extension.process_expression(expr, config) == "POWER(9, 3)"

def test_process_ceil_expression(self) -> None:
extension = SqlToStringDbExtension()
config = SqlToStringConfig(SqlToStringFormat(pretty=False))

expr = CeilExpression(DoubleLiteral(2.3))
assert extension.process_expression(expr, config) == "CEIL(2.3)"

def test_process_floor_expression(self) -> None:
extension = SqlToStringDbExtension()
config = SqlToStringConfig(SqlToStringFormat(pretty=False))

expr = FloorExpression(DoubleLiteral(2.3))
assert extension.process_expression(expr, config) == "FLOOR(2.3)"

def test_process_sqrt_expression(self) -> None:
extension = SqlToStringDbExtension()
config = SqlToStringConfig(SqlToStringFormat(pretty=False))

expr = SqrtExpression(IntegerLiteral(10))
assert extension.process_expression(expr, config) == "SQRT(10)"
18 changes: 18 additions & 0 deletions pylegend/tests/core/language/primitives/test_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,24 @@ def test_number_power_expr(self) -> None:
assert self.__generate_sql_string(lambda x: 1.2 ** x.get_number("col2")) == \
'POWER(1.2, "root".col2)'

def test_number_ceil_expr(self) -> None:
assert self.__generate_sql_string(lambda x: x.get_number("col2").ceil()) == \
'CEIL("root".col2)'
assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).ceil()) == \
'CEIL(("root".col2 + "root".col1))'

def test_number_floor_expr(self) -> None:
assert self.__generate_sql_string(lambda x: x.get_number("col2").floor()) == \
'FLOOR("root".col2)'
assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).floor()) == \
'FLOOR(("root".col2 + "root".col1))'

def test_number_sqrt_expr(self) -> None:
assert self.__generate_sql_string(lambda x: x.get_number("col2").sqrt()) == \
'SQRT("root".col2)'
assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).sqrt()) == \
'SQRT(("root".col2 + "root".col1))'

def __generate_sql_string(self, f) -> str: # type: ignore
return self.db_extension.process_expression(
f(self.tds_row).to_sql_expression({"t": self.base_query}, self.frame_to_sql_config),
Expand Down

0 comments on commit 98df2a0

Please sign in to comment.