Skip to content

Commit

Permalink
Number exp/log operations
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-ssh16 authored and gs-ssh16 committed Jul 26, 2023
1 parent b01cee7 commit d89086e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pylegend/core/databse/sql_to_string/db_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
FloorExpression,
SqrtExpression,
CbrtExpression,
ExpExpression,
LogExpression,
)


Expand Down Expand Up @@ -354,6 +356,10 @@ def expression_processor(
return extension.process_sqrt_expression(expression, config)
elif isinstance(expression, CbrtExpression):
return extension.process_cbrt_expression(expression, config)
elif isinstance(expression, ExpExpression):
return extension.process_exp_expression(expression, config)
elif isinstance(expression, LogExpression):
return extension.process_log_expression(expression, config)
else:
raise ValueError("Unsupported expression type: " + str(type(expression))) # pragma: no cover

Expand Down Expand Up @@ -986,6 +992,16 @@ def process_cbrt_expression(self, expr: CbrtExpression, config: SqlToStringConfi
value=self.process_expression(expr.value, config)
)

def process_exp_expression(self, expr: ExpExpression, config: SqlToStringConfig) -> str:
return "EXP({value})".format(
value=self.process_expression(expr.value, config)
)

def process_log_expression(self, expr: LogExpression, config: SqlToStringConfig) -> str:
return "LN({value})".format(
value=self.process_expression(expr.value, config)
)

def process_qualified_name(self, qualified_name: QualifiedName, config: SqlToStringConfig) -> str:
return qualified_name_processor(qualified_name, self, config)

Expand Down
42 changes: 42 additions & 0 deletions pylegend/core/language/operations/number_operation_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
FloorExpression,
SqrtExpression,
CbrtExpression,
ExpExpression,
LogExpression,
)


Expand All @@ -59,6 +61,8 @@
"PyLegendNumberFloorExpression",
"PyLegendNumberSqrtExpression",
"PyLegendNumberCbrtExpression",
"PyLegendNumberExpExpression",
"PyLegendNumberLogExpression",
]


Expand Down Expand Up @@ -363,3 +367,41 @@ def __init__(self, operand: PyLegendExpressionNumberReturn) -> None:
operand,
PyLegendNumberCbrtExpression.__to_sql_func
)


class PyLegendNumberExpExpression(PyLegendUnaryExpression, PyLegendExpressionNumberReturn):

@staticmethod
def __to_sql_func(
expression: Expression,
frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification],
config: FrameToSqlConfig
) -> Expression:
return ExpExpression(expression)

def __init__(self, operand: PyLegendExpressionNumberReturn) -> None:
PyLegendExpressionNumberReturn.__init__(self)
PyLegendUnaryExpression.__init__(
self,
operand,
PyLegendNumberExpExpression.__to_sql_func
)


class PyLegendNumberLogExpression(PyLegendUnaryExpression, PyLegendExpressionNumberReturn):

@staticmethod
def __to_sql_func(
expression: Expression,
frame_name_to_base_query_map: PyLegendDict[str, QuerySpecification],
config: FrameToSqlConfig
) -> Expression:
return LogExpression(expression)

def __init__(self, operand: PyLegendExpressionNumberReturn) -> None:
PyLegendExpressionNumberReturn.__init__(self)
PyLegendUnaryExpression.__init__(
self,
operand,
PyLegendNumberLogExpression.__to_sql_func
)
8 changes: 8 additions & 0 deletions pylegend/core/language/primitives/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
PyLegendNumberFloorExpression,
PyLegendNumberSqrtExpression,
PyLegendNumberCbrtExpression,
PyLegendNumberExpExpression,
PyLegendNumberLogExpression,
)
from pylegend.core.sql.metamodel import (
Expression,
Expand Down Expand Up @@ -210,6 +212,12 @@ def sqrt(self) -> "PyLegendNumber":
def cbrt(self) -> "PyLegendNumber":
return PyLegendNumber(PyLegendNumberCbrtExpression(self.__value))

def exp(self) -> "PyLegendNumber":
return PyLegendNumber(PyLegendNumberExpExpression(self.__value))

def log(self) -> "PyLegendNumber":
return PyLegendNumber(PyLegendNumberLogExpression(self.__value))

@staticmethod
def __convert_to_number_expr(
val: PyLegendUnion[int, float, "PyLegendInteger", "PyLegendFloat", "PyLegendNumber"]
Expand Down
24 changes: 24 additions & 0 deletions pylegend/core/sql/metamodel_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"FloorExpression",
"SqrtExpression",
"CbrtExpression",
"ExpExpression",
"LogExpression",
]


Expand Down Expand Up @@ -200,3 +202,25 @@ def __init__(
) -> None:
super().__init__(_type="cbrtExpression")
self.value = value


class ExpExpression(Expression):
value: "Expression"

def __init__(
self,
value: "Expression",
) -> None:
super().__init__(_type="expExpression")
self.value = value


class LogExpression(Expression):
value: "Expression"

def __init__(
self,
value: "Expression",
) -> None:
super().__init__(_type="logExpression")
self.value = value
16 changes: 16 additions & 0 deletions pylegend/tests/core/database/test_sql_to_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
FloorExpression,
SqrtExpression,
CbrtExpression,
ExpExpression,
LogExpression,
)


Expand Down Expand Up @@ -1256,3 +1258,17 @@ def test_process_cbrt_expression(self) -> None:

expr = CbrtExpression(IntegerLiteral(10))
assert extension.process_expression(expr, config) == "CBRT(10)"

def test_process_exp_expression(self) -> None:
extension = SqlToStringDbExtension()
config = SqlToStringConfig(SqlToStringFormat(pretty=False))

expr = ExpExpression(IntegerLiteral(10))
assert extension.process_expression(expr, config) == "EXP(10)"

def test_process_log_expression(self) -> None:
extension = SqlToStringDbExtension()
config = SqlToStringConfig(SqlToStringFormat(pretty=False))

expr = LogExpression(IntegerLiteral(10))
assert extension.process_expression(expr, config) == "LN(10)"
12 changes: 12 additions & 0 deletions pylegend/tests/core/language/primitives/test_number.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ def test_number_cbrt_expr(self) -> None:
assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).cbrt()) == \
'CBRT(("root".col2 + "root".col1))'

def test_number_exp_expr(self) -> None:
assert self.__generate_sql_string(lambda x: x.get_number("col2").exp()) == \
'EXP("root".col2)'
assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).exp()) == \
'EXP(("root".col2 + "root".col1))'

def test_number_log_expr(self) -> None:
assert self.__generate_sql_string(lambda x: x.get_number("col2").log()) == \
'LN("root".col2)'
assert self.__generate_sql_string(lambda x: (x.get_number("col2") + x.get_number("col1")).log()) == \
'LN(("root".col2 + "root".col1))'

def __generate_sql_string(self, f) -> str: # type: ignore
return self.db_extension.process_expression(
f(self.tds_row).to_sql_expression({"t": self.base_query}, self.frame_to_sql_config),
Expand Down

0 comments on commit d89086e

Please sign in to comment.