Skip to content

Commit

Permalink
Fix bug where we don't notice modules with generic attributes.
Browse files Browse the repository at this point in the history
Refactor and unit tests so this won't happen in the future.
  • Loading branch information
YairHalberstadt committed Nov 24, 2021
1 parent 5aae204 commit 8d4fcf6
Show file tree
Hide file tree
Showing 15 changed files with 431 additions and 167 deletions.
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<LangVersion>preview</LangVersion>
<Nullable>enable</Nullable>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<NoWarn>CS1591</NoWarn>
<NoWarn>CS1591;RS1024</NoWarn>
<WarningsNotAsErrors>AD0001</WarningsNotAsErrors>
<EnforceCodeStyleInBuild>true</EnforceCodeStyleInBuild>
<AnalysisMode>AllEnabledByDefault</AnalysisMode>
Expand Down
10 changes: 5 additions & 5 deletions StrongInject.Generator/ContainerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ private ContainerGenerator(

_containerInterfaces = _container.AllInterfaces
.Where(x
=> x.OriginalDefinition.Equals(_wellKnownTypes.IContainer, SymbolEqualityComparer.Default)
|| x.OriginalDefinition.Equals(_wellKnownTypes.IAsyncContainer, SymbolEqualityComparer.Default))
.Select(x => (containerInterface: x, isAsync: x.OriginalDefinition.Equals(_wellKnownTypes.IAsyncContainer, SymbolEqualityComparer.Default)))
=> x.OriginalDefinition.Equals(_wellKnownTypes.IContainer)
|| x.OriginalDefinition.Equals(_wellKnownTypes.IAsyncContainer))
.Select(x => (containerInterface: x, isAsync: x.OriginalDefinition.Equals(_wellKnownTypes.IAsyncContainer)))
.ToList();

foreach (var (_, isAsync) in _containerInterfaces)
Expand Down Expand Up @@ -422,9 +422,9 @@ private void CreateVariables(
{
FactorySource => true,
FactoryMethod { Method: { ReturnType: var returnType } }
=> returnType.OriginalDefinition.Equals(_wellKnownTypes.ValueTask1, SymbolEqualityComparer.Default),
=> returnType.OriginalDefinition.Equals(_wellKnownTypes.ValueTask1),
WrappedDecoratorInstanceSource { Decorator: DecoratorFactoryMethod { Method: { ReturnType: var returnType } } }
=> returnType.OriginalDefinition.Equals(_wellKnownTypes.ValueTask1, SymbolEqualityComparer.Default),
=> returnType.OriginalDefinition.Equals(_wellKnownTypes.ValueTask1),
_ => throw new NotImplementedException(source.GetType().ToString())
},
_ => throw new NotImplementedException(operation.Statement.GetType().ToString()),
Expand Down
4 changes: 2 additions & 2 deletions StrongInject.Generator/GenericDecoratorsResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal class GenericDecoratorsResolver

public GenericDecoratorsResolver(Compilation compilation, IEnumerable<DecoratorSource> decoratorFactoryMethods)
{
_namedTypeDecoratorSources = new Dictionary<INamedTypeSymbol, List<DecoratorSource>>(SymbolEqualityComparer.Default);
_namedTypeDecoratorSources = new Dictionary<INamedTypeSymbol, List<DecoratorSource>>();
_arrayDecoratorSources = new List<DecoratorSource>();
_typeParameterDecoratorSources = new List<DecoratorSource>();

Expand Down Expand Up @@ -55,7 +55,7 @@ public IEnumerable<DecoratorSource> ResolveDecorators(ITypeSymbol type)
var constructed = decoratorRegistration.Type.Construct(namedType.TypeArguments.ToArray());
var originalConstructor = decoratorRegistration.Constructor;
var constructor = constructed.InstanceConstructors.First(
x => SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, originalConstructor));
x => x.OriginalDefinition.Equals(originalConstructor));

yield return decoratorRegistration with
{
Expand Down
12 changes: 6 additions & 6 deletions StrongInject.Generator/GenericRegistrationsResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ public GenericRegistrationsResolver Build(Compilation compilation)

static (Dictionary<INamedTypeSymbol, Bucket> namedTypeBuckets, Bucket? otherTypesBucket, Bucket? typeParameterBucket) Partition(Builder builder, Compilation compilation)
{
Dictionary<INamedTypeSymbol, BucketBuilder> namedTypeBucketBuilders = new(SymbolEqualityComparer.Default);
Dictionary<INamedTypeSymbol, BucketBuilder> namedTypeBucketBuilders = new();
BucketBuilder? otherTypesBucketBuilder = null;
BucketBuilder? typeParameterBucketBuilder = null;

Expand Down Expand Up @@ -320,11 +320,11 @@ public bool TryResolve(ITypeSymbol type, out InstanceSource instanceSource, out

foreach (var registration in _registrations)
{
if (registration.Type.OriginalDefinition.Equals(type.OriginalDefinition, SymbolEqualityComparer.Default))
if (registration.Type.OriginalDefinition.Equals(type.OriginalDefinition))
{
var originalConstructor = registration.Constructor.OriginalDefinition;
var constructor = ((INamedTypeSymbol)type).InstanceConstructors.First(
x => SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, originalConstructor));
x => x.OriginalDefinition.Equals(originalConstructor));

var updatedRegistration = registration with
{
Expand All @@ -346,7 +346,7 @@ public bool TryResolve(ITypeSymbol type, out InstanceSource instanceSource, out

foreach (var forwardedInstanceSource in _forwardedInstanceSources)
{
if (forwardedInstanceSource.AsType.OriginalDefinition.Equals(type.OriginalDefinition, SymbolEqualityComparer.Default))
if (forwardedInstanceSource.AsType.OriginalDefinition.Equals(type.OriginalDefinition))
{
if (forwardedInstanceSource.Underlying is Registration registration)
{
Expand All @@ -357,7 +357,7 @@ public bool TryResolve(ITypeSymbol type, out InstanceSource instanceSource, out
var constructedRegistrationType =
registration.Type.OriginalDefinition.Construct(typeArguments.ToArray());
var constructor = constructedRegistrationType.InstanceConstructors.First(
x => SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, originalConstructor));
x => x.OriginalDefinition.Equals(originalConstructor));

var updatedRegistration = registration with
{
Expand Down Expand Up @@ -460,7 +460,7 @@ private IEnumerable<FactoryMethod> GetAllRelevantFactoryMethods(ITypeSymbol toCo

private static bool IsRelevant(FactoryOfMethod factoryOfMethod, ITypeSymbol toConstruct)
{
return factoryOfMethod.FactoryOfType.OriginalDefinition.Equals(toConstruct.OriginalDefinition, SymbolEqualityComparer.Default);
return factoryOfMethod.FactoryOfType.OriginalDefinition.Equals(toConstruct.OriginalDefinition);
}

private bool CanConstructFromGenericFactoryMethod(ITypeSymbol toConstruct, FactoryMethod factoryMethod, out FactoryMethod constructedFactoryMethod, out bool constraintsDoNotMatch)
Expand Down
2 changes: 1 addition & 1 deletion StrongInject.Generator/GenericResolutionHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static bool CanConstructFrom(ITypeSymbol toConstruct, ITypeSymbol toConstructFro
}

if (callingConvention != toConstructCallingConvention
|| !unmanagedCallingConventionTypes.SequenceEqual<INamedTypeSymbol, INamedTypeSymbol>(toConstructUnmanagedCallingConventionTypes, SymbolEqualityComparer.Default)
|| !unmanagedCallingConventionTypes.SequenceEqual(toConstructUnmanagedCallingConventionTypes)
|| refKind != toConstructRefKind
|| parameters.Length != toConstructParameters.Length
|| !CanConstructFrom(toConstructReturnType, returnType, method, ref typeArguments))
Expand Down
2 changes: 1 addition & 1 deletion StrongInject.Generator/InstanceSourcesScope.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public InstanceSourcesScope Enter(InstanceSource instanceSource)
case DelegateSource { Parameters: var parameters }:
var newDepth = Depth + 1;
var delegateParameters = _delegateParameters is null
? new Dictionary<ITypeSymbol, DelegateParameter>(SymbolEqualityComparer.Default)
? new Dictionary<ITypeSymbol, DelegateParameter>()
: new Dictionary<ITypeSymbol, DelegateParameter>(_delegateParameters);
foreach (var param in parameters)
{
Expand Down
Loading

0 comments on commit 8d4fcf6

Please sign in to comment.