Skip to content

Commit

Permalink
[BACKPORT] Swap parameters in Math.Log(double a, double newBase) tr…
Browse files Browse the repository at this point in the history
…anslation (#1894)

* Swap parameters in `Math.Log(double a, double newBase)` translation. (#1891)

(cherry picked from commit 5ca0fa1)
  • Loading branch information
lauxjpn authored Mar 18, 2024
1 parent ef39c34 commit 326b218
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 97 deletions.
195 changes: 99 additions & 96 deletions src/EFCore.MySql/Query/Internal/MySqlMathMethodTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Diagnostics;
Expand All @@ -13,97 +14,97 @@ namespace Pomelo.EntityFrameworkCore.MySql.Query.Internal
{
public class MySqlMathMethodTranslator : IMethodCallTranslator
{
private static readonly IDictionary<MethodInfo, (string Name, bool OnlyNullByArgs)> _methodToFunctionName = new Dictionary<MethodInfo, (string Name, bool OnlyNullByArgs)>
private static readonly IDictionary<MethodInfo, (string Name, bool OnlyNullByArgs, bool ReverseArgs)> _methodToFunctionName = new Dictionary<MethodInfo, (string Name, bool OnlyNullByArgs, bool ReverseArgs)>
{
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(decimal) }), ("ABS", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(double) }), ("ABS", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(float) }), ("ABS", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(int) }), ("ABS", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(long) }), ("ABS", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(short) }), ("ABS", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Abs), new[] { typeof(float) }), ("ABS", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Acos), new[] { typeof(double) }), ("ACOS", false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), new[] { typeof(float) }), ("ACOS", false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Asin), new[] { typeof(double) }), ("ASIN", false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), new[] { typeof(float) }), ("ASIN", false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan), new[] { typeof(double) }), ("ATAN", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), new[] { typeof(float) }), ("ATAN", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) }), ("ATAN2", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), new[] { typeof(float), typeof(float) }), ("ATAN2", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(decimal) }), ("CEILING", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(double) }), ("CEILING", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), new[] { typeof(float) }), ("CEILING", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Cos), new[] { typeof(double) }), ("COS", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), new[] { typeof(float) }), ("COS", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Exp), new[] { typeof(double) }), ("EXP", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), new[] { typeof(float) }), ("EXP", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(decimal) }), ("FLOOR", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(double) }), ("FLOOR", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), new[] { typeof(float) }), ("FLOOR", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double) }), ("LOG", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double), typeof(double) }), ("LOG", false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), new[] { typeof(float) }), ("LOG", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), new[] { typeof(float), typeof(float) }), ("LOG", false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Log10), new[] { typeof(double) }), ("LOG10", false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), new[] { typeof(float) }), ("LOG10", false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(decimal), typeof(decimal) }), ("GREATEST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(double), typeof(double) }), ("GREATEST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(float), typeof(float) }), ("GREATEST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(int), typeof(int) }), ("GREATEST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(long), typeof(long) }), ("GREATEST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(short), typeof(short) }), ("GREATEST", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Max), new[] { typeof(float), typeof(float) }), ("GREATEST", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(decimal), typeof(decimal) }), ("LEAST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(double), typeof(double) }), ("LEAST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(float), typeof(float) }), ("LEAST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(int), typeof(int) }), ("LEAST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(long), typeof(long) }), ("LEAST", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(short), typeof(short) }), ("LEAST", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Min), new[] { typeof(float), typeof(float) }), ("LEAST", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) }), ("POWER", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), new[] { typeof(float), typeof(float) }), ("POWER", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(double) }), ("ROUND", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(double), typeof(int) }), ("ROUND", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(decimal) }), ("ROUND", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(decimal), typeof(int) }), ("ROUND", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), new[] { typeof(float) }), ("ROUND", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), new[] { typeof(float), typeof(int) }), ("ROUND", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(decimal) }), ("SIGN", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(double) }), ("SIGN", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(float) }), ("SIGN", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(int) }), ("SIGN", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(long) }), ("SIGN", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(sbyte) }), ("SIGN", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(short) }), ("SIGN", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sign), new[] { typeof(float) }), ("SIGN", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Sin), new[] { typeof(double) }), ("SIN", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), new[] { typeof(float) }), ("SIN", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), new[] { typeof(double) }), ("SQRT", false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), new[] { typeof(float) }), ("SQRT", false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Tan), new[] { typeof(double) }), ("TAN", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), new[] { typeof(float) }), ("TAN", true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), new[] { typeof(double) }), ("TRUNCATE", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), new[] { typeof(decimal) }), ("TRUNCATE", true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Truncate), new[] { typeof(float) }), ("TRUNCATE", true) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(decimal) }), ("ABS", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(double) }), ("ABS", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(float) }), ("ABS", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(int) }), ("ABS", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(long) }), ("ABS", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Abs), new[] { typeof(short) }), ("ABS", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Abs), new[] { typeof(float) }), ("ABS", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Acos), new[] { typeof(double) }), ("ACOS", false, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Acos), new[] { typeof(float) }), ("ACOS", false, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Asin), new[] { typeof(double) }), ("ASIN", false, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Asin), new[] { typeof(float) }), ("ASIN", false, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan), new[] { typeof(double) }), ("ATAN", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan), new[] { typeof(float) }), ("ATAN", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) }), ("ATAN2", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Atan2), new[] { typeof(float), typeof(float) }), ("ATAN2", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(decimal) }), ("CEILING", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Ceiling), new[] { typeof(double) }), ("CEILING", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Ceiling), new[] { typeof(float) }), ("CEILING", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Cos), new[] { typeof(double) }), ("COS", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Cos), new[] { typeof(float) }), ("COS", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Exp), new[] { typeof(double) }), ("EXP", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Exp), new[] { typeof(float) }), ("EXP", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(decimal) }), ("FLOOR", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Floor), new[] { typeof(double) }), ("FLOOR", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Floor), new[] { typeof(float) }), ("FLOOR", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double) }), ("LOG", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Log), new[] { typeof(double), typeof(double) }), ("LOG", false, true) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), new[] { typeof(float) }), ("LOG", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log), new[] { typeof(float), typeof(float) }), ("LOG", false, true) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Log10), new[] { typeof(double) }), ("LOG10", false, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Log10), new[] { typeof(float) }), ("LOG10", false, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(decimal), typeof(decimal) }), ("GREATEST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(double), typeof(double) }), ("GREATEST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(float), typeof(float) }), ("GREATEST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(int), typeof(int) }), ("GREATEST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(long), typeof(long) }), ("GREATEST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Max), new[] { typeof(short), typeof(short) }), ("GREATEST", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Max), new[] { typeof(float), typeof(float) }), ("GREATEST", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(decimal), typeof(decimal) }), ("LEAST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(double), typeof(double) }), ("LEAST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(float), typeof(float) }), ("LEAST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(int), typeof(int) }), ("LEAST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(long), typeof(long) }), ("LEAST", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Min), new[] { typeof(short), typeof(short) }), ("LEAST", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Min), new[] { typeof(float), typeof(float) }), ("LEAST", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) }), ("POWER", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Pow), new[] { typeof(float), typeof(float) }), ("POWER", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(double) }), ("ROUND", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(double), typeof(int) }), ("ROUND", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(decimal) }), ("ROUND", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Round), new[] { typeof(decimal), typeof(int) }), ("ROUND", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), new[] { typeof(float) }), ("ROUND", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Round), new[] { typeof(float), typeof(int) }), ("ROUND", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(decimal) }), ("SIGN", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(double) }), ("SIGN", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(float) }), ("SIGN", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(int) }), ("SIGN", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(long) }), ("SIGN", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(sbyte) }), ("SIGN", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Sign), new[] { typeof(short) }), ("SIGN", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sign), new[] { typeof(float) }), ("SIGN", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Sin), new[] { typeof(double) }), ("SIN", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sin), new[] { typeof(float) }), ("SIN", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Sqrt), new[] { typeof(double) }), ("SQRT", false, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Sqrt), new[] { typeof(float) }), ("SQRT", false, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Tan), new[] { typeof(double) }), ("TAN", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Tan), new[] { typeof(float) }), ("TAN", true, false) },

{ typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), new[] { typeof(double) }), ("TRUNCATE", true, false) },
{ typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), new[] { typeof(decimal) }), ("TRUNCATE", true, false) },
{ typeof(MathF).GetRuntimeMethod(nameof(MathF.Truncate), new[] { typeof(float) }), ("TRUNCATE", true, false) },
};

private readonly MySqlSqlExpressionFactory _sqlExpressionFactory;
Expand All @@ -127,18 +128,20 @@ public virtual SqlExpression Translate(
targetArgumentsCount = 2;
}

Debug.Assert(targetArgumentsCount is >= 1 and <= 2);

var newArguments = new SqlExpression[targetArgumentsCount];
newArguments[0] = arguments[0];

if (targetArgumentsCount == 2)
{
if (arguments.Count == 2)
{
newArguments[1] = arguments[1];
}
else
newArguments[1] = arguments.Count == 2
? arguments[1]
: _sqlExpressionFactory.Constant(0);

if (mapping.ReverseArgs)
{
newArguments[1] = _sqlExpressionFactory.Constant(0);
(newArguments[0], newArguments[1]) = (newArguments[1], newArguments[0]);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
using Pomelo.EntityFrameworkCore.MySql.FunctionalTests.TestUtilities;
using Xunit;

namespace Pomelo.EntityFrameworkCore.MySql.FunctionalTests.Query
Expand Down Expand Up @@ -788,5 +789,19 @@ await AssertQuery(
FROM `Customers` AS `c`
WHERE (LOCATE(CONVERT(LCASE('nt') USING utf8mb4) COLLATE utf8mb4_bin, LCASE(`c`.`CustomerID`)) - 1) = 1");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Where_math_log_new_base2(bool async)
{
// The original `Where_math_log_new_base` test will succeed even if the number and base are swapped by accident.
await AssertQueryScalar(
async,
ss => ss.Set<OrderDetail>().Where(od => od.OrderID == 11077 && od.Discount > 0).Where(od => Math.Log(od.Discount, 7) < -1).Select(od => Math.Log(od.Discount, 7)));

AssertSql($@"SELECT LOG(7.0, {CastAsDouble("`o`.`Discount`")})
FROM `Order Details` AS `o`
WHERE ((`o`.`OrderID` = 11077) AND (`o`.`Discount` > 0)) AND (LOG(7.0, {CastAsDouble("`o`.`Discount`")}) < -1.0)");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ public override async Task Where_math_log_new_base(bool async)
AssertSql(
$@"SELECT `o`.`OrderID`, `o`.`ProductID`, `o`.`Discount`, `o`.`Quantity`, `o`.`UnitPrice`
FROM `Order Details` AS `o`
WHERE ((`o`.`OrderID` = 11077) AND (`o`.`Discount` > 0)) AND (LOG({CastAsDouble("`o`.`Discount`")}, 7.0) < 0.0)");
WHERE ((`o`.`OrderID` = 11077) AND (`o`.`Discount` > 0)) AND (LOG(7.0, {CastAsDouble("`o`.`Discount`")}) < 0.0)");
}

public override async Task Where_math_sqrt(bool async)
Expand Down

0 comments on commit 326b218

Please sign in to comment.