Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MongoFramework/IMongoDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public interface IMongoDbSet<TEntity> : IMongoDbSet, IQueryable<TEntity> where T
{
IMongoDbContext Context { get; }
TEntity Find(object id);
ValueTask<TEntity> FindAsync(object id);
TEntity Create();
void Add(TEntity entity);
void AddRange(IEnumerable<TEntity> entities);
Expand Down
110 changes: 73 additions & 37 deletions src/MongoFramework/MongoDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,79 @@ public virtual TEntity Create()
var entity = Activator.CreateInstance<TEntity>();
Add(entity);
return entity;
}

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
/// and this entity, if found, is attached to the context and returned. If no entity is found, then
/// null is returned.
/// </summary>
/// <param name="id">The value of the primary key for the entity to be found.</param>
/// <returns>The entity found, or null.</returns>
public virtual TEntity Find(object id)
{
Check.NotNull(id, nameof(id));

var tracked = Context.ChangeTracker.GetEntryById<TEntity>(id);

if (tracked != null)
{
return tracked.Entity as TEntity;
}

var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity));
var filter = entityDefinition.CreateIdFilter<TEntity>(id);

var collection = Context.Connection.GetDatabase().GetCollection<TEntity>(entityDefinition.CollectionName);
var cursor = collection.Find(filter);
var entity = cursor.FirstOrDefault();

if (entity != null)
{
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

return entity;
}

}

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
/// and this entity, if found, is attached to the context and returned. If no entity is found, then
/// null is returned.
/// </summary>
/// <param name="id">The value of the primary key for the entity to be found.</param>
/// <returns>The entity found, or null.</returns>
public virtual TEntity Find(object id)
{
Check.NotNull(id, nameof(id));

var tracked = Context.ChangeTracker.GetEntryById<TEntity>(id);

if (tracked != null)
{
return tracked.Entity as TEntity;
}

var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity));
var filter = entityDefinition.CreateIdFilter<TEntity>(id);

var collection = Context.Connection.GetDatabase().GetCollection<TEntity>(entityDefinition.CollectionName);
var cursor = collection.Find(filter);
var entity = cursor.FirstOrDefault();

if (entity != null)
{
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

return entity;
}

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
/// and this entity, if found, is attached to the context and returned. If no entity is found, then
/// null is returned.
/// </summary>
/// <param name="id">The value of the primary key for the entity to be found.</param>
/// <returns>The entity found, or null.</returns>
public virtual async ValueTask<TEntity> FindAsync(object id)
{
Check.NotNull(id, nameof(id));

var tracked = Context.ChangeTracker.GetEntryById<TEntity>(id);

if (tracked != null)
{
return tracked.Entity as TEntity;
}

var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity));
var filter = entityDefinition.CreateIdFilter<TEntity>(id);

var collection = Context.Connection.GetDatabase().GetCollection<TEntity>(entityDefinition.CollectionName);
var cursor = await collection.FindAsync(filter);
var entity = await cursor.FirstOrDefaultAsync();

if (entity != null)
{
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

return entity;
}


/// <summary>
/// Marks the entity for insertion into the database.
/// </summary>
Expand Down
78 changes: 78 additions & 0 deletions src/MongoFramework/MongoDbTenantSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using MongoDB.Driver;
using MongoFramework.Infrastructure;
using MongoFramework.Infrastructure.Commands;
using MongoFramework.Infrastructure.Mapping;
using MongoFramework.Utilities;

namespace MongoFramework
Expand Down Expand Up @@ -43,6 +48,79 @@ protected virtual void CheckEntities(IEnumerable<TEntity> entities)
}
}

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
/// and this entity, if found, is attached to the context and returned. If no entity is found, then
/// null is returned.
/// </summary>
/// <param name="id">The value of the primary key for the entity to be found.</param>
/// <returns>The entity found, or null.</returns>
public override TEntity Find(object id)
{
Check.NotNull(id, nameof(id));

var tracked = Context.ChangeTracker.GetEntryById<TEntity>(id);

if (tracked != null)
{
if ((tracked.Entity as IHaveTenantId)?.TenantId == Context.TenantId)
{
return tracked.Entity as TEntity;
}
}

var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity));
var filter = entityDefinition.CreateIdFilter<TEntity>(id, Context.TenantId);

var collection = Context.Connection.GetDatabase().GetCollection<TEntity>(entityDefinition.CollectionName);
var cursor = collection.Find(filter);
var entity = cursor.FirstOrDefault();

if (entity != null)
{
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

return entity;
}

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
/// and this entity, if found, is attached to the context and returned. If no entity is found, then
/// null is returned.
/// </summary>
/// <param name="id">The value of the primary key for the entity to be found.</param>
/// <returns>The entity found, or null.</returns>
public override async ValueTask<TEntity> FindAsync(object id)
{
Check.NotNull(id, nameof(id));

var tracked = Context.ChangeTracker.GetEntryById<TEntity>(id);

if (tracked != null)
{
return tracked.Entity as TEntity;
}

var entityDefinition = EntityMapping.GetOrCreateDefinition(typeof(TEntity));
var filter = entityDefinition.CreateIdFilter<TEntity>(id, Context.TenantId);

var collection = Context.Connection.GetDatabase().GetCollection<TEntity>(entityDefinition.CollectionName);
var cursor = await collection.FindAsync(filter);
var entity = await cursor.FirstOrDefaultAsync();

if (entity != null)
{
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

return entity;
}

public override void Add(TEntity entity)
{
Check.NotNull(entity, nameof(entity));
Expand Down
73 changes: 72 additions & 1 deletion tests/MongoFramework.Tests/MongoDbSetTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.ComponentModel.DataAnnotations;
using System.Linq;
Expand Down Expand Up @@ -122,6 +122,77 @@ public void FindRequiresId()
Assert.ThrowsException<ArgumentNullException>(() => dbSet.Find(null));
}

[TestMethod]
public async Task SuccessfulInsertAndFindAsync()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

var model = new TestModel
{
Description = "SuccessfulInsertAndFind"
};

dbSet.Add(model);

context.SaveChanges();

context = new MongoDbContext(connection);
dbSet = new MongoDbSet<TestModel>(context);
Assert.AreEqual("SuccessfulInsertAndFind", (await dbSet.FindAsync(model.Id)).Description);
Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(model).State);
}

[TestMethod]
public async Task SuccessfulNullFindAsync()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

var model = new TestModel
{
Description = "SuccessfulNullFind"
};

dbSet.Add(model);

context.SaveChanges();

Assert.IsNull(await dbSet.FindAsync("abcd"));
}

[TestMethod]
public async Task SuccessfullyFindAsyncTracked()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

var model = new TestModel
{
Id = "abcd",
Description = "SuccessfullyFindTracked"
};

dbSet.Add(model);

//Note: not saving, but still should be found as tracked
Assert.AreEqual("SuccessfullyFindTracked", (await dbSet.FindAsync(model.Id)).Description);
Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Added, context.ChangeTracker.GetEntry(model).State);
}

[TestMethod]
public async Task FindAsyncRequiresId()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<TestModel>(context);

await Assert.ThrowsExceptionAsync<ArgumentNullException>(async () => await dbSet.FindAsync(null));
}

[TestMethod]
public void SuccessfullyUpdateEntity()
{
Expand Down
Loading