diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 64ed61626b..379f4bf846 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -6,6 +6,7 @@ import unittest import uuid import logging +import textwrap from hyperscript import h from rich.console import Console as RichConsole @@ -35,7 +36,7 @@ from sqlmesh.utils import rich as srich from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.date import time_like_to_str, to_date, yesterday_ds -from sqlmesh.utils.errors import PythonModelEvalError +from sqlmesh.utils.errors import PythonModelEvalError, NodeAuditsErrors if t.TYPE_CHECKING: import ipywidgets as widgets @@ -2405,9 +2406,13 @@ def _format_node_error(ex: NodeExecutionFailedError) -> str: error_msg = str(cause) - if not isinstance(cause, (NodeExecutionFailedError, PythonModelEvalError)): + if isinstance(cause, NodeAuditsErrors): + error_msg = _format_audits_errors(cause) + elif not isinstance(cause, (NodeExecutionFailedError, PythonModelEvalError)): error_msg = " " + error_msg.replace("\n", "\n ") - error_msg = f" {cause.__class__.__name__}:\n{error_msg}" + error_msg = ( + f" {cause.__class__.__name__}:\n{error_msg}" # include error class name in msg + ) error_msg = error_msg.replace("\n", "\n ") error_msg = error_msg + "\n" if not error_msg.rstrip(" ").endswith("\n") else error_msg @@ -2431,3 +2436,20 @@ def _format_node_error(ex: NodeExecutionFailedError) -> str: error_messages[node_name] = msg return error_messages + + +def _format_audits_errors(error: NodeAuditsErrors) -> str: + error_messages = [] + for err in error.errors: + audit_args_sql = [] + for arg_name, arg_value in err.audit_args.items(): + audit_args_sql.append(f"{arg_name} := {arg_value.sql(dialect=err.adapter_dialect)}") + audit_args_sql_msg = ("\n".join(audit_args_sql) + "\n\n") if audit_args_sql else "" + + err_msg = f"'{err.audit_name}' audit error\n\n{err.count} {'row' if err.count == 1 else 'rows'} failed" + + query = "\n ".join(textwrap.wrap(err.sql(err.adapter_dialect), width=100)) + msg = f"{err_msg}\n\nAudit arguments\n {audit_args_sql_msg}Audit query\n {query}\n\n" + msg = msg.replace("\n", "\n ") + error_messages.append(msg) + return " " + "\n".join(error_messages) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 0b0615a3d8..54b77e3444 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1771,6 +1771,8 @@ def depends_on_self(self) -> bool: class AuditResult(PydanticModel): audit: Audit """The audit this result is for.""" + audit_args: t.Dict[t.Any, t.Any] + """Arguments passed to the audit.""" model: t.Optional[_Model] = None """The model this audit is for.""" count: t.Optional[int] = None diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 29e29caff8..f349609ea3 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -37,7 +37,7 @@ to_timestamp, validate_date_range, ) -from sqlmesh.utils.errors import AuditError, CircuitBreakerError, SQLMeshError +from sqlmesh.utils.errors import AuditError, NodeAuditsErrors, CircuitBreakerError, SQLMeshError logger = logging.getLogger(__name__) SnapshotToIntervals = t.Dict[Snapshot, Intervals] @@ -205,10 +205,11 @@ def evaluate( **kwargs, ) - audit_error_to_raise: t.Optional[AuditError] = None + audit_errors_to_raise: t.List[AuditError] = [] for audit_result in (result for result in audit_results if result.count): error = AuditError( audit_name=audit_result.audit.name, + audit_args=audit_result.audit_args, model=snapshot.model_or_none, count=t.cast(int, audit_result.count), query=t.cast(exp.Query, audit_result.query), @@ -220,14 +221,14 @@ def evaluate( NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error ) if audit_result.blocking: - audit_error_to_raise = error + audit_errors_to_raise.append(error) else: get_console().log_warning( - f"{error}\nAudit is non-blocking so proceeding with execution." + f"{error}. Audit is non-blocking so proceeding with execution.\n{error.query.sql(error.adapter_dialect)}\n" ) - if audit_error_to_raise: - raise audit_error_to_raise + if audit_errors_to_raise: + raise NodeAuditsErrors(audit_errors_to_raise) self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 6e3d23ec24..78253d2e6f 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -1013,6 +1013,7 @@ def _audit( if audit.skip: return AuditResult( audit=audit, + audit_args=audit_args, model=snapshot.model_or_none, skipped=True, ) @@ -1049,6 +1050,7 @@ def _audit( return AuditResult( audit=audit, + audit_args=audit_args, model=snapshot.model_or_none, count=count, query=query, diff --git a/sqlmesh/engines/commands.py b/sqlmesh/engines/commands.py index 602eec4763..6b144ee9e5 100644 --- a/sqlmesh/engines/commands.py +++ b/sqlmesh/engines/commands.py @@ -103,6 +103,7 @@ def evaluate( if failed_audit_result: raise AuditError( audit_name=failed_audit_result.audit.name, + audit_args=failed_audit_result.audit_args, model=command_payload.snapshot.model_or_none, count=t.cast(int, failed_audit_result.count), query=t.cast(exp.Query, failed_audit_result.query), diff --git a/sqlmesh/utils/errors.py b/sqlmesh/utils/errors.py index 000dbd8dc9..68d3549a08 100644 --- a/sqlmesh/utils/errors.py +++ b/sqlmesh/utils/errors.py @@ -71,6 +71,7 @@ class AuditError(SQLMeshError): def __init__( self, audit_name: str, + audit_args: t.Dict[t.Any, t.Any], count: int, query: exp.Query, model: t.Optional[Model] = None, @@ -78,14 +79,15 @@ def __init__( adapter_dialect: t.Optional[str] = None, ) -> None: self.audit_name = audit_name + self.audit_args = audit_args self.model = model self.count = count self.query = query self.adapter_dialect = adapter_dialect - def __str__(self) -> str: - model_str = f" for model '{self.model_name}'" if self.model_name else "" - return f"Audit '{self.audit_name}'{model_str} failed.\nGot {self.count} results, expected 0.\n{self.sql()}" + super().__init__( + f"'{self.audit_name}' audit error: {self.count} {'row' if self.count == 1 else 'rows'} failed" + ) @property def model_name(self) -> t.Optional[str]: @@ -106,6 +108,13 @@ def sql(self, dialect: t.Optional[str] = None, **opts: t.Any) -> str: return self.query.sql(dialect=dialect or self.adapter_dialect, **opts) +class NodeAuditsErrors(SQLMeshError): + def __init__(self, errors: t.List[AuditError]) -> None: + self.errors = errors + + super().__init__(f"Audits failed: {', '.join([e.audit_name for e in errors])}") + + class TestError(SQLMeshError): __test__ = False # prevent pytest trying to collect this as a test class pass diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 532722cd6b..f81b8b9e46 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -465,14 +465,10 @@ def test_override_builtin_audit_blocking_mode(): new_snapshot = next(iter(plan.context_diff.new_snapshots.values())) version = new_snapshot.fingerprint.to_version() - assert mock_logger.mock_calls == [ - call( - "Audit 'not_null' for model 'db.x' failed.\n" - "Got 1 results, expected 0.\n" - f'SELECT * FROM (SELECT * FROM "sqlmesh__db"."db__x__{version}" AS "db__x__{version}") AS "_q_0" WHERE "c" IS NULL AND TRUE\n' - "Audit is non-blocking so proceeding with execution." - ) - ] + assert ( + mock_logger.call_args_list[0][0][0] + == f'\'not_null\' audit error: 1 row failed. Audit is non-blocking so proceeding with execution.\nSELECT * FROM (SELECT * FROM "sqlmesh__db"."db__x__{version}" AS "db__x__{version}") AS "_q_0" WHERE "c" IS NULL AND TRUE\n' + ) # Even though there are two builtin audits referenced in the above definition, we only # store the one that overrides `blocking` in the snapshot; the other one isn't needed @@ -1372,6 +1368,8 @@ def test_plan_runs_audits_on_dev_previews(sushi_context: Context, capsys, caplog # we only see audit results if they fail stdout = capsys.readouterr().out log = caplog.text - assert "Audit 'not_null' for model 'sushi.test_audit_model' failed" in log - assert "Audit is non-blocking so proceeding with execution" in log + assert ( + "'not_null' audit error: 17 rows failed. Audit is non-blocking so proceeding with execution.\nSELECT" + in log + ) assert "Target environment updated successfully" in stdout diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index d558ff7e4a..7e3ed0f1ad 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -29,7 +29,7 @@ DeployabilityIndex, ) from sqlmesh.utils.date import to_datetime, to_timestamp, DatetimeRanges, TimeLike -from sqlmesh.utils.errors import CircuitBreakerError, AuditError +from sqlmesh.utils.errors import CircuitBreakerError, NodeAuditsErrors @pytest.fixture @@ -531,14 +531,28 @@ def _evaluate(): ) evaluator_audit_mock.return_value = [ - AuditResult(audit=audit, model=waiter_names.model, query=query, count=0, skipped=False) + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=0, + skipped=False, + ) ] _evaluate() assert notify_user_mock.call_count == 0 assert notify_mock.call_count == 0 evaluator_audit_mock.return_value = [ - AuditResult(audit=audit, model=waiter_names.model, query=query, count=None, skipped=True) + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=None, + skipped=True, + ) ] _evaluate() assert notify_user_mock.call_count == 0 @@ -547,6 +561,7 @@ def _evaluate(): evaluator_audit_mock.return_value = [ AuditResult( audit=audit, + audit_args={}, model=waiter_names.model, query=query, count=1, @@ -561,9 +576,16 @@ def _evaluate(): notify_mock.reset_mock() evaluator_audit_mock.return_value = [ - AuditResult(audit=audit, model=waiter_names.model, query=query, count=1, skipped=False) + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=1, + skipped=False, + ) ] - with pytest.raises(AuditError): + with pytest.raises(NodeAuditsErrors): _evaluate() assert notify_user_mock.call_count == 1 assert notify_mock.call_count == 1