AddOrUpdate() method in asp.net Core 2 - entity-framework

How can I access to AddOrUpdate() method when I using ASP.NET Core 2?
There is AddOrUpdate() in EntityFramework 6 in System.Data.Entity.Migrations namespace. But when I want to use this method in ASP.NET Core 2, I cannot find it.

This may be what you want
public static class DbSetExtension
{
/// <exception cref="ArgumentNullException"></exception>
public static TEntity FindEntity<TEntity>(this DbSet<TEntity> dbSet, TEntity entity, bool noTracking = false) where TEntity : class
{
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
var dbContext = dbSet.GetService<ICurrentDbContext>().Context;
var entityEntry = dbContext.Entry(entity);
var entityType = entityEntry.Metadata;
var primaryKey = entityType.FindPrimaryKey();
if (primaryKey == null)
{
return (noTracking ? dbSet.AsNoTracking() : dbSet).FirstOrDefault(item => item.Equals(entity));
}
var ids = primaryKey.Properties.Select(item => item.PropertyInfo.GetValue(entity)).ToArray();
var result = dbSet.Find(ids);
if (noTracking && result != null)
{
dbContext.Entry(result).State = EntityState.Detached;
}
return result;
}
/// <exception cref="ArgumentNullException"></exception>
public static async ValueTask<TEntity> FindEntityAsync<TEntity>(this DbSet<TEntity> dbSet, TEntity entity, bool noTracking = false)
where TEntity : class
{
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
var dbContext = dbSet.GetService<ICurrentDbContext>().Context;
var entityEntry = dbContext.Entry(entity);
var entityType = entityEntry.Metadata;
var primaryKey = entityType.FindPrimaryKey();
if (primaryKey == null)
{
return await (noTracking ? dbSet.AsNoTracking() : dbSet).FirstOrDefaultAsync(item => item.Equals(entity));
}
var ids = primaryKey.Properties.Select(item => item.PropertyInfo.GetValue(entity)).ToArray();
var result = await dbSet.FindAsync(ids);
if (noTracking && result != null)
{
dbContext.Entry(result).State = EntityState.Detached;
}
return result;
}
/// <exception cref="ArgumentNullException"></exception>
public static EntityEntry<TEntity> Update<TEntity>(this DbSet<TEntity> dbSet, TEntity entity, bool? includeOrExclude = null,
params string[] propertyNames) where TEntity : class
{
if (entity == null)
{
throw new ArgumentNullException(nameof(entity));
}
var entityEntry = dbSet.Update(entity);
if (includeOrExclude != null)
{
foreach (var property in entityEntry.Properties)
{
if (includeOrExclude.Value ^ propertyNames?.Contains(property.Metadata.PropertyInfo.Name) == true)
{
property.IsModified = false;
}
}
}
return entityEntry;
}
/// <exception cref="ArgumentNullException"></exception>
public static EntityEntry<TEntity> AddOrUpdate<TEntity>(this DbSet<TEntity> dbSet, TEntity entity, bool? includeOrExclude = null,
params string[] propertyNames) where TEntity : class
{
if (dbSet.FindEntity(entity, true) == null)
{
return dbSet.Add(entity);
}
return dbSet.Update(entity, includeOrExclude, propertyNames);
}
/// <exception cref="ArgumentNullException"></exception>
public static async ValueTask<EntityEntry<TEntity>> AddOrUpdateAsync<TEntity>(this DbSet<TEntity> dbSet, TEntity entity, bool? includeOrExclude = null,
params string[] propertyNames) where TEntity : class
{
if (await dbSet.FindEntityAsync(entity, true) == null)
{
return await dbSet.AddAsync(entity);
}
return dbSet.Update(entity, includeOrExclude, propertyNames);
}
}
I've omitted the comments for brevity.
Use like this
var people = new People {
Id=1,
Name="Tom",
ChangeTime=DateTimeOffset.Now
AddTime=DateTimeOffset.Now,
};
await dbContext.AddOrUpdateAsync(people, false, nameof(people.AddTime));
await dbContext.SaveChangesAsync();
or
await dbContext.AddOrUpdateAsync(people, true, nameof(people.Name), nameof(people.ChangeTime));
await dbContext.SaveChangesAsync();
You can also use its synchronization method.
I was in the "Microsoft. EntityFrameworkCore 2.1.1" and "Microsoft. EntityFrameworkCore 3.1.7" test normal.
You may need more: Saving an explicit value during add.
If the entity uses auto-generated key values. see also: Saving single entities.
If there is any omission, please let me know. Thank you.

Related

How to use 2 expressions in EF Repository Pattern

I have this method in my Generic EF Repository class
public class GenericRepository<T> : IGenericRepository<T> where T : class
{
public async Task<IEnumerable<T>> GetAllByExpression<T>(Expression<Func<T, bool>> expression,
Expression<Func<T, object>>[] includes) where T : class
{
if (expression != null)
{
IQueryable<T> query = context.Set<T>();
foreach(var include in includes)
{
query = query.Include(include);
}
var result = await query.Where(expression)
.ToListAsync();
return result;
}
else
{
throw new ArgumentNullException("Invalid expression.");
}
}
}
I do not know how to use the "Include" to it. Basically what I want to get is like this =>
return await context.PropertyChargesRates.Where(x => x.BuildingId == buildingId)
.Include(x => x.PropertyChargeType)
.ToListAsync();
This is how I try to utilize this function:
public async Task<IEnumerable<PropertyChargesRate>> GetPropertyChargeRates(int buildingId)
{
try
{
var g = new GenericRepository<PropertyChargesRate>(context);
var result = g.GetAllByExpression<PropertyChargesRate>(x => x.BuildingId == buildingId, ** CANT FIGURE OUT HOW TO CALL THIS PART YET **);
return result;
}
catch (Exception ex)
{
return null;
}
}
Implementation that you are looking for could look like this:
public IEnumerable<T> GetAllByExpression<T>(
YourDBContext context,
Expression<Func<T, bool>> expression,
Func<IQueryable<T>, IQueryable<T>> includes) where T : class
{
if (expression != null)
{
IQueryable<T> query = context.Set<T>();
query = includes(query);
var result = query.Where(expression)
.ToList();
return result;
}
throw new ArgumentNullException(nameof(expression), "Invalid expression.");
}
And usage would look like this:
repository.GetAllByExpression<YourEntity>
(YourDbContext,
YourWhereExpression,
(q)=> q.Include(i=> i.YourInclude1)
.Include(i2=> i2.YourInclude2));

Questions about repository pattern with Entity Framework Core

I have created an API that is using EF Core with a repository pattern and I have few questions:
Post method receives an email address and verify whether user exists on not.
If an email address does not exist in the User table, get the guest access details from the AccessManagement table and save in Entitlement table and return the details
If the entry exists, get the user access details and return them
IGeneralRepository:
public interface IGenrealRepository<TEntity> where TEntity : class , new()
{
IQueryable<TEntity> GetAll();
Task<TEntity> AddAsync(TEntity entity);
Task<TEntity[]> AddRangeAsync(TEntity[] entity);
TEntity Update(TEntity entity);
Task<int> CompleteAsync();
}
General repository:
public class GeneralRepository<TEntity> : IGenrealRepository<TEntity> where TEntity : class, new()
{
private MyDbContext _myDbContext;
public GeneralRepository(MyDbContext myDbContext)
{
_myDbContext = myDbContext;
}
public async Task<TEntity> AddAsync(TEntity entity)
{
if (entity == null)
{
throw new ArgumentNullException($"{nameof(AddAsync)} entity must not be null");
}
try
{
await _myDbContext.AddAsync(entity);
return entity;
}
catch (Exception ex)
{
throw new Exception($"{nameof(entity)} could not be saved: {ex.Message}");
}
}
public async Task<TEntity[]> AddRangeAsync(TEntity[] entity)
{
if (entity == null)
{
throw new ArgumentNullException($"{nameof(AddRangeAsync)} entity must not be null");
}
try
{
await _myDbContext.AddRangeAsync(entity);
return entity;
}
catch (Exception ex)
{
throw new Exception($"{nameof(entity)} could not be saved: {ex.Message}");
}
}
public async Task<int> CompleteAsync()
{
return await _myDbContext.SaveChangesAsync();
}
public IQueryable<TEntity> GetAll()
{
try
{
return _myDbContext.Set<TEntity>();
}
catch (Exception ex)
{
throw new Exception($"Couldn't retrieve entities: {ex.Message}");
}
}
public TEntity Update(TEntity entity)
{
try
{
_myDbContext.Update<TEntity>(entity);
return entity;
}
catch (Exception ex)
{
throw new Exception($"{nameof(entity)} could not be updated: {ex.Message}");
}
}
}
IUserService:
public interface IUserService
{
Task<User> CreateUser(string emailId);
Task<int> Complete();
}
UserService implementation:
public class UserService : IUserService
{
private readonly IUserRepository _userRepository;
private readonly IAccessManagementRepository _accessManagementRepository;
public UserService(IUserRepository userRepository, IAccessManagementRepository accessManagementRepository)
{
_userRepository = userRepository;
_accessManagementRepository = accessManagementRepository;
}
public async Task<int> Complete()
{
return await _userRepository.CompleteAsync();
}
public async Task<User> CreateUser(string emailId)
{
var user = _userRepository.GetAll()
.Where(x => x.EmailId.ToUpper() == emailId.ToUpper())
.FirstOrDefault();
if (user == null)
{
var entitlements = await _userAccessRepository.GetAll()
.Where( x => x.Default == true)
.Select( x => new UserEntitlement() {
Id = x.Id,
AccessName = x.AccessName
}).ToListAsync();
//saving User and Entitlement
user = new User()
{
EmailId = emailId,
UserEntitlements = entitlements
};
user = await _userRepository.AddAsync(user);
}
else
{
// Getting current User Entitlement
var entitlements = await _userRepository.GetAllUserEntitilements();
var entitlement = entitlements.Find(x => x.UserId == user.UserId);
user.UserEntitlements = entitlements;
}
return user;
}
}
API call:
[HttpPost]
public async Task<IActionResult> CreateUser([FromBody] User user)
{
var result = await _userService.CreateUser(user.EmailId);
await _userService.Complete();
return CreatedAtAction(nameof(GetUser), new { emailId = result.EmailId }, result);
}
Questions:
Is my method UserService.CreateUser() implementation correct? Any better approach?
Is the below code is the best approach to filter?
var user = _userRepository.GetAll()
.Where(x => x.EmailId.ToUpper() == emailId.ToUpper())
.FirstOrDefault();
How to get data from User and Entitlement table at one stretch? Something like below Include but can not use include because of an error
var user = _userRepository.GetAll()
.Where(x => x.EmailId.ToUpper() == emailId.ToUpper())
.Include<UserEntitlement>()
.FirstOrDefault();
How to do insert to one table and update to another table in a single transaction?
Leo,
I prefer doing the validation of the email outside the CreateUser function
This comes with another function where you could add to IUserService where you can get the user by email GetUserByEmail.
Doing that you can possibly return a proper error or validation message before invoking the CreateUser at the API Call
For example
[HttpPost]
public async Task<IActionResult> CreateUser([FromBody] User user)
{
var user = await _userService.GetUserByEmail(user.EmailId);
// or var userRegistered = await _userService.UserExistsByEmail(user.EmailId) returning a bool
// user registered?
if (user)
{
// The user already exists, return an error or
// You could update the UserEntitlements here or you could
// make an HttpPut where the user is updated do nothing here
}
....
}
An example
var user = _userRepository.GetAll()
.Include(x => x.UserEntitlements)
.Where(x => x.EmailId.ToUpper() == emailId.ToUpper())
.FirstOrDefault();
You can do it using UnitOfWork
Repository Pattern and Unit of Work

Preventing default values being used for keys in Entity Framework Core?

I would like to prevent any default values for a type being used for keys in Entity Framework core. So for example, 00000000-0000-0000-0000-000000000000 for Guids, 0 for ints, etc
Using this helper class
static class KeyValidator
{
public static void ValidateKeys(this DbContext context)
{
foreach (var entity in context.AddedOrModified())
{
foreach (var key in entity.Metadata.GetKeys())
{
foreach (var property in key.Properties)
{
var propertyEntry = entity.Property(property.Name);
if (!IsDefaultValue(property.ClrType, propertyEntry.CurrentValue))
{
continue;
}
throw new Exception($#"Invalid empty key.
EntityType: {entity.Metadata.ClrType.FullName}
PropertyName: {property.Name}
PropertyType: {property.ClrType.FullName}.");
}
}
}
}
static bool IsDefaultValue(Type clrType, object currentValue)
{
if (clrType.IsValueType)
{
var instance = Activator.CreateInstance(clrType);
return instance.Equals(currentValue);
}
return currentValue == null;
}
static IEnumerable<EntityEntry> AddedOrModified(this DbContext context)
{
return context.ChangeTracker.Entries()
.Where(e => e.State == EntityState.Added ||
e.State == EntityState.Modified);
}
}
The in the DbContext include
public override int SaveChanges()
{
this.ValidateKeys();
return base.SaveChanges();
}
public override Task<int> SaveChangesAsync(bool acceptAllChanges, CancellationToken cancellation = default)
{
this.ValidateKeys();
return base.SaveChangesAsync(acceptAllChanges, cancellation);
}

EntityFramework throws 'Can not start another operation while there is an asynchronous operation pending'

IRepository.cs
public interface ICommonRepository<T>
{
Task<int> CountAsync(Expression<Func<T, bool>> filter = null, Func<IQueryable<T>,IOrderedQueryable<T>> orderBy = null,List<Expression<Func<T, object>>> includes = null);
}
Repository.cs:
public class Repository<T> : IRepository<T> where T : class, new()
{
protected readonly MyDbContext _context;
protected readonly ILogger<Repository<T>> _logger;
protected readonly DbSet<T> _dbSet;
public CommomRepository(MyDbContext context, ILogger<Repository<T>> logger)
{
_context = context;
_logger = logger;
if (_context != null)
{
_dbSet = _context.Set<T>();
}
else
{
}
}
internal IQueryable<T> _Select(Expression<Func<T, bool>> filter = null
, Func<IQueryable<T>, IOrderedQueryable<T>> orderBy = null
, List<Expression<Func<T, object>>> includes = null
, int? pageIndex = null
, int? pageSize = null)
{
IQueryable<T> query = _dbSet;
if (includes != null)
{
query = includes.Aggregate(query, (current, include) => current.Include(include));
}
if (orderBy != null)
{
query = orderBy(query);
}
if (filter != null)
{
query = query.Where(filter);
}
if (pageIndex != null && pageSize != null)
{
query = query.Skip((pageIndex.Value - 1) * pageSize.Value).Take(pageSize.Value);
}
return query;
}
public async Task<int> CountAsync(Expression<Func<T, bool>> filter = null
, Func<IQueryable<T>, IOrderedQueryable<T>> orderBy = null
, List<Expression<Func<T, object>>> includes = null)
{
var query = _Select(filter, orderBy, includes);
return await query.CountAsync();
}
}
Usage (controller):
var singleCheckTask = _Repo.CountAsync(x=> x.id== item.id);
var nameCheckTask = _Repo.CountAsync(x=> x.name== item.name);
var ipCheckTask = _Repo.CountAsync(x=> x.ip == item.ip);
await Task.WhenAll(singleCheckTask, nameCheckTask, ipCheckTask);
And exception thowed:
Microsoft.EntityFrameworkCore.Query.Internal.SqlServerQueryCompilationContextFactory|ERROR|An exception occurred in the database while iterating the results of a query.
System.InvalidOperationException: Can not start another operation while there is an asynchronous operation pending.
I'vs tested that if I do not use Task.whenAll, var testSingleCheck = _Repo.CountAsync(x=> x.id== item.id).Result; This would be all right.
It's simple, you can't run queries in parallel with EF (neither EF6 nor EF Core).
One reasons for is, that EF isn't thread-safe.
EF 6 on Task-based pattern
Thread Safety
While thread safety would make async more useful it is an orthogonal feature. It is unclear that we could ever implement support for it in the most general case, given that EF interacts with a graph composed of user code to maintain state and there aren't easy ways to ensure that this code is also thread safe.
For the moment, EF will detect if the developer attempts to execute two async operations at one time and throw.

Using DbContext in MS unit test

I am not sure how to fit EF into my business logic tests. Let me give an example of how it works at runtime (no testing, regular application run):
Context.Set<T>.Add(instance);
When I add the entity using the above generic method, an instance is added to context, and EF fixes all the navigation properties behind the scenes. For example, if exists [instance.Parent] property, and [parent.Instances] collection property (1-to-many relationship), EF will automatically add the instance to parent.Instances collection behind the scenes.
My code depends on the [parent.Instances] collection, and if it is empty, it will fail. When I am writing unit tests using MS testing framework, how can I reuse the power of EF, so it can still do its behind-the-scenes job, but uaing the memory as data storage, and not the actual database? I am not really interested whether EF successfully added, modified or deleted something in the database, I am just interested in getting the EF magic on the in-memory sets.
I've been doing this with a mock DbContext and mock DbSet that I've created. They store test data in memory and allow you to do most of the standard things you can do on a DbSet.
Your code that acquires the DbContext initially will have to be changed so that it acquires a MockDbContext when it is running under unit test. You can determine if you are running under MSTest with the following code:
public static bool IsInUnitTest
{
get
{
return AppDomain.CurrentDomain.GetAssemblies()
.Any(assembly =>
assembly.FullName.StartsWith(
"Microsoft.VisualStudio.QualityTools.UnitTestFramework"));
}
}
Here is the code for MockDbContext:
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Data.Entity;
using System.Data.Entity.Core.Objects;
using System.Data.Entity.Infrastructure;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApplication5
{
// ProductionDbContext would be DbContext class
// generated by Entity Framework
public class MockDbContext: ProductionDbContext
{
public MockDbContext()
{
LoadFakeData();
}
// Entities (for which we'll provide MockDbSet implementation
// and test data)
public override DbSet<Account> Accounts { get; set; }
public override DbSet<AccountGenLink> AccountGenLinks { get; set; }
public override DbSet<AccountPermit> AccountPermits { get; set; }
public override DbSet<AcctDocGenLink> AcctDocGenLinks { get; set; }
// DbContext method overrides
private int InternalSaveChanges()
{
// Just return 0 in the mock
return 0;
}
public override int SaveChanges()
{
return InternalSaveChanges();
}
public override Task<int> SaveChangesAsync()
{
return Task.FromResult(InternalSaveChanges());
}
public override Task<int> SaveChangesAsync(CancellationToken cancellationToken)
{
// Just ignore the cancellation token in the mock
return SaveChangesAsync();
}
private void LoadFakeData()
{
// Tables
Accounts = new MockDbSet<Account>(this);
Accounts.AddRange(new List<Account>
{
new Account
{
SSN_EIN = "123456789", CODE = "A", accttype = "CD",
acctnumber = "1", pending = false, BankOfficer1 = string.Empty,
BankOfficer2 = null, Branch = 0, type = "18", drm_rate_code = "18",
officer_code = string.Empty, open_date = new DateTime(2010, 6, 8),
maturity_date = new DateTime(2010, 11, 8), HostAcctActive = true,
EffectiveAcctStatus = "A"
},
new Account
{
SSN_EIN = "123456789", CODE = "A", accttype = "DD",
acctnumber = "00001234", pending = false, BankOfficer1 = "BCK",
BankOfficer2 = string.Empty, Branch = 0, type = "05", drm_rate_code = "00",
officer_code = "DJT", open_date = new DateTime(1998, 9, 14),
maturity_date = null, HostAcctActive = true,
EffectiveAcctStatus = "A"
},
new Account
{
SSN_EIN = "123456789", CODE = "A", accttype = "LN", acctnumber = "1",
pending = false, BankOfficer1 = "LMP", BankOfficer2 = string.Empty,
Branch = 0, type = "7", drm_rate_code = null, officer_code = string.Empty,
open_date = new DateTime(2001, 10, 24),
maturity_date = new DateTime(2008, 5, 2), HostAcctActive = true,
EffectiveAcctStatus = "A"
}
});
AccountGenLinks = new MockDbSet<AccountGenLink>(this);
AccountGenLinks.AddRange(new List<AccountGenLink>
{
// Add your test data here if needed
});
AccountPermits = new MockDbSet<AccountPermit>(this);
AccountPermits.AddRange(new List<AccountPermit>
{
// Add your test data here if needed
});
AcctDocLinks = new MockDbSet<AcctDocLink>(this);
AcctDocLinks.AddRange(new List<AcctDocLink>
{
new AcctDocLink { ID = 1, SSN_EIN = "123456789", CODE = "A", accttype = "DD",
acctnumber = "00001234", DocID = 50, DocType = 5 },
new AcctDocLink { ID = 25, SSN_EIN = "123456789", CODE = "6", accttype = "CD",
acctnumber = "1", DocID = 6750, DocType = 5 }
});
}
}
}
And here is the code for MockDbSet:
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Entity;
using System.Data.Entity.Core.Metadata.Edm;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
namespace ConsoleApplication5
{
public sealed class MockDbSet<TEntity> : DbSet<TEntity>, IQueryable,
IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity> where TEntity : class
{
public MockDbSet(MockDbContext context)
{
// Get entity set for entity
// Used when we figure out whether to generate
// IDENTITY values
EntitySet = ((IObjectContextAdapter) context).ObjectContext
.MetadataWorkspace
.GetItems<EntityContainer>(DataSpace.SSpace).First()
.BaseEntitySets
.FirstOrDefault(item => item.Name == typeof(TEntity).Name);
Data = new ObservableCollection<TEntity>();
Query = Data.AsQueryable();
}
private ObservableCollection<TEntity> Data { get; set; }
Type IQueryable.ElementType
{
get { return Query.ElementType; }
}
private EntitySetBase EntitySet { get; set; }
Expression IQueryable.Expression
{
get { return Query.Expression; }
}
IEnumerator IEnumerable.GetEnumerator()
{
return Data.GetEnumerator();
}
public override ObservableCollection<TEntity> Local
{
get { return Data; }
}
IQueryProvider IQueryable.Provider
{
get { return new MockDbAsyncQueryProvider<TEntity>(Query.Provider); }
}
private IQueryable Query { get; set; }
public override TEntity Add(TEntity entity)
{
GenerateIdentityColumnValues(entity);
Data.Add(entity);
return entity;
}
public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
{
foreach (var entity in entities)
Add(entity);
return entities;
}
public override TEntity Attach(TEntity entity)
{
return Add(entity);
}
public override TEntity Create()
{
return Activator.CreateInstance<TEntity>();
}
public override TDerivedEntity Create<TDerivedEntity>()
{
return Activator.CreateInstance<TDerivedEntity>();
}
public override TEntity Find(params object[] keyValues)
{
throw new NotSupportedException();
}
public override Task<TEntity> FindAsync(params object[] keyValues)
{
return FindAsync(CancellationToken.None, keyValues);
}
public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
{
throw new NotSupportedException();
}
private void GenerateIdentityColumnValues(TEntity entity)
{
// The purpose of this method, which is called when adding a row,
// is to ensure that Identity column values are properly initialized
// before performing the add. If we were making a "real" Entity Framework
// Add() call, this task would be handled by the data provider and the
// value(s) would then be propagated back into the entity. In the case
// of this mock, there is nothing that will do that, so we have to make
// this at-least token effort to ensure the columns are properly initialized.
// In SQL Server, an Identity column can be of one of the following
// data types: tinyint, smallint, int, bigint, decimal (with a scale of 0),
// or numeric (with a scale of 0); This method handles the integer types
// (the others are typically not used).
foreach (var member in EntitySet.ElementType.Members.ToList())
{
if (member.IsStoreGeneratedIdentity)
{
// OK, we've got a live one; do our thing.
//
// Note that we'll get the current value of the column and,
// if it is nonzero, we'll leave it alone. We do this because
// the test data in our mock DbContext provides values for the
// Identity columns and many of those values are foreign keys
// in other entities (where we also provide test data). We don't
// want to disturb any existing relationships defined in the test data.
Type columnDataType = null;
foreach (var metadataProperty in member.TypeUsage.EdmType.MetadataProperties.ToList())
{
if (metadataProperty.Name != "PrimitiveTypeKind")
continue;
switch ((PrimitiveTypeKind)metadataProperty.Value)
{
case PrimitiveTypeKind.SByte:
columnDataType = typeof(SByte);
break;
case PrimitiveTypeKind.Int16:
columnDataType = typeof(Int16);
break;
case PrimitiveTypeKind.Int32:
columnDataType = typeof(Int32);
break;
case PrimitiveTypeKind.Int64:
columnDataType = typeof(Int64);
break;
default:
throw new InvalidOperationException();
}
var identityColumnGetter = entity.GetType().GetProperty(member.Name).GetGetMethod();
var identityColumnSetter = entity.GetType().GetProperty(member.Name).GetSetMethod();
Int64 specifiedColumnValue = 0;
switch (columnDataType.Name)
{
case "SByte":
specifiedColumnValue = (SByte)identityColumnGetter.Invoke(entity, null);
break;
case "Int16":
specifiedColumnValue = (Int16)identityColumnGetter.Invoke(entity, null);
break;
case "Int32":
specifiedColumnValue = (Int32)identityColumnGetter.Invoke(entity, null);
break;
case "Int64":
specifiedColumnValue = (Int64)identityColumnGetter.Invoke(entity, null);
break;
}
if (specifiedColumnValue != 0)
break;
Int64 maxExistingColumnValue = 0;
switch (columnDataType.Name)
{
case "SByte":
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (SByte)identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] { (SByte)(++maxExistingColumnValue) });
break;
case "Int16":
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Int16)identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] { (Int16)(++maxExistingColumnValue) });
break;
case "Int32":
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Int32)identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] { (Int32)(++maxExistingColumnValue) });
break;
case "Int64":
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Int64)identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] { (Int64)(++maxExistingColumnValue) });
break;
}
}
}
}
}
IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
{
return new MockDbAsyncEnumerator<TEntity>(Data.GetEnumerator());
}
IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
return Data.GetEnumerator();
}
public override TEntity Remove(TEntity entity)
{
Data.Remove(entity);
return entity;
}
public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
{
foreach (var entity in entities)
Remove(entity);
return entities;
}
public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
{
throw new NotSupportedException();
}
}
internal class MockDbAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
{
internal MockDbAsyncQueryProvider(IQueryProvider queryProvider)
{
QueryProvider = queryProvider;
}
private IQueryProvider QueryProvider { get; set; }
public IQueryable CreateQuery(Expression expression)
{
return new MockDbAsyncEnumerable<TEntity>(expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new MockDbAsyncEnumerable<TElement>(expression);
}
public object Execute(Expression expression)
{
return QueryProvider.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return QueryProvider.Execute<TResult>(expression);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute(expression));
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute<TResult>(expression));
}
}
internal class MockDbAsyncEnumerable<T> : EnumerableQuery<T>, IDbAsyncEnumerable<T>, IQueryable<T>
{
public MockDbAsyncEnumerable(IEnumerable<T> enumerable)
: base(enumerable)
{
}
public MockDbAsyncEnumerable(Expression expression)
: base(expression)
{
}
IQueryProvider IQueryable.Provider
{
get { return new MockDbAsyncQueryProvider<T>(this); }
}
public IDbAsyncEnumerator<T> GetAsyncEnumerator()
{
return new MockDbAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
}
internal class MockDbAsyncEnumerator<T> : IDbAsyncEnumerator<T>
{
public MockDbAsyncEnumerator(IEnumerator<T> enumerator)
{
Enumerator = enumerator;
}
public void Dispose()
{
Enumerator.Dispose();
}
public T Current
{
get { return Enumerator.Current; }
}
object IDbAsyncEnumerator.Current
{
get { return Current; }
}
private IEnumerator<T> Enumerator { get; set; }
public Task<bool> MoveNextAsync(CancellationToken cancellationToken)
{
return Task.FromResult(Enumerator.MoveNext());
}
}
}
If you are using the EntityFramework-Reverse-POCO-Code-First-Generator from Simon Hughes, with its generated FakeContext, Jeff Prince's approach is still possible with some tweaking. The little difference here is we are using the partial class support and implementing the InitializePartial() methods in the FakeContext and FakeDbSet. The bigger difference is that the Reverse POCO FakeContext does not inherit from a DbContext, so we can't easily get the MetadataWorkspace to know which columns are identities. The answer is to create a 'real' context with a bogus connection string and use that to get the EntitySetBase for the FakeDbSet. This should be pasted inside the proper namespace of a new source file, renaming the context, and you shouldn't need to do anything further in the rest of your project.
/// <summary>
/// This code will set Identity columns to be unique. It behaves differently from the real context in that the
/// identities are generated on add, not save. This is inspired by https://stackoverflow.com/a/31795273/1185620 and
/// modified for use with the FakeDbSet and FakeContext that can be generated by EntityFramework-Reverse-POCO-Code-
/// First-Generator from Simon Hughes.
///
/// Aside from changing the name of the FakeContext and the type used to in its InitializePartial() as
/// the 'realContext' this file can be pasted into another namespace for a completely unrelated context. If you
/// have additional implementation for the InitializePartial methods in the FakeContext or FakeDbSet, change the
/// name to InitializePartial2 and they will be called after InitializePartial is called here. Please don't add
/// code unrelated to the above purpose to this file - make another file to further extend the partial class.
/// </summary>
partial class FakeFooBarBazContext
{
/// <summary> Initialization of FakeContext to handle setting an identity for columns marked as
/// <c>IsStoreGeneratedIdentity</c> when an item is Added to the DbSet. If this signature
/// conflicts with another partial class, change that signature to implement
/// <see cref="InitializePartial2"/>, as that will be called when this is complete. </summary>
partial void InitializePartial()
{
// Here we need to get a 'real' ObjectContext so we can get the metadata for determining
// identity columns. Since FakeContext doesn't inherit from DbContext, create
// the real one with a bogus connection string.
using (var realContext = new FooBarBazContext("Server=."))
{
var objectContext = (realContext as IObjectContextAdapter).ObjectContext;
// Reflect over the public properties that return DbSet<> and get it. If it is
// of type FakeDbSet<>, call InitializeWithContext() on it.
var fakeDbSetGenericType = typeof(FakeDbSet<>);
var dbSetGenericType = typeof(DbSet<>);
var properties = this.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
foreach (var prop in properties)
{
if (!prop.PropertyType.IsGenericType || prop.GetMethod == null)
continue;
if (prop.PropertyType.GetGenericTypeDefinition() != dbSetGenericType)
continue;
var dbSetObj = prop.GetMethod.Invoke(this, null);
var dbSetObjType = dbSetObj?.GetType();
if (dbSetObjType?.GetGenericTypeDefinition() != fakeDbSetGenericType)
continue;
var initMethod = dbSetObjType.GetMethod(nameof(FakeDbSet<object>.InitializeWithContext),
BindingFlags.NonPublic | BindingFlags.Instance,
null, new[] {typeof(ObjectContext)}, new ParameterModifier[] { });
initMethod.Invoke(dbSetObj, new object[] {objectContext});
}
}
InitializePartial2();
}
partial void InitializePartial2();
}
partial class FakeDbSet<TEntity>
{
private EntitySetBase EntitySet { get; set; }
/// <summary> Initialization of FakeDbSet to handle setting an identity for columns marked as
/// <c>IsStoreGeneratedIdentity</c> when an item is Added to the DbSet. If this signature
/// conflicts with another partial class, change that signature to implement
/// <see cref="InitializePartial2"/>, as that will be called when this is complete. </summary>
partial void InitializePartial()
{
// The only way we know something was added to the DbSet from this partial class
// is to hook the CollectionChanged event.
_data.CollectionChanged += DataOnCollectionChanged;
InitializePartial2();
}
internal void InitializeWithContext(ObjectContext objectContext)
{
// Get entity set for entity. Used when we figure out whether to generate IDENTITY values
EntitySet = objectContext
.MetadataWorkspace
.GetItems<EntityContainer>(DataSpace.SSpace).First()
.BaseEntitySets
.FirstOrDefault(item => item.Name == typeof(TEntity).Name);
}
private void DataOnCollectionChanged(object sender, NotifyCollectionChangedEventArgs e)
{
if (e.Action != NotifyCollectionChangedAction.Add)
return;
foreach (TEntity entity in e.NewItems)
GenerateIdentityColumnValues(entity);
}
/// <summary> The purpose of this method, which is called after a row is added, is to ensure that Identity column values are
/// properly initialized. If this was a real Entity Framework, this task would be handled by the data provider
/// when SaveChanges[Async]() is called and the value(s) would then be propagated back into the entity.
/// In the case of FakeDbSet, there is nothing that will do that, so we have to make this at-least token effort
/// to ensure the columns are properly initialized, even if it is done at the incorrect time.
/// </summary>
private void GenerateIdentityColumnValues(TEntity entity)
{
foreach (var member in EntitySet.ElementType.Members)
{
if (!member.IsStoreGeneratedIdentity)
continue;
foreach (var metadataProperty in member.TypeUsage.EdmType.MetadataProperties)
{
if (metadataProperty.Name != "PrimitiveTypeKind")
continue;
var entityProperty = entity.GetType().GetProperty(member.Name);
var identityColumnGetter = entityProperty.GetGetMethod();
// Note that we'll get the current value of the column and,
// if it is nonzero, we'll leave it alone. We do this because
// the test data in our mock DbContext provides values for the
// Identity columns and many of those values are foreign keys
// in other entities (where we also provide test data). We don't
// want to disturb any existing relationships defined in the test data.
bool isDefaultForType;
var columnType = (PrimitiveTypeKind)metadataProperty.Value;
switch (columnType)
{
case PrimitiveTypeKind.SByte:
isDefaultForType = default(SByte) == (SByte)identityColumnGetter.Invoke(entity, null);
break;
case PrimitiveTypeKind.Int16:
isDefaultForType = default(Int16) == (Int16)identityColumnGetter.Invoke(entity, null);
break;
case PrimitiveTypeKind.Int32:
isDefaultForType = default(Int32) == (Int32)identityColumnGetter.Invoke(entity, null);
break;
case PrimitiveTypeKind.Int64:
isDefaultForType = default(Int64) == (Int64)identityColumnGetter.Invoke(entity, null);
break;
case PrimitiveTypeKind.Decimal:
isDefaultForType = default(Decimal) == (Decimal)identityColumnGetter.Invoke(entity, null);
break;
default:
// In SQL Server, an Identity column can be of one of the following data types:
// tinyint (SqlByte, byte), smallint (SqlInt16, Int16), int (SqlInt32, Int32),
// bigint (SqlInt64, Int64), decimal (with a scale of 0) (SqlDecimal, Decimal),
// or numeric (with a scale of 0) (SqlDecimal, Decimal). Those are handled above.
// 'If we don't know, we throw'
throw new InvalidOperationException($"Unsupported Identity Column Type {columnType}");
}
// From this point on, we can return from the method, as only one identity column is
// possible per table and we found it.
if (!isDefaultForType)
return;
var identityColumnSetter = entityProperty.GetSetMethod();
lock (Local)
{
switch (columnType)
{
case PrimitiveTypeKind.SByte:
{
SByte maxExistingColumnValue = 0;
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (SByte) identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] {(SByte) (++maxExistingColumnValue)});
return;
}
case PrimitiveTypeKind.Int16:
{
Int16 maxExistingColumnValue = 0;
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Int16) identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] {(Int16) (++maxExistingColumnValue)});
return;
}
case PrimitiveTypeKind.Int32:
{
Int32 maxExistingColumnValue = 0;
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Int32) identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] {(Int32) (++maxExistingColumnValue)});
return;
}
case PrimitiveTypeKind.Int64:
{
Int64 maxExistingColumnValue = 0;
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Int64) identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] {(Int64) (++maxExistingColumnValue)});
return;
}
case PrimitiveTypeKind.Decimal:
{
Decimal maxExistingColumnValue = 0;
foreach (var item in Local.ToList())
maxExistingColumnValue = Math.Max(maxExistingColumnValue, (Decimal) identityColumnGetter.Invoke(item, null));
identityColumnSetter.Invoke(entity, new object[] {(Decimal) (++maxExistingColumnValue)});
return;
}
}
}
}
}
}
partial void InitializePartial2();
}