diff --git a/src/MongoFramework/IMongoDbSet.cs b/src/MongoFramework/IMongoDbSet.cs index 291cccad..46da466c 100644 --- a/src/MongoFramework/IMongoDbSet.cs +++ b/src/MongoFramework/IMongoDbSet.cs @@ -17,6 +17,7 @@ public interface IMongoDbSet public interface IMongoDbSet : IMongoDbSet, IQueryable where TEntity : class { IMongoDbContext Context { get; } + TEntity Find(object id); TEntity Create(); void Add(TEntity entity); void AddRange(IEnumerable entities); diff --git a/src/MongoFramework/Infrastructure/EntityEntryContainer.cs b/src/MongoFramework/Infrastructure/EntityEntryContainer.cs index daa07112..32f7f42b 100644 --- a/src/MongoFramework/Infrastructure/EntityEntryContainer.cs +++ b/src/MongoFramework/Infrastructure/EntityEntryContainer.cs @@ -66,8 +66,28 @@ public EntityEntry GetEntry(TCollectionBase entity) } return null; - } + } + + public EntityEntry GetEntryById(object id) + { + var collectionType = typeof(TCollectionBase); + + if (EntryLookupByType.TryGetValue(collectionType, out var entries)) + { + var entityDefinition = EntityMapping.GetOrCreateDefinition(collectionType); + + foreach (var entry in entries) + { + var entryEntityId = entityDefinition.GetIdValue(entry.Entity); + if (entryEntityId.Equals(id)) + { + return entry; + } + } + } + return null; + } public EntityEntry SetEntityState(TCollectionBase entity, EntityEntryState state) where TCollectionBase : class { if (entity is null) diff --git a/src/MongoFramework/MongoDbSet.cs b/src/MongoFramework/MongoDbSet.cs index bde83ecb..74cff0f4 100644 --- a/src/MongoFramework/MongoDbSet.cs +++ b/src/MongoFramework/MongoDbSet.cs @@ -1,7 +1,9 @@ -using MongoFramework.Infrastructure; +using MongoDB.Driver; +using MongoFramework.Infrastructure; using MongoFramework.Infrastructure.Commands; using MongoFramework.Infrastructure.Linq; using MongoFramework.Infrastructure.Linq.Processors; +using MongoFramework.Infrastructure.Mapping; using System; using System.Collections; using System.Collections.Generic; @@ -30,8 +32,44 @@ public virtual TEntity Create() var entity = Activator.CreateInstance(); Add(entity); return entity; - } - + } + + /// + /// Finds an entity with the given primary key value. If an entity with the given primary key value + /// is being tracked by the context, then it is returned immediately without making a request to the + /// database. Otherwise, a query is made to the database for an entity with the given primary key value + /// and this entity, if found, is attached to the context and returned. If no entity is found, then + /// null is returned. + /// + /// The value of the primary key for the entity to be found. + /// The entity found, or null. + public virtual TEntity Find(object id) + { + if (id == null) + throw new ArgumentNullException(nameof(id)); + + var tracked = Context.ChangeTracker.GetEntryById(id); + + if (tracked != null) + { + return tracked.Entity as TEntity; + } + + var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity)); + var filter = entityDefinition.CreateIdFilter(id); + + var collection = Context.Connection.GetDatabase().GetCollection(entityDefinition.CollectionName); + var cursor = collection.Find(filter); + var entity = cursor.FirstOrDefault(); + + if (entity != null) + { + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + + return entity; + } + /// /// Marks the entity for insertion into the database. /// diff --git a/tests/MongoFramework.Tests/Infrastructure/EntityEntryContainerTests.cs b/tests/MongoFramework.Tests/Infrastructure/EntityEntryContainerTests.cs index f85ef3a0..ca2a1369 100644 --- a/tests/MongoFramework.Tests/Infrastructure/EntityEntryContainerTests.cs +++ b/tests/MongoFramework.Tests/Infrastructure/EntityEntryContainerTests.cs @@ -119,6 +119,28 @@ public void EntryDoesntMatchOnEqualityOverride() Assert.AreEqual(2, entryContainer.Entries().Count()); } + [TestMethod] + public void GetExistingEntryIdMatch() + { + var entryContainer = new EntityEntryContainer(); + var entity = new EntityEntryContainerModel + { + Id = "123", + Title = "EntityEntryContainerTests.UpdateExistingEntryWithoutId" + }; + entryContainer.SetEntityState(entity, EntityEntryState.Added); + + var entry = entryContainer.GetEntryById("123"); + Assert.AreEqual(entity, entry.Entity); + Assert.AreEqual(EntityEntryState.Added, entry.State); + + entryContainer.SetEntityState(entity, EntityEntryState.NoChanges); + + entry = entryContainer.GetEntryById("123"); + Assert.AreEqual(entity, entry.Entity); + Assert.AreEqual(EntityEntryState.NoChanges, entry.State); + } + [TestMethod] public void RemoveRange() { diff --git a/tests/MongoFramework.Tests/MongoDbSetTests.cs b/tests/MongoFramework.Tests/MongoDbSetTests.cs index e8acb86b..9864939c 100644 --- a/tests/MongoFramework.Tests/MongoDbSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbSetTests.cs @@ -50,6 +50,67 @@ public async Task SuccessfulInsertAndQueryBackAsync() Assert.IsTrue(dbSet.Any(m => m.Description == "ValueAsync")); } + [TestMethod] + public void SuccessfulInsertAndFind() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Description = "ValueSync" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + Assert.IsTrue(dbSet.Find(model.Id).Description == "ValueSync"); + Assert.IsTrue(context.ChangeTracker.GetEntry(model).State == MongoFramework.Infrastructure.EntityEntryState.NoChanges); + } + + [TestMethod] + public void SuccessfulNullFind() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Description = "ValueSync" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + Assert.IsNull(dbSet.Find("abcd")); + } + + [TestMethod] + public void SuccessfullyFindTracked() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "ValueSync" + }; + + dbSet.Add(model); + + //Note: not saving, but still should be found as tracked + Assert.IsTrue(dbSet.Find(model.Id).Description == "ValueSync"); + Assert.IsTrue(context.ChangeTracker.GetEntry(model).State == MongoFramework.Infrastructure.EntityEntryState.Added); + } + [TestMethod] public void SuccessfullyUpdateEntity() {