Apex test class for Page Reference and Static methods - apex

I am new to write Apex test class. Can anyone help me or suggest how to start writing apex text. Below is my Apex Class
public class RadioBtn_SmileyTable_Apex
{
public RB_Rating__c ratobj {get;set;}
public user lgflag {get;set;}
public string Selected {get;set;}
public string EnteredText {get;set;}
public string Redirect {get;set;}
public string QAppname {get;set;}
public string Qimprove {get;set;}
public RadioBtn_SmileyTable_Apex()
{
QAppname = 'How satisified you are with '+getAppName();
Qimprove = 'How could we improve?';
}
Public void SelectedAnswer()
{
Selected = Apexpages.currentPage().getParameters().get('conid');
}
public pagereference savefeedback()
{
system.debug('Submit Started : '+ Redirect);
RB_Rating__c ratobj = new RB_Rating__c();
ratobj.Application_Name__c = getAppName();
if(string.isBlank(Redirect))
{
ratobj.Application_Rating__c = Selected;
}
else
{
ratobj.Application_Rating__c = Redirect;
}
ratobj.Description__c = EnteredText;
insert ratobj;
updateflgusertb();
return Auth.SessionManagement.FinishLoginflow('/'+ratobj.Id);
}
public static void updateflgusertb()
{
List<user> userupdate = new List<user>();
for(user usr : [SELECT id,Login_Flows_YN__c FROM USER WHERE ID = :UserInfo.getUserId() LIMIT 1])
{
usr.Login_Flows_YN__c = true;
userupdate.add(usr);
}
update userupdate;
}
public pagereference feedbackexit()
{
updateexitcount();
return Auth.SessionManagement.FinishLoginflow('/');
}
public void updateexitcount()
{
integer exitcnt;
List<user> updatecnt = new List<user>();
for(user cnt : [SELECT Rating_Exit_Count__c FROM USER WHERE ID = :UserInfo.getUserId() LIMIT 1])
{
if (cnt.Rating_Exit_Count__c == null)
{
cnt.Rating_Exit_Count__c = 0;
}
exitcnt = Integer.valueOf(cnt.Rating_Exit_Count__c) + 1;
cnt.Rating_Exit_Count__c = exitcnt;
updatecnt.add(cnt);
}
update updatecnt;
if (exitcnt ==3)
{
insertzerorating();
}
}
public pagereference loginflowsexornot()
{
lgflag =[SELECT Login_Flows_YN__c FROM USER WHERE ID = :UserInfo.getUserId() LIMIT 1];
if (lgflag.Login_Flows_YN__c == true)
{
return Auth.SessionManagement.FinishLoginflow('/');
}
else
{
return null;
}
}
public pagereference redirect()
{
PageReference pg = new PageReference('/apex/Slider_Rating_Redir_VF?retURL='+Selected);
Redirect = ApexPages.currentPage().getparameters().get('retURL');
pg.setRedirect(false);
return pg;
}
public static String getAppName()
{
UserAppInfo userAppInfo = [SELECT Id, AppDefinitionId FROM UserAppInfo WHERE UserId = :UserInfo.getUserId() LIMIT 1];
AppDefinition appDefinition = [SELECT DurableId, Label FROM AppDefinition Where DurableId = :userAppInfo.AppDefinitionId LIMIT 1];
return appDefinition.Label;
}
public pagereference insertzerorating()
{
RB_Rating__c ratobjZero = new RB_Rating__c();
ratobjZero.Application_Name__c = getAppName();
ratobjZero.Application_Rating__c = '0';
ratobjZero.Description__c = EnteredText;
insert ratobjZero;
updateflgusertb();
return Auth.SessionManagement.FinishLoginflow('/'+ratobjZero.Id);
}
}
Wrote below test class. Please suggest to further:
#isTest (SeeAllData = FALSE)
public class Slider_VF_TestClass
{
public static testmethod void SliderTestClass()
{
RB_Rating__c testRBRating = new RB_Rating__c();
testRBRating.Application_Name__c = 'FoxiPedia';
testRBRating.Application_Rating__c = '9';
testRBRating.Description__c='Testing the testclass for the Rating 9';
insert testRBRating;
Test.startTest();
RadioBtn_SmileyTable_Apex smileyapex = new RadioBtn_SmileyTable_Apex();
PageReference page1 = page.Slider_Rating_VF;
Test.setCurrentPage(page1);
smileyapex.ratobj.Application_Name__c= 'FoxiPedia';
smileyapex.ratobj.Application_Rating__c = '8';
smileyapex.ratobj.Description__c = 'Testing the testclass for the Rating 9';
PageReference pageRef = smileyapex.savefeedback();
ApexPages.currentPage().getParameters().put('id',smileyapex.ratobj.Id);
Test.setCurrentPage(pageRef);
Test.stopTest();
}
}
Try to write Apex Test class. Need some sample code

Related

can't serialize class org.springframework.data.mongodb.core.geo.GeoJsonPoint

I am using MongoDB Near to find nearest points, getMills is method to find nearest points based on LatLng. It runs fine till fetching data and mapping. while mapping it returns following error. My question is that how can i serialize class to GeoJsonPoint???
It's returning this error : can't serialize class org.springframework.data.mongodb.core.geo.GeoJsonPoint.
private List<Mill> getMills(GetNearByMillsRequest getNearByMillsRequest, BasicResponse basicResponse) {
// get farmerPosition
LatLng fieldPosition = getNearByMillsRequest.getFieldPosition();
double Longitude = fieldPosition.getLongitude();
double Latitude = fieldPosition.getLatitude();
System.out.println("Longitude : " + Longitude + "Latitude : "+ Latitude);
// creating GeoJson Object for farmer field location...
Point point = new Point(Latitude,Longitude);
GeoJsonPoint farmerFieldPosition = new GeoJsonPoint(point);
Float distance = getNearByMillsRequest.getDistance();
// find crop by crop Name
Query query = new Query();
query.addCriteria(Criteria.where(Mill.Constants.MILL_CROP_REQUIRED).elemMatch(Criteria.where(MillCropRequired.Constants.CROP).is(getNearByMillsRequest.getCrop())));
query.addCriteria(Criteria.where(Mill.Constants.MILL_LOCATION).near(farmerFieldPosition));
// TODO Get radius from UserPreferences
NearQuery nearQuery = NearQuery.near(farmerFieldPosition).maxDistance(new Distance(distance, Metrics.KILOMETERS));
nearQuery.query(query);
nearQuery.spherical(true); // if using 2dsphere index, otherwise delete or set false
nearQuery.num(10);
// Problem occurs here, in mapping part...
GeoResults<Mill> results = millDAO.getMongoOperations().geoNear(nearQuery, Mill.class);
System.out.println("results : " + results.getContent().size());
if (results.getContent().isEmpty()) {
basicResponse.setErrorCode(ErrorCode.FORBIDDEN_ERROR);
basicResponse.setResponse("No Crop Found for this region.");
return null;
}
List<GeoResult<Mill>> geoResults = results.getContent();
// add nearBy Mills and return this list...
List<Mill> nearByMills = new ArrayList<>();
for (GeoResult<Mill> result : geoResults) {
Mill mill = result.getContent();
nearByMills.add(mill);
}
return nearByMills;
}
This is my Mill Entity.
public class Mill extends AbstractEntity {
private String vendorId;
private String name;
private String emailAddress;
private String countryCode;
private String phoneNumber;
private String postalAddress;
#GeoSpatialIndexed(type = GeoSpatialIndexType.GEO_2DSPHERE)
private GeoJsonPoint millLocation;
private String ginningCapacity;
private String storageCapacity;
private ArrayList<MillCropRequired> millCropRequired;
public String getVendorId() {
return vendorId;
}
public void setVendorId(String vendorId) {
this.vendorId = vendorId;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getEmailAddress() {
return emailAddress;
}
public void setEmailAddress(String emailAddress) {
this.emailAddress = emailAddress;
}
public String getCountryCode() {
return countryCode;
}
public void setCountryCode(String countryCode) {
this.countryCode = countryCode;
}
public String getPhoneNumber() {
return phoneNumber;
}
public void setPhoneNumber(String phoneNumber) {
this.phoneNumber = phoneNumber;
}
public String getPostalAddress() {
return postalAddress;
}
public void setPostalAddress(String postalAddress) {
this.postalAddress = postalAddress;
}
public GeoJsonPoint getMillLocation() {
return millLocation;
}
public void setMillLocation(GeoJsonPoint millLocation) {
this.millLocation = millLocation;
}
public String getGinningCapacity() {
return ginningCapacity;
}
public void setGinningCapacity(String ginningCapacity) {
this.ginningCapacity = ginningCapacity;
}
public ArrayList<MillCropRequired> getCropRequirnment() {
return millCropRequired;
}
public void setCropRequirnment(ArrayList<MillCropRequired> millCropRequired) {
this.millCropRequired = millCropRequired;
}
public String getStorageCapacity() {
return storageCapacity;
}
public void setStorageCapacity(String storageCapacity) {
this.storageCapacity = storageCapacity;
}
public static class Constants extends AbstractEntity.Constants {
public static final String PHONE_NUMBER = "phoneNumber";
public static final String VENDOR_ID = "vendorId";
public static final String MILL_CROP_REQUIRED = "millCropRequired";
public static final String MILL_LOCATION = "millLocation";
}
}

Serial communication(Serial UART) in raspberry-pi with windows 10 iot core and sim 900a

We were building and an app that could send and receive sms from sim900a module interfacing with raspberry-pi with windows 10 iot core.
In windows form application:-
We are going to read the serial data by making a call to ReadExisting method of SerialPort instance; which returns a partial response, so we have to loop through and append the data until the serial data that we received contain a substring “OK” or “\r\n>” means we have completely read the AT command response.
do
{
if (receiveNow.WaitOne(timeout, false))
{
string data = port.ReadExisting();
serialPortData += data;
}
}while (!serialPortData.EndsWith("\r\nOK\r\n") &&
!serialPortData.EndsWith("\r\n> ") &&
!serialPortData.EndsWith("\r\nERROR\r\n"));
How to do the same thing in universal windows platform(uwp)
I have used these commands but it reads partially(till \r \n), remaining part is not being read.
Task<UInt32> loadAsyncTask;
uint ReadBufferLength = 1024;
// Set InputStreamOptions to complete the asynchronous read operation when one or more bytes is available
dataReaderObject.InputStreamOptions = InputStreamOptions.Partial;
// Create a task object to wait for data on the serialPort.InputStream
loadAsyncTask = dataReaderObject.LoadAsync(ReadBufferLength).AsTask();
// Launch the task and wait
UInt32 bytesRead = await loadAsyncTask;
if (bytesRead > 0)
{
rcvdText.Text = dataReaderObject.ReadString(bytesRead);
status.Text = "bytes read successfully!";
}
Have a look at SerialDevice class. It can be created from USB vendorID/productID or by name (e.g. COM1). It provides stream based access to the serial port.
I´m not experienced with it (work in progress) but it seems like its the way to go for UWP.
This is my current attempt to serial communication in UWP (still untested!):
using SerialCommunication.Contracts.DataTypes;
using SerialCommunication.Contracts.Interfaces;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Windows.Devices.Enumeration;
using Windows.Devices.SerialCommunication;
using Windows.Storage.Streams;
namespace SerialCommunication
{
public class SerialPort : ISerialPort, IDisposable
{
public SerialPort(string portName)
{
PortName = portName;
AdvancedQuery = SerialDevice.GetDeviceSelector(portName);
}
public SerialPort(ushort vendorId, ushort productId)
{
PortName = $"{nameof(vendorId)}={vendorId} {nameof(productId)}={productId}";
AdvancedQuery = SerialDevice.GetDeviceSelectorFromUsbVidPid(vendorId, productId);
}
public void Dispose() => Disconnect();
public string PortName { get; }
public string AdvancedQuery { get; }
public bool IsConnected => serialPort != null;
public int WriteTimeoutMilliseconds { get; set; }
public int ReadTimeoutMilliseconds { get; set; }
public uint BaudRate { get; set; }
public SerialParity Parity { get; set; } = SerialParity.None;
public SerialStopBitCount StopBits { get; set; } = SerialStopBitCount.One;
public ushort DataBits { get; set; } = 8;
public SerialHandshake Handshake { get; set; } = SerialHandshake.None;
public InputStreamOptions InputStreamOptions { get; set; } = InputStreamOptions.ReadAhead;
public UnicodeEncoding Encoding { get; set; } = UnicodeEncoding.Utf8;
public ByteOrder ByteOrder { get; set; } = ByteOrder.LittleEndian;
public bool Connect()
{
lock (serialPortLock)
{
CreateSerialDevice().Wait();
// serial port
serialPort.WriteTimeout = TimeSpan.FromMilliseconds(WriteTimeoutMilliseconds);
serialPort.ReadTimeout = TimeSpan.FromMilliseconds(ReadTimeoutMilliseconds);
serialPort.BaudRate = BaudRate;
serialPort.Parity = Parity;
serialPort.StopBits = StopBits;
serialPort.DataBits = DataBits;
serialPort.Handshake = Handshake;
// output stream
dataWriter = new DataWriter(serialPort.OutputStream);
dataWriter.UnicodeEncoding = Encoding;
dataWriter.ByteOrder = ByteOrder;
// input stream
dataReader = new DataReader(serialPort.InputStream);
dataReader.InputStreamOptions = InputStreamOptions;
dataReader.UnicodeEncoding = Encoding;
dataReader.ByteOrder = ByteOrder;
// start reader
ReadWorker();
return IsConnected;
}
}
public void Disconnect()
{
lock (serialPortLock)
{
if (serialPort == null) return;
continuousReadData = false;
serialPort?.Dispose();
serialPort = null;
}
}
private async Task CreateSerialDevice()
{
var foundDevices = await DeviceInformation.FindAllAsync(AdvancedQuery);
if (foundDevices.Count == 0) throw new SerialPortException($"No device found: {nameof(PortName)}={PortName} {nameof(AdvancedQuery)}={AdvancedQuery}");
var deviceId = foundDevices[0].Id;
serialPort = await SerialDevice.FromIdAsync(deviceId);
if (serialPort == null) throw new SerialPortException($"Error creating device: {nameof(PortName)}={PortName} {nameof(AdvancedQuery)}={AdvancedQuery} {nameof(deviceId)}={deviceId}");
}
public void Write(byte[] bytes, int index = 0, int count = -1)
{
if (count < 0) count = bytes.Length - index;
byte[] tmp = new byte[count];
Array.Copy(bytes, index, tmp, 0, count);
WriteBytes(tmp);
}
public void InjectAndDispatch(byte[] receivedBytes)
{
lock (listeners)
{
foreach (var listener in listeners)
if (listener.IsActive)
listener.AddBytes(receivedBytes);
}
}
private async void ReadWorker()
{
continuousReadData = true;
while (continuousReadData)
{
await dataReader.LoadAsync(1);
byte[] receivedBytes = new byte[dataReader.UnconsumedBufferLength];
dataReader.ReadBytes(receivedBytes);
lock (listeners)
foreach (var listener in listeners)
if (listener.IsActive)
listener.AddBytes(receivedBytes);
}
dataReader.Dispose();
}
private async void WriteBytes(params byte[] bytes)
{
dataWriter.WriteBytes(bytes);
await dataWriter.StoreAsync();
await dataWriter.FlushAsync();
}
private readonly object serialPortLock = new object();
private SerialDevice serialPort;
private DataWriter dataWriter;
private DataReader dataReader;
private volatile bool continuousReadData = true;
#region listeners
public IListener AddListener()
{
lock (listenersLock)
{
var listener = new Listener();
listeners.Add(listener);
return listener;
}
}
public void RemoveListener(IListener listener)
{
lock (listenersLock)
listeners.Remove(listener as IListenerInternal);
}
private readonly object listenersLock = new object();
private readonly List<IListenerInternal> listeners = new List<IListenerInternal>();
class Listener : IListenerInternal
{
private bool _IsActive;
public bool IsActive
{
get { return _IsActive; }
set
{
if (_IsActive != value)
{
_IsActive = value;
Clear();
}
}
}
public void AddBytes(byte[] bytes)
{
lock (receivedBytesLock)
receivedBytes.AddRange(bytes);
BytesReceived?.Invoke(this, EventArgs.Empty);
}
public event EventHandler BytesReceived;
public byte[] GetReceivedBytesAndClear()
{
lock (receivedBytesLock)
{
var bytes = receivedBytes.ToArray();
receivedBytes.Clear();
return bytes;
}
}
public byte[] GetReceivedBytes()
{
lock (receivedBytesLock)
return receivedBytes.ToArray();
}
public void Clear()
{
lock (receivedBytesLock)
receivedBytes.Clear();
}
public void Trim(int length)
{
lock (receivedBytesLock)
{
var count = receivedBytes.Count;
if (count > length)
receivedBytes.RemoveRange(0, count - length);
}
}
private readonly object receivedBytesLock = new object();
private readonly List<byte> receivedBytes = new List<byte>();
}
#endregion
}
}
Furthermore I had trouble using the System.IO.SerialPort (of .NET Framework) with strings (ReadExisting etc). It seems like sometimes the encoding makes it hard to create good results. I always use it with byte arrays - less trouble, more fun!

Autofac WithKey Attribute not working as expected (multiple implementations)

I tried to reconstruct the problem in LinqPad:
/*
“Named and Keyed Services”
http://autofac.readthedocs.org/en/latest/advanced/keyed-services.html
*/
const string A = "a";
const string B = "b";
const string MyApp = "MyApp";
void Main()
{
var builder = new ContainerBuilder();
builder
.RegisterType<MyClassA>()
.As<IMyInterface>()
.InstancePerLifetimeScope()
.Keyed<IMyInterface>(A);
builder
.RegisterType<MyClassB>()
.As<IMyInterface>()
.InstancePerLifetimeScope()
.Keyed<IMyInterface>(B);
builder
.RegisterType<MyAppDomain>()
.Named<MyAppDomain>(MyApp);
var container = builder.Build();
var instance = container.ResolveKeyed<IMyInterface>(A);
instance.AddTheNumbers().Dump();
var myApp = container.ResolveNamed<MyAppDomain>(MyApp);
myApp.Dump();
}
interface IMyInterface
{
int AddTheNumbers();
}
class MyClassA : IMyInterface
{
public int AddTheNumbers() { return 1 + 2; }
}
class MyClassB : IMyInterface
{
public int AddTheNumbers() { return 3 + 4; }
}
class MyAppDomain
{
public MyAppDomain([WithKey(A)]IMyInterface aInstance, [WithKey(B)]IMyInterface bInstance)
{
this.ANumber = aInstance.AddTheNumbers();
this.BNumber = bInstance.AddTheNumbers();
}
public int ANumber { get; private set; }
public int BNumber { get; private set; }
public override string ToString()
{
var sb = new StringBuilder();
sb.AppendFormat("ANumber: {0}", this.ANumber);
sb.AppendFormat(", BNumber: {0}", this.BNumber);
return sb.ToString();
}
}
when MyApp is “dumped” I am seeing ANumber: 7, BNumber: 7 which tells me that WithKey(A) is not returning the expected instance. What am I doing wrong here?
Looks like you forgot to register the consumers using WithAttributeFilter which is what allows this to work. Like:
builder.RegisterType<ArtDisplay>().As<IDisplay>().WithAttributeFilter();

Unit testing generic repository

I'm pretty new to unit testing and I'm having some problems with regards, to unit testing a generic repository in my application. I've implemented the unit of work pattern in my ASP.NET MVC application. My classes look like this:
public class UnitOfWork : IUnitOfWork
{
private bool disposed = false;
private IGenericRepository<Shop> _shopRespository;
public UnitOfWork(PosContext context)
{
this.Context = context;
}
public PosContext Context { get; private set; }
public IGenericRepository<Shop> ShopRepository
{
get
{
return this._shopRespository ?? (this._shopRespository = new GenericRepository<Shop>(this.Context));
}
}
public void SaveChanges()
{
this.Context.SaveChanges();
}
public void Dispose()
{
this.Dispose(true);
}
protected virtual void Dispose(bool disposing)
{
if (!this.disposed)
{
if (disposing)
{
this.Context.Dispose();
}
this.disposed = true;
}
}
}
public class PosContext : DbContext, IPosContext
{
public DbSet<Shop> Shops { get; private set; }
}
public class GenericRepository<T> : IGenericRepository<T>
where T : class
{
private readonly PosContext context;
private readonly DbSet<T> dbSet;
public GenericRepository(PosContext context)
{
this.context = context;
this.dbSet = context.Set<T>();
}
public virtual IEnumerable<T> Get(
Expression<Func<T, bool>> filter = null,
Func<IQueryable<T>, IOrderedQueryable<T>> orderBy = null,
string includeProperties = "")
{
IQueryable<T> query = this.dbSet;
if (filter != null)
{
query = query.Where(filter);
}
foreach (var includeProperty in includeProperties.Split
(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries))
{
query = query.Include(includeProperty);
}
if (orderBy != null)
{
return orderBy(query).ToList();
}
else
{
return query.ToList();
}
}
public virtual T Find(object id)
{
return this.dbSet.Find(id);
}
public virtual void Add(T entity)
{
this.dbSet.Add(entity);
}
public virtual void Remove(object id)
{
T entityToDelete = this.dbSet.Find(id);
this.Remove(entityToDelete);
}
public virtual void Remove(T entityToDelete)
{
if (this.context.Entry(entityToDelete).State == EntityState.Detached)
{
this.dbSet.Attach(entityToDelete);
}
this.dbSet.Remove(entityToDelete);
}
public virtual void Update(T entityToUpdate)
{
this.dbSet.Attach(entityToUpdate);
this.context.Entry(entityToUpdate).State = EntityState.Modified;
}
I'm using NUnit and FakeItEasy to write my unit tests. In my set up function, I create a UnitIfWork object with a fake PosContext object. I then populate the context with a few Shop objects.
[SetUp]
public void SetUp()
{
this.unitOfWork = new UnitOfWork(A.Fake<PosContext>());
this.unitOfWork.ShopRepository.Add(new Shop() { Id = 1, Name = "Test name1" });
this.unitOfWork.ShopRepository.Add(new Shop() { Id = 2, Name = "Test name2" });
this.unitOfWork.ShopRepository.Add(new Shop() { Id = 3, Name = "Test name3" });
this.unitOfWork.ShopRepository.Add(new Shop() { Id = 4, Name = "Test name4" });
this.unitOfWork.ShopRepository.Add(new Shop() { Id = 5, Name = "Test name5" });
this.Controller = new ShopController(this.unitOfWork);
}
It works fine when I test the Find-method of the GenericRepository. The correct Shop object is returned and I can assert that it works fine:
[TestCase]
public void DetailsReturnsCorrectShop()
{
// Arrange
int testId = 1;
// Act
Shop shop = this.unitOfWork.ShopRepository.Find(testId);
ViewResult result = this.Controller.Details(testId) as ViewResult;
// Assert
Shop returnedShop = (Shop)result.Model;
Assert.AreEqual(testId, returnedShop.Id);
}
But when I want to test that the Get-method returns all shops from the repository, if I do not give any filter params, I get an empty list back. I can't figure out why?
[TestCase]
public void IndexReturnsListOfShops()
{
// Arrange
// Act
ViewResult result = this.Controller.Index() as ViewResult;
// Assert
List<Shop> returnedShops = (List<Shop>)result.Model;
Assert.AreEqual(5, returnedShops.Count);
}
The ShopController looks like this:
public class ShopController : Controller
{
private readonly IUnitOfWork unitOfWork;
public ShopController(IUnitOfWork unitOfWork)
{
this.unitOfWork = unitOfWork;
}
// GET: /Shop/
public ActionResult Index()
{
return View(this.unitOfWork.ShopRepository.Get());
}
// GET: /Shop/Details/5
public ActionResult Details(int? id)
{
if (id == null)
{
return new HttpStatusCodeResult(HttpStatusCode.BadRequest);
}
Shop shop = this.unitOfWork.ShopRepository.Find(id);
if (shop == null)
{
return HttpNotFound();
}
return View(shop);
}
}
Can you help me figure out why I get an empty list back from the Get-method?

ASP.NET MVC 2 Authorization with Gateway Page

I've got an MVC 2 application which won't be doing its own authentication, but will retrieve a user ID from the HTTP request header, since users must pass through a gateway before reaching the application.
Once in the app, we need to match up the user ID to information in a "users" table, which contains some security details the application makes use of.
I'm familiar with setting up custom membership and roles providers in ASP.NET, but this feels so different, since the user never should see a login page once past the gateway application.
Questions:
How do I persist the user ID, if at all? It starts out in the request header, but do I have to put it in a cookie? How about SessionState?
Where/when do I get this information? The master page shows the user's name, so it should be available everywhere.
I'd like to still use the [Authorize(Roles="...")] tag in my controller if possible.
We have a very similar setup where I work. As #Mystere Man mentioned, there are risks with this setup, but if the whole infrastructure is setup and running correctly, we have found it to be a secure setup (we do care about security). One thing to ensure, is that the SiteMinder agent is running on the IIS node you're trying to secure, as it will validate an encrypted SMSESSION key also passed in the headers, which will make the requests secure (it would be extremely difficult to spoof the value of the SMSESSION header).
We are using ASP.NET MVC3, which has global action filters, which is what we're using. But with MVC2, you could create a normal, controller level action filter that could be applied to a base controller class so that all of your controllers/actions will be secured.
We have created a custom configuration section that allows us to turn this security filter on and off via web.config. If it's turned off, our configuration section has properties that will allow you to "impersonate" a given user with given roles for testing and debugging purposes. This configuration section also allows us to store the values of the header keys we're looking for in config as well, in case the vendor ever changes the header key names on us.
public class SiteMinderConfiguration : ConfigurationSection
{
[ConfigurationProperty("enabled", IsRequired = true)]
public bool Enabled
{
get { return (bool)this["enabled"]; }
set { this["enabled"] = value; }
}
[ConfigurationProperty("redirectTo", IsRequired = true)]
public RedirectToElement RedirectTo
{
get { return (RedirectToElement)this["redirectTo"]; }
set { this["redirectTo"] = value; }
}
[ConfigurationProperty("sessionCookieName", IsRequired = true)]
public SiteMinderSessionCookieNameElement SessionCookieName
{
get { return (SiteMinderSessionCookieNameElement)this["sessionCookieName"]; }
set { this["sessionCookieName"] = value; }
}
[ConfigurationProperty("userKey", IsRequired = true)]
public UserKeyElement UserKey
{
get { return (UserKeyElement)this["userKey"]; }
set { this["userKey"] = value; }
}
[ConfigurationProperty("rolesKey", IsRequired = true)]
public RolesKeyElement RolesKey
{
get { return (RolesKeyElement)this["rolesKey"]; }
set { this["rolesKey"] = value; }
}
[ConfigurationProperty("firstNameKey", IsRequired = true)]
public FirstNameKeyElement FirstNameKey
{
get { return (FirstNameKeyElement)this["firstNameKey"]; }
set { this["firstNameKey"] = value; }
}
[ConfigurationProperty("lastNameKey", IsRequired = true)]
public LastNameKeyElement LastNameKey
{
get { return (LastNameKeyElement)this["lastNameKey"]; }
set { this["lastNameKey"] = value; }
}
[ConfigurationProperty("impersonate", IsRequired = false)]
public ImpersonateElement Impersonate
{
get { return (ImpersonateElement)this["impersonate"]; }
set { this["impersonate"] = value; }
}
}
public class SiteMinderSessionCookieNameElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
public class RedirectToElement : ConfigurationElement
{
[ConfigurationProperty("loginUrl", IsRequired = false)]
public string LoginUrl
{
get { return (string)this["loginUrl"]; }
set { this["loginUrl"] = value; }
}
}
public class UserKeyElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
public class RolesKeyElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
public class FirstNameKeyElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
public class LastNameKeyElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
public class ImpersonateElement : ConfigurationElement
{
[ConfigurationProperty("username", IsRequired = false)]
public UsernameElement Username
{
get { return (UsernameElement)this["username"]; }
set { this["username"] = value; }
}
[ConfigurationProperty("roles", IsRequired = false)]
public RolesElement Roles
{
get { return (RolesElement)this["roles"]; }
set { this["roles"] = value; }
}
}
public class UsernameElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
public class RolesElement : ConfigurationElement
{
[ConfigurationProperty("value", IsRequired = true)]
public string Value
{
get { return (string)this["value"]; }
set { this["value"] = value; }
}
}
So our web.config looks something like this
<configuration>
<configSections>
<section name="siteMinderSecurity" type="MyApp.Web.Security.SiteMinderConfiguration, MyApp.Web" />
...
</configSections>
...
<siteMinderSecurity enabled="false">
<redirectTo loginUrl="https://example.com/login/?ReturnURL={0}"/>
<sessionCookieName value="SMSESSION"/>
<userKey value="SM_USER"/>
<rolesKey value="SN-AD-GROUPS"/>
<firstNameKey value="SN-AD-FIRST-NAME"/>
<lastNameKey value="SN-AD-LAST-NAME"/>
<impersonate>
<username value="ImpersonateMe" />
<roles value="Role1, Role2, Role3" />
</impersonate>
</siteMinderSecurity>
...
</configuration>
We have a custom SiteMinderIdentity...
public class SiteMinderIdentity : GenericIdentity, IIdentity
{
public SiteMinderIdentity(string name, string type) : base(name, type) { }
public IList<string> Roles { get; set; }
}
And a custom SiteMinderPrincipal...
public class SiteMinderPrincipal : GenericPrincipal, IPrincipal
{
public SiteMinderPrincipal(IIdentity identity) : base(identity, null) { }
public SiteMinderPrincipal(IIdentity identity, string[] roles) : base(identity, roles) { }
}
And we populate HttpContext.Current.User and Thread.CurrentPrincipal with an instance of SiteMinderPrincipal that we build up based on information that we pull from the request headers in our action filter...
public class SiteMinderSecurity : ActionFilterAttribute
{
public override void OnActionExecuting(ActionExecutingContext filterContext)
{
base.OnActionExecuting(filterContext);
var request = filterContext.HttpContext.Request;
var response = filterContext.HttpContext.Response;
if (MyApp.SiteMinderConfig.Enabled)
{
string[] userRoles = null; // default to null
userRoles = Array.ConvertAll(request.Headers[MyApp.SiteMinderConfig.RolesKey.Value].Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries), r => r.Trim());
var identity = new SiteMinderIdentity(request.Headers[MyApp.SiteMinderConfig.UserKey.Value];, "SiteMinder");
if (userRoles != null)
identity.Roles = userRoles.ToList();
var principal = new SiteMinderPrincipal(identity, userRoles);
HttpContext.Current.User = principal;
Thread.CurrentPrincipal = principal;
}
else
{
var roles = Array.ConvertAll(MyApp.SiteMinderConfig.Impersonate.Roles.Value.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries), r => r.Trim());
var identity = new SiteMinderIdentity(MyApp.SiteMinderConfig.Impersonate.Username.Value, "SiteMinder") { Roles = roles.ToList() };
var principal = new SiteMinderPrincipal(identity, roles);
HttpContext.Current.User = principal;
Thread.CurrentPrincipal = principal;
}
}
}
MyApp is a static class that gets initialized at application startup that caches the configuration information so we're not reading it from web.config on every request...
public static class MyApp
{
private static bool _isInitialized;
private static object _lock;
static MyApp()
{
_lock = new object();
}
private static void Initialize()
{
if (!_isInitialized)
{
lock (_lock)
{
if (!_isInitialized)
{
// Initialize application version number
_version = FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location).FileVersion;
_siteMinderConfig = (SiteMinderConfiguration)ConfigurationManager.GetSection("siteMinderSecurity");
_isInitialized = true;
}
}
}
}
private static string _version;
public static string Version
{
get
{
Initialize();
return _version;
}
}
private static SiteMinderConfiguration _siteMinderConfig;
public static SiteMinderConfiguration SiteMinderConfig
{
get
{
Initialize();
return _siteMinderConfig;
}
}
}
From what I gather of your situation, you have information in a database that you'll need to lookup based on the information in the headers to get everything you need, so this won't be exactly what you need, but it seems like it should at least get you started.
Hope this helps.