Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public async Task<IEnumerable<string>> GetFeaturesThatNeedUpdateAsync()
return CreateUpgradeLookupTable(dataMigration).ContainsKey(record.Version.Value);
}

return GetMethod(dataMigration, "Create") != null;
return GetCreateMethod(dataMigration) != null;
});

return outOfDateMigrations.Select(m => _typeFeatureProvider.GetFeatureForDependency(m.GetType()).Id).ToArray();
Expand All @@ -103,11 +103,22 @@ public async Task Uninstall(string feature)
// get current version for this migration
var dataMigrationRecord = await GetDataMigrationRecordAsync(tempMigration);

var uninstallMethod = GetMethod(migration, "Uninstall");
var uninstallMethod = GetUninstallMethod(migration);

if (uninstallMethod != null)
{
await InvokeMethodAsync(uninstallMethod, migration);
if (uninstallMethod.ReturnType == typeof(Task))
{
await (Task)uninstallMethod.Invoke(migration, []);
}
else if (uninstallMethod.ReturnType == typeof(void))
{
uninstallMethod.Invoke(migration, []);
}
else
{
throw new InvalidOperationException("Invalid return type used in a migration method.");
}
}

if (dataMigrationRecord == null)
Expand Down Expand Up @@ -201,15 +212,15 @@ private async Task UpdateAsync(string featureId)
if (current == 0)
{
// Try to get a Create method.
var createMethod = GetMethod(migration, "Create");
var createMethod = GetCreateMethod(migration);

if (createMethod == null)
{
_logger.LogWarning("The migration '{Name}' for '{FeatureName}' does not contain a proper Create or CreateAsync method.", migration.GetType().FullName, featureId);
continue;
}

current = await InvokeMethodAsync(createMethod, migration);
current = await InvokeCreateOrUpdateMethodAsync(createMethod, migration);
}

var lookupTable = CreateUpgradeLookupTable(migration);
Expand All @@ -218,7 +229,7 @@ private async Task UpdateAsync(string featureId)
{
_logger.LogInformation("Applying migration for '{Migration}' in '{FeatureId}' from version {Version}.", migration.GetType().FullName, featureId, current);

current = await InvokeMethodAsync(methodInfo, migration);
current = await InvokeCreateOrUpdateMethodAsync(methodInfo, migration);
}

// If current is 0, it means no upgrade/create method was found or succeeded.
Expand All @@ -243,7 +254,7 @@ private async Task UpdateAsync(string featureId)
}
}

private static async Task<int> InvokeMethodAsync(MethodInfo method, IDataMigration migration)
private static async Task<int> InvokeCreateOrUpdateMethodAsync(MethodInfo method, IDataMigration migration)
{
if (method.ReturnType == typeof(Task<int>))
{
Expand Down Expand Up @@ -307,21 +318,19 @@ private static Tuple<int, MethodInfo> GetUpdateFromMethod(MethodInfo methodInfo)
return null;
}

/// <summary>
/// Returns the method from a data migration class that matches the given name if found.
/// </summary>
private static MethodInfo GetMethod(IDataMigration dataMigration, string name)
private static MethodInfo GetCreateMethod(IDataMigration dataMigration)
{
var methodName = "Create";
// First try to find a method that match the given name. (Ex. Create())
var methodInfo = dataMigration.GetType().GetMethod(name, BindingFlags.Public | BindingFlags.Instance);
var methodInfo = dataMigration.GetType().GetMethod(methodName, BindingFlags.Public | BindingFlags.Instance);

if (methodInfo != null && (methodInfo.ReturnType == typeof(int) || methodInfo.ReturnType == typeof(Task<int>)))
{
return methodInfo;
}

// At this point, try to find a method that matches the given name and ends with Async. (Ex. CreateAsync())
methodInfo = dataMigration.GetType().GetMethod(name + _asyncSuffix, BindingFlags.Public | BindingFlags.Instance);
methodInfo = dataMigration.GetType().GetMethod(methodName + _asyncSuffix, BindingFlags.Public | BindingFlags.Instance);

if (methodInfo != null && methodInfo.ReturnType == typeof(Task<int>))
{
Expand All @@ -330,4 +339,24 @@ private static MethodInfo GetMethod(IDataMigration dataMigration, string name)

return null;
}

private static MethodInfo GetUninstallMethod(IDataMigration dataMigration)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could merge both methods instead

{
var methodName = "Uninstall";
var methodInfo = dataMigration.GetType().GetMethod(methodName, BindingFlags.Public | BindingFlags.Instance);

if (methodInfo != null && (methodInfo.ReturnType == typeof(void) || methodInfo.ReturnType == typeof(Task)))
{
return methodInfo;
}

methodInfo = dataMigration.GetType().GetMethod(methodName + _asyncSuffix, BindingFlags.Public | BindingFlags.Instance);

if (methodInfo != null && methodInfo.ReturnType == typeof(Task))
{
return methodInfo;
}

return null;
}
}
221 changes: 221 additions & 0 deletions test/OrchardCore.Tests/Data/Migration/DataMigrationManagerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
using System.Data.Common;
using OrchardCore.Environment.Extensions;
using OrchardCore.Environment.Extensions.Features;
using ISession = YesSql.ISession;

namespace OrchardCore.Data.Migration.Tests;

public class DataMigrationManagerTests
{
[Fact]
public async Task UpdateAsync_ShouldExecuteDataMigration_CreateMethod_OnFreshMigration()
{
// Arrange
var migration1 = new Migration1();
var migration2 = new Migration2();
var migrationManager = GetDataMigrationManager([migration1, migration2]);

// Act
await migrationManager.UpdateAsync("TestFeature");

// Assert
Assert.True(migration1.CreateCalled);
Assert.True(migration2.CreateCalled);
}

[Fact]
public async Task UpdateAsync_ShouldExecuteDataMigration_UpdateFromMethods()
{
// Arrange
var migration1 = new Migration1();
var migration2 = new Migration2();
var migrationManager = GetDataMigrationManager([migration1, migration2]);

// Act
await migrationManager.UpdateAsync("TestFeature");

// Assert
Assert.Equal(2, migration1.UpdateFromCalls);
Assert.Equal(0, migration2.UpdateFromCalls);
}

[Fact]
public async Task Uninstall_ShouldExecuteDataMigration_UninstallMethod()
{
// Arrange
var migration1 = new Migration1();
var migration2 = new Migration2();
var migrationManager = GetDataMigrationManager([migration1, migration2]);

// Act
await migrationManager.Uninstall("TestFeature");

// Assert
Assert.True(migration1.UninstallCalled);
Assert.True(migration2.UninstallCalled);
}

private static DataMigrationManager GetDataMigrationManager(IEnumerable<DataMigration> dataMigrations)
{
var featureInfo = new Mock<IFeatureInfo>();
featureInfo.Setup(f => f.Id).Returns("TestFeature");

var typeFeatureProviderMock = new Mock<ITypeFeatureProvider>();
typeFeatureProviderMock.Setup(m => m.GetFeatureForDependency(It.IsAny<Type>()))
.Returns(featureInfo.Object);

var extensionManagerMock = new Mock<IExtensionManager>();
extensionManagerMock.Setup(m => m.GetFeatureDependencies(It.IsAny<string>()))
.Returns(Enumerable.Empty<IFeatureInfo>());

var sessionMock = new Mock<ISession>();
sessionMock.Setup(s => s.BeginTransactionAsync())
.ReturnsAsync(Mock.Of<DbTransaction>());

sessionMock.Setup(s => s.Query())
.Returns(new FakeQuery());

sessionMock.Setup(s => s.SaveAsync(It.IsAny<object>()))
.Returns(Task.CompletedTask);

var storeMock = new Mock<IStore>();
storeMock.Setup(s => s.Configuration).Returns(new Configuration());

return new DataMigrationManager(
typeFeatureProviderMock.Object,
dataMigrations,
sessionMock.Object,
storeMock.Object,
extensionManagerMock.Object,
NullLogger<DataMigrationManager>.Instance);
}

private sealed class Migration1 : DataMigration
{
public bool CreateCalled { get; private set; }

public bool UninstallCalled { get; private set; }

public int UpdateFromCalls { get; private set; }

public int Create()
{
CreateCalled = true;

return 1;
}

public int UpdateFrom1()
{
++UpdateFromCalls;

return 2;
}

public Task<int> UpdateFrom2Async()
{
++UpdateFromCalls;

return Task.FromResult(3);
}

#pragma warning disable CA1822 // Mark members as static
public int UpdateFromInvalid() => 0;
#pragma warning restore CA1822 // Mark members as static

public void Uninstall() => UninstallCalled = true;
}

private sealed class Migration2 : DataMigration
{
public bool CreateCalled { get; private set; }

public bool UninstallCalled { get; private set; }

public int UpdateFromCalls { get; private set; }

public Task<int> CreateAsync()
{
CreateCalled = true;

return Task.FromResult(1);
}

public Task UninstallAsync()
{
UninstallCalled = true;

return Task.CompletedTask;
}
}

private sealed class FakeQuery : IQuery
{
public IQuery<object> Any()
=> throw new NotImplementedException();

public IQuery<T> For<T>(bool filterType = true) where T : class => new FakeQuery<T>();

IQueryIndex<T> IQuery.ForIndex<T>()
=> throw new NotImplementedException();
}

private sealed class FakeQuery<T> : IQuery<T> where T : class
{
public IQuery<T> All(params Func<IQuery<T>, IQuery<T>>[] predicates)
=> throw new NotImplementedException();

public ValueTask<IQuery<T>> AllAsync(params Func<IQuery<T>, ValueTask<IQuery<T>>>[] predicates)
=> throw new NotImplementedException();

public IQuery<T> Any(params Func<IQuery<T>, IQuery<T>>[] predicates)
=> throw new NotImplementedException();

public ValueTask<IQuery<T>> AnyAsync(params Func<IQuery<T>, ValueTask<IQuery<T>>>[] predicates)
=> throw new NotImplementedException();

public Task<int> CountAsync(CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public Task<int> CountAsync()
=> throw new NotImplementedException();

public Task<T> FirstOrDefaultAsync(CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public Task<T> FirstOrDefaultAsync() => Task.FromResult((T)null);

public string GetTypeAlias(Type t)
=> throw new NotImplementedException();

public Task<IEnumerable<T>> ListAsync(CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public Task<IEnumerable<T>> ListAsync()
=> throw new NotImplementedException();

public IQuery<T> NoDuplicates()
=> throw new NotImplementedException();

public IQuery<T> Skip(int count)
=> throw new NotImplementedException();

public IQuery<T> Take(int count)
=> throw new NotImplementedException();

public IAsyncEnumerable<T> ToAsyncEnumerable(CancellationToken cancellationToken = default)
=> throw new NotImplementedException();

public IAsyncEnumerable<T> ToAsyncEnumerable()
=> throw new NotImplementedException();

public IQuery<T> With(Type indexType)
=> throw new NotImplementedException();

IQuery<T, TIndex> IQuery<T>.With<TIndex>()
=> throw new NotImplementedException();

IQuery<T, TIndex> IQuery<T>.With<TIndex>(System.Linq.Expressions.Expression<Func<TIndex, bool>> predicate)
=> throw new NotImplementedException();
}
}