Skip to content

Commit

Permalink
Feat: improve audit error message formatting (#3818)
Browse files Browse the repository at this point in the history
  • Loading branch information
treysp authored Feb 12, 2025
1 parent 8a5b5f4 commit eb4b264
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 28 deletions.
32 changes: 29 additions & 3 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
import uuid
import logging
import textwrap

from hyperscript import h
from rich.console import Console as RichConsole
Expand Down Expand Up @@ -35,7 +36,11 @@
from sqlmesh.utils import rich as srich
from sqlmesh.utils.concurrency import NodeExecutionFailedError
from sqlmesh.utils.date import time_like_to_str, to_date, yesterday_ds
from sqlmesh.utils.errors import PythonModelEvalError, format_destructive_change_msg
from sqlmesh.utils.errors import (
PythonModelEvalError,
NodeAuditsErrors,
format_destructive_change_msg,
)

if t.TYPE_CHECKING:
import ipywidgets as widgets
Expand Down Expand Up @@ -2448,9 +2453,13 @@ def _format_node_error(ex: NodeExecutionFailedError) -> str:

error_msg = str(cause)

if not isinstance(cause, (NodeExecutionFailedError, PythonModelEvalError)):
if isinstance(cause, NodeAuditsErrors):
error_msg = _format_audits_errors(cause)
elif not isinstance(cause, (NodeExecutionFailedError, PythonModelEvalError)):
error_msg = " " + error_msg.replace("\n", "\n ")
error_msg = f" {cause.__class__.__name__}:\n{error_msg}"
error_msg = (
f" {cause.__class__.__name__}:\n{error_msg}" # include error class name in msg
)
error_msg = error_msg.replace("\n", "\n ")
error_msg = error_msg + "\n" if not error_msg.rstrip(" ").endswith("\n") else error_msg

Expand All @@ -2474,3 +2483,20 @@ def _format_node_error(ex: NodeExecutionFailedError) -> str:
error_messages[node_name] = msg

return error_messages


def _format_audits_errors(error: NodeAuditsErrors) -> str:
error_messages = []
for err in error.errors:
audit_args_sql = []
for arg_name, arg_value in err.audit_args.items():
audit_args_sql.append(f"{arg_name} := {arg_value.sql(dialect=err.adapter_dialect)}")
audit_args_sql_msg = ("\n".join(audit_args_sql) + "\n\n") if audit_args_sql else ""

err_msg = f"'{err.audit_name}' audit error: {err.count} {'row' if err.count == 1 else 'rows'} failed"

query = "\n ".join(textwrap.wrap(err.sql(err.adapter_dialect), width=100))
msg = f"{err_msg}\n\nAudit arguments\n {audit_args_sql_msg}Audit query\n {query}\n\n"
msg = msg.replace("\n", "\n ")
error_messages.append(msg)
return " " + "\n".join(error_messages)
2 changes: 2 additions & 0 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,8 @@ def depends_on_self(self) -> bool:
class AuditResult(PydanticModel):
audit: Audit
"""The audit this result is for."""
audit_args: t.Dict[t.Any, t.Any]
"""Arguments passed to the audit."""
model: t.Optional[_Model] = None
"""The model this audit is for."""
count: t.Optional[int] = None
Expand Down
14 changes: 7 additions & 7 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import logging
import typing as t
from sqlglot import exp

from sqlmesh.core import constants as c
from sqlmesh.core.console import Console, get_console
from sqlmesh.core.environment import EnvironmentNamingInfo
Expand Down Expand Up @@ -37,7 +36,7 @@
to_timestamp,
validate_date_range,
)
from sqlmesh.utils.errors import AuditError, CircuitBreakerError, SQLMeshError
from sqlmesh.utils.errors import AuditError, NodeAuditsErrors, CircuitBreakerError, SQLMeshError

logger = logging.getLogger(__name__)
SnapshotToIntervals = t.Dict[Snapshot, Intervals]
Expand Down Expand Up @@ -205,10 +204,11 @@ def evaluate(
**kwargs,
)

audit_error_to_raise: t.Optional[AuditError] = None
audit_errors_to_raise: t.List[AuditError] = []
for audit_result in (result for result in audit_results if result.count):
error = AuditError(
audit_name=audit_result.audit.name,
audit_args=audit_result.audit_args,
model=snapshot.model_or_none,
count=t.cast(int, audit_result.count),
query=t.cast(exp.Query, audit_result.query),
Expand All @@ -220,14 +220,14 @@ def evaluate(
NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error
)
if audit_result.blocking:
audit_error_to_raise = error
audit_errors_to_raise.append(error)
else:
get_console().log_warning(
f"{error}\nAudit is non-blocking so proceeding with execution."
f"\n{error}. Audit is non-blocking so proceeding with execution. Audit query:\n{error.query.sql(error.adapter_dialect)}\n"
)

if audit_error_to_raise:
raise audit_error_to_raise
if audit_errors_to_raise:
raise NodeAuditsErrors(audit_errors_to_raise)

self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable)

Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def _audit(
if audit.skip:
return AuditResult(
audit=audit,
audit_args=audit_args,
model=snapshot.model_or_none,
skipped=True,
)
Expand Down Expand Up @@ -1054,6 +1055,7 @@ def _audit(

return AuditResult(
audit=audit,
audit_args=audit_args,
model=snapshot.model_or_none,
count=count,
query=query,
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/engines/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def evaluate(
if failed_audit_result:
raise AuditError(
audit_name=failed_audit_result.audit.name,
audit_args=failed_audit_result.audit_args,
model=command_payload.snapshot.model_or_none,
count=t.cast(int, failed_audit_result.count),
query=t.cast(exp.Query, failed_audit_result.query),
Expand Down
15 changes: 12 additions & 3 deletions sqlmesh/utils/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,23 @@ class AuditError(SQLMeshError):
def __init__(
self,
audit_name: str,
audit_args: t.Dict[t.Any, t.Any],
count: int,
query: exp.Query,
model: t.Optional[Model] = None,
# the dialect of the engine adapter that evaluated the audit query
adapter_dialect: t.Optional[str] = None,
) -> None:
self.audit_name = audit_name
self.audit_args = audit_args
self.model = model
self.count = count
self.query = query
self.adapter_dialect = adapter_dialect

def __str__(self) -> str:
model_str = f" for model '{self.model_name}'" if self.model_name else ""
return f"Audit '{self.audit_name}'{model_str} failed.\nGot {self.count} results, expected 0.\n{self.sql()}"
super().__init__(
f"'{self.audit_name}' audit error: {self.count} {'row' if self.count == 1 else 'rows'} failed"
)

@property
def model_name(self) -> t.Optional[str]:
Expand All @@ -106,6 +108,13 @@ def sql(self, dialect: t.Optional[str] = None, **opts: t.Any) -> str:
return self.query.sql(dialect=dialect or self.adapter_dialect, **opts)


class NodeAuditsErrors(SQLMeshError):
def __init__(self, errors: t.List[AuditError]) -> None:
self.errors = errors

super().__init__(f"Audits failed: {', '.join([e.audit_name for e in errors])}")


class TestError(SQLMeshError):
__test__ = False # prevent pytest trying to collect this as a test class
pass
Expand Down
18 changes: 8 additions & 10 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,14 +465,10 @@ def test_override_builtin_audit_blocking_mode():
new_snapshot = next(iter(plan.context_diff.new_snapshots.values()))

version = new_snapshot.fingerprint.to_version()
assert mock_logger.mock_calls == [
call(
"Audit 'not_null' for model 'db.x' failed.\n"
"Got 1 results, expected 0.\n"
f'SELECT * FROM (SELECT * FROM "sqlmesh__db"."db__x__{version}" AS "db__x__{version}") AS "_q_0" WHERE "c" IS NULL AND TRUE\n'
"Audit is non-blocking so proceeding with execution."
)
]
assert (
mock_logger.call_args_list[0][0][0]
== f'\n\'not_null\' audit error: 1 row failed. Audit is non-blocking so proceeding with execution. Audit query:\nSELECT * FROM (SELECT * FROM "sqlmesh__db"."db__x__{version}" AS "db__x__{version}") AS "_q_0" WHERE "c" IS NULL AND TRUE\n'
)

# Even though there are two builtin audits referenced in the above definition, we only
# store the one that overrides `blocking` in the snapshot; the other one isn't needed
Expand Down Expand Up @@ -1372,6 +1368,8 @@ def test_plan_runs_audits_on_dev_previews(sushi_context: Context, capsys, caplog
# we only see audit results if they fail
stdout = capsys.readouterr().out
log = caplog.text
assert "Audit 'not_null' for model 'sushi.test_audit_model' failed" in log
assert "Audit is non-blocking so proceeding with execution" in log
assert (
"\n'not_null' audit error: 22 rows failed. Audit is non-blocking so proceeding with execution. Audit query:\nSELECT"
in log
)
assert "Target environment updated successfully" in stdout
32 changes: 27 additions & 5 deletions tests/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
DeployabilityIndex,
)
from sqlmesh.utils.date import to_datetime, to_timestamp, DatetimeRanges, TimeLike
from sqlmesh.utils.errors import CircuitBreakerError, AuditError
from sqlmesh.utils.errors import CircuitBreakerError, NodeAuditsErrors


@pytest.fixture
Expand Down Expand Up @@ -531,14 +531,28 @@ def _evaluate():
)

evaluator_audit_mock.return_value = [
AuditResult(audit=audit, model=waiter_names.model, query=query, count=0, skipped=False)
AuditResult(
audit=audit,
audit_args={},
model=waiter_names.model,
query=query,
count=0,
skipped=False,
)
]
_evaluate()
assert notify_user_mock.call_count == 0
assert notify_mock.call_count == 0

evaluator_audit_mock.return_value = [
AuditResult(audit=audit, model=waiter_names.model, query=query, count=None, skipped=True)
AuditResult(
audit=audit,
audit_args={},
model=waiter_names.model,
query=query,
count=None,
skipped=True,
)
]
_evaluate()
assert notify_user_mock.call_count == 0
Expand All @@ -547,6 +561,7 @@ def _evaluate():
evaluator_audit_mock.return_value = [
AuditResult(
audit=audit,
audit_args={},
model=waiter_names.model,
query=query,
count=1,
Expand All @@ -561,9 +576,16 @@ def _evaluate():
notify_mock.reset_mock()

evaluator_audit_mock.return_value = [
AuditResult(audit=audit, model=waiter_names.model, query=query, count=1, skipped=False)
AuditResult(
audit=audit,
audit_args={},
model=waiter_names.model,
query=query,
count=1,
skipped=False,
)
]
with pytest.raises(AuditError):
with pytest.raises(NodeAuditsErrors):
_evaluate()
assert notify_user_mock.call_count == 1
assert notify_mock.call_count == 1
Expand Down

0 comments on commit eb4b264

Please sign in to comment.