Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MongoFramework/IMongoDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public interface IMongoDbSet
public interface IMongoDbSet<TEntity> : IMongoDbSet, IQueryable<TEntity> where TEntity : class
{
IMongoDbContext Context { get; }
TEntity Find(object id);
TEntity Create();
void Add(TEntity entity);
void AddRange(IEnumerable<TEntity> entities);
Expand Down
22 changes: 21 additions & 1 deletion src/MongoFramework/Infrastructure/EntityEntryContainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,28 @@ public EntityEntry GetEntry<TCollectionBase>(TCollectionBase entity)
}

return null;
}
}

public EntityEntry GetEntryById<TCollectionBase>(object id)
{
var collectionType = typeof(TCollectionBase);

if (EntryLookupByType.TryGetValue(collectionType, out var entries))
{
var entityDefinition = EntityMapping.GetOrCreateDefinition(collectionType);

foreach (var entry in entries)
{
var entryEntityId = entityDefinition.GetIdValue(entry.Entity);
if (entryEntityId.Equals(id))
{
return entry;
}
}
}

return null;
}
public EntityEntry SetEntityState<TCollectionBase>(TCollectionBase entity, EntityEntryState state) where TCollectionBase : class
{
if (entity is null)
Expand Down
44 changes: 41 additions & 3 deletions src/MongoFramework/MongoDbSet.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using MongoFramework.Infrastructure;
using MongoDB.Driver;
using MongoFramework.Infrastructure;
using MongoFramework.Infrastructure.Commands;
using MongoFramework.Infrastructure.Linq;
using MongoFramework.Infrastructure.Linq.Processors;
using MongoFramework.Infrastructure.Mapping;
using System;
using System.Collections;
using System.Collections.Generic;
Expand Down Expand Up @@ -30,8 +32,44 @@ public virtual TEntity Create()
var entity = Activator.CreateInstance<TEntity>();
Add(entity);
return entity;
}

}

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
/// and this entity, if found, is attached to the context and returned. If no entity is found, then
/// null is returned.
/// </summary>
/// <param name="id">The value of the primary key for the entity to be found.</param>
/// <returns>The entity found, or null.</returns>
public virtual TEntity Find(object id)
{
if (id == null)
throw new ArgumentNullException(nameof(id));

var tracked = Context.ChangeTracker.GetEntryById<TEntity>(id);

if (tracked != null)
{
return tracked.Entity as TEntity;
}

var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity));
var filter = entityDefinition.CreateIdFilter<TEntity>(id);

var collection = Context.Connection.GetDatabase().GetCollection<TEntity>(entityDefinition.CollectionName);
var cursor = collection.Find(filter);
var entity = cursor.FirstOrDefault();

if (entity != null)
{
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

return entity;
}

/// <summary>
/// Marks the entity for insertion into the database.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,28 @@ public void EntryDoesntMatchOnEqualityOverride()
Assert.AreEqual(2, entryContainer.Entries().Count());
}

[TestMethod]
public void GetExistingEntryIdMatch()
{
var entryContainer = new EntityEntryContainer();
var entity = new EntityEntryContainerModel
{
Id = "123",
Title = "EntityEntryContainerTests.UpdateExistingEntryWithoutId"
};
entryContainer.SetEntityState(entity, EntityEntryState.Added);

var entry = entryContainer.GetEntryById<EntityEntryContainerModel>("123");
Assert.AreEqual(entity, entry.Entity);
Assert.AreEqual(EntityEntryState.Added, entry.State);

entryContainer.SetEntityState(entity, EntityEntryState.NoChanges);

entry = entryContainer.GetEntryById<EntityEntryContainerModel>("123");
Assert.AreEqual(entity, entry.Entity);
Assert.AreEqual(EntityEntryState.NoChanges, entry.State);
}

[TestMethod]
public void RemoveRange()
{
Expand Down
61 changes: 61 additions & 0 deletions tests/MongoFramework.Tests/MongoDbSetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,67 @@ public async Task SuccessfulInsertAndQueryBackAsync()
Assert.IsTrue(dbSet.Any(m => m.Description == "ValueAsync"));
}

[TestMethod]
public void SuccessfulInsertAndFind()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

var model = new TestModel
{
Description = "ValueSync"
};

dbSet.Add(model);

context.SaveChanges();

context = new MongoDbContext(connection);
dbSet = new MongoDbSet<TestModel>(context);
Assert.IsTrue(dbSet.Find(model.Id).Description == "ValueSync");
Assert.IsTrue(context.ChangeTracker.GetEntry(model).State == MongoFramework.Infrastructure.EntityEntryState.NoChanges);
}

[TestMethod]
public void SuccessfulNullFind()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

var model = new TestModel
{
Description = "ValueSync"
};

dbSet.Add(model);

context.SaveChanges();

Assert.IsNull(dbSet.Find("abcd"));
}

[TestMethod]
public void SuccessfullyFindTracked()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

var model = new TestModel
{
Id = "abcd",
Description = "ValueSync"
};

dbSet.Add(model);

//Note: not saving, but still should be found as tracked
Assert.IsTrue(dbSet.Find(model.Id).Description == "ValueSync");
Assert.IsTrue(context.ChangeTracker.GetEntry(model).State == MongoFramework.Infrastructure.EntityEntryState.Added);
}

[TestMethod]
public void SuccessfullyUpdateEntity()
{
Expand Down