diff --git a/src/MongoFramework/IMongoDbSet.cs b/src/MongoFramework/IMongoDbSet.cs index 46da466c..e087b728 100644 --- a/src/MongoFramework/IMongoDbSet.cs +++ b/src/MongoFramework/IMongoDbSet.cs @@ -18,6 +18,7 @@ public interface IMongoDbSet : IMongoDbSet, IQueryable where T { IMongoDbContext Context { get; } TEntity Find(object id); + ValueTask FindAsync(object id); TEntity Create(); void Add(TEntity entity); void AddRange(IEnumerable entities); diff --git a/src/MongoFramework/MongoDbSet.cs b/src/MongoFramework/MongoDbSet.cs index 45561f1f..b2019853 100644 --- a/src/MongoFramework/MongoDbSet.cs +++ b/src/MongoFramework/MongoDbSet.cs @@ -34,43 +34,79 @@ public virtual TEntity Create() var entity = Activator.CreateInstance(); Add(entity); return entity; - } - - /// - /// Finds an entity with the given primary key value. If an entity with the given primary key value - /// is being tracked by the context, then it is returned immediately without making a request to the - /// database. Otherwise, a query is made to the database for an entity with the given primary key value - /// and this entity, if found, is attached to the context and returned. If no entity is found, then - /// null is returned. - /// - /// The value of the primary key for the entity to be found. - /// The entity found, or null. - public virtual TEntity Find(object id) - { - Check.NotNull(id, nameof(id)); - - var tracked = Context.ChangeTracker.GetEntryById(id); - - if (tracked != null) - { - return tracked.Entity as TEntity; - } - - var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity)); - var filter = entityDefinition.CreateIdFilter(id); - - var collection = Context.Connection.GetDatabase().GetCollection(entityDefinition.CollectionName); - var cursor = collection.Find(filter); - var entity = cursor.FirstOrDefault(); - - if (entity != null) - { - Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); - } - - return entity; - } - + } + + /// + /// Finds an entity with the given primary key value. If an entity with the given primary key value + /// is being tracked by the context, then it is returned immediately without making a request to the + /// database. Otherwise, a query is made to the database for an entity with the given primary key value + /// and this entity, if found, is attached to the context and returned. If no entity is found, then + /// null is returned. + /// + /// The value of the primary key for the entity to be found. + /// The entity found, or null. + public virtual TEntity Find(object id) + { + Check.NotNull(id, nameof(id)); + + var tracked = Context.ChangeTracker.GetEntryById(id); + + if (tracked != null) + { + return tracked.Entity as TEntity; + } + + var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity)); + var filter = entityDefinition.CreateIdFilter(id); + + var collection = Context.Connection.GetDatabase().GetCollection(entityDefinition.CollectionName); + var cursor = collection.Find(filter); + var entity = cursor.FirstOrDefault(); + + if (entity != null) + { + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + + return entity; + } + + /// + /// Finds an entity with the given primary key value. If an entity with the given primary key value + /// is being tracked by the context, then it is returned immediately without making a request to the + /// database. Otherwise, a query is made to the database for an entity with the given primary key value + /// and this entity, if found, is attached to the context and returned. If no entity is found, then + /// null is returned. + /// + /// The value of the primary key for the entity to be found. + /// The entity found, or null. + public virtual async ValueTask FindAsync(object id) + { + Check.NotNull(id, nameof(id)); + + var tracked = Context.ChangeTracker.GetEntryById(id); + + if (tracked != null) + { + return tracked.Entity as TEntity; + } + + var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity)); + var filter = entityDefinition.CreateIdFilter(id); + + var collection = Context.Connection.GetDatabase().GetCollection(entityDefinition.CollectionName); + var cursor = await collection.FindAsync(filter); + var entity = await cursor.FirstOrDefaultAsync(); + + if (entity != null) + { + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + + return entity; + } + + /// /// Marks the entity for insertion into the database. /// diff --git a/src/MongoFramework/MongoDbTenantSet.cs b/src/MongoFramework/MongoDbTenantSet.cs index dbd5a845..42872142 100644 --- a/src/MongoFramework/MongoDbTenantSet.cs +++ b/src/MongoFramework/MongoDbTenantSet.cs @@ -5,6 +5,11 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Threading.Tasks; +using MongoDB.Driver; +using MongoFramework.Infrastructure; +using MongoFramework.Infrastructure.Commands; +using MongoFramework.Infrastructure.Mapping; using MongoFramework.Utilities; namespace MongoFramework @@ -43,6 +48,79 @@ protected virtual void CheckEntities(IEnumerable entities) } } + /// + /// Finds an entity with the given primary key value. If an entity with the given primary key value + /// is being tracked by the context, then it is returned immediately without making a request to the + /// database. Otherwise, a query is made to the database for an entity with the given primary key value + /// and this entity, if found, is attached to the context and returned. If no entity is found, then + /// null is returned. + /// + /// The value of the primary key for the entity to be found. + /// The entity found, or null. + public override TEntity Find(object id) + { + Check.NotNull(id, nameof(id)); + + var tracked = Context.ChangeTracker.GetEntryById(id); + + if (tracked != null) + { + if ((tracked.Entity as IHaveTenantId)?.TenantId == Context.TenantId) + { + return tracked.Entity as TEntity; + } + } + + var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity)); + var filter = entityDefinition.CreateIdFilter(id, Context.TenantId); + + var collection = Context.Connection.GetDatabase().GetCollection(entityDefinition.CollectionName); + var cursor = collection.Find(filter); + var entity = cursor.FirstOrDefault(); + + if (entity != null) + { + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + + return entity; + } + + /// + /// Finds an entity with the given primary key value. If an entity with the given primary key value + /// is being tracked by the context, then it is returned immediately without making a request to the + /// database. Otherwise, a query is made to the database for an entity with the given primary key value + /// and this entity, if found, is attached to the context and returned. If no entity is found, then + /// null is returned. + /// + /// The value of the primary key for the entity to be found. + /// The entity found, or null. + public override async ValueTask FindAsync(object id) + { + Check.NotNull(id, nameof(id)); + + var tracked = Context.ChangeTracker.GetEntryById(id); + + if (tracked != null) + { + return tracked.Entity as TEntity; + } + + var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity)); + var filter = entityDefinition.CreateIdFilter(id, Context.TenantId); + + var collection = Context.Connection.GetDatabase().GetCollection(entityDefinition.CollectionName); + var cursor = await collection.FindAsync(filter); + var entity = await cursor.FirstOrDefaultAsync(); + + if (entity != null) + { + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + + return entity; + } + public override void Add(TEntity entity) { Check.NotNull(entity, nameof(entity)); diff --git a/tests/MongoFramework.Tests/MongoDbSetTests.cs b/tests/MongoFramework.Tests/MongoDbSetTests.cs index b76fe5d7..c7b2b727 100644 --- a/tests/MongoFramework.Tests/MongoDbSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbSetTests.cs @@ -1,4 +1,4 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.ComponentModel.DataAnnotations; using System.Linq; @@ -122,6 +122,77 @@ public void FindRequiresId() Assert.ThrowsException(() => dbSet.Find(null)); } + [TestMethod] + public async Task SuccessfulInsertAndFindAsync() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Description = "SuccessfulInsertAndFind" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + Assert.AreEqual("SuccessfulInsertAndFind", (await dbSet.FindAsync(model.Id)).Description); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(model).State); + } + + [TestMethod] + public async Task SuccessfulNullFindAsync() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Description = "SuccessfulNullFind" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + Assert.IsNull(await dbSet.FindAsync("abcd")); + } + + [TestMethod] + public async Task SuccessfullyFindAsyncTracked() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked" + }; + + dbSet.Add(model); + + //Note: not saving, but still should be found as tracked + Assert.AreEqual("SuccessfullyFindTracked", (await dbSet.FindAsync(model.Id)).Description); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Added, context.ChangeTracker.GetEntry(model).State); + } + + [TestMethod] + public async Task FindAsyncRequiresId() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + await Assert.ThrowsExceptionAsync(async () => await dbSet.FindAsync(null)); + } + [TestMethod] public void SuccessfullyUpdateEntity() { diff --git a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs index b5be474d..35c26643 100644 --- a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs @@ -128,6 +128,238 @@ public async Task SuccessfulInsertAndQueryBackAsync() Assert.IsTrue(dbSet.Any(m => m.TenantId == tenantId)); } + [TestMethod] + public void SuccessfulInsertAndFind() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var entity1 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + var entity2 = new TestModel {Description = "SuccessfulInsertAndFind.2"}; + + dbSet.Add(entity1); + dbSet2.Add(entity2); + + context.SaveChanges(); + context2.SaveChanges(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + Assert.AreEqual("SuccessfulInsertAndFind.1", dbSet.Find(entity1.Id).Description); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(entity1).State); + } + + [TestMethod] + public void SuccessfulNullFind() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var entity1 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + var entity2 = new TestModel {Description = "SuccessfulInsertAndFind.2"}; + + dbSet.Add(entity1); + dbSet2.Add(entity2); + + context.SaveChanges(); + context2.SaveChanges(); + + Assert.IsNull(dbSet.Find("abcd")); + } + + [TestMethod] + public void BlocksWrongTenantFind() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var entity1 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + var entity2 = new TestModel {Description = "SuccessfulInsertAndFind.2"}; + + dbSet.Add(entity1); + dbSet2.Add(entity2); + + context.SaveChanges(); + context2.SaveChanges(); + + Assert.IsNull(dbSet.Find(entity2.Id)); + } + + [TestMethod] + public void SuccessfullyFindTracked() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.1" + }; + + var model2 = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.2" + }; + + dbSet.Add(model); + dbSet2.Add(model2); + + //Note: not saving, but still should be found as tracked + Assert.AreEqual("SuccessfullyFindTracked.1", dbSet.Find(model.Id).Description); + Assert.AreEqual("SuccessfullyFindTracked.2", dbSet2.Find(model2.Id).Description); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Added, context.ChangeTracker.GetEntry(model).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Added, context.ChangeTracker.GetEntry(model2).State); + } + + [TestMethod] + public void FindRequiresId() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + Assert.ThrowsException(() => dbSet.Find(null)); + } + + [TestMethod] + public async Task SuccessfulInsertAndFindAsync() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var entity1 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + var entity2 = new TestModel {Description = "SuccessfulInsertAndFind.2"}; + + dbSet.Add(entity1); + dbSet2.Add(entity2); + + context.SaveChanges(); + context2.SaveChanges(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + Assert.AreEqual("SuccessfulInsertAndFind.1", (await dbSet.FindAsync(entity1.Id)).Description); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(entity1).State); + } + + [TestMethod] + public async Task SuccessfulNullFindAsync() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var entity1 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + var entity2 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + + dbSet.Add(entity1); + dbSet2.Add(entity2); + + context.SaveChanges(); + context2.SaveChanges(); + + Assert.IsNull(await dbSet.FindAsync("abcd")); + } + + [TestMethod] + public async Task BlocksWrongTenantFindAsync() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var entity1 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + var entity2 = new TestModel {Description = "SuccessfulInsertAndFind.1"}; + + dbSet.Add(entity1); + dbSet2.Add(entity2); + + context.SaveChanges(); + context2.SaveChanges(); + + Assert.IsNull(await dbSet.FindAsync(entity2.Id)); + } + + [TestMethod] + public async Task SuccessfullyFindAsyncTracked() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var context2 = new MongoDbTenantContext(connection, tenantId + "-alt"); + var dbSet2 = new MongoDbTenantSet(context2); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.1" + }; + + var model2 = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.2" + }; + + dbSet.Add(model); + dbSet2.Add(model2); + + //Note: not saving, but still should be found as tracked + Assert.AreEqual("SuccessfullyFindTracked.1", (await dbSet.FindAsync(model.Id)).Description); + Assert.AreEqual("SuccessfullyFindTracked.2", (await dbSet2.FindAsync(model2.Id)).Description); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Added, context.ChangeTracker.GetEntry(model).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Added, context.ChangeTracker.GetEntry(model2).State); + } + + [TestMethod] + public async Task FindAsyncRequiresId() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + await Assert.ThrowsExceptionAsync(async () => await dbSet.FindAsync(null)); + } + [TestMethod] public void SuccessfullyUpdateEntity() {