Wednesday, September 24, 2008

Structual Compare for a Linq Expression

Rather than explicitly coding whether a linq Expression matches a certain form, you could walk the tree and see if it matches an expression template (pattern).

Anyway here's the structual compare to save anyone else needing to adapt the standard visitor:

///
/// This is adapted from Matt Warren's sample:
/// http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx
///

public class ExpressionStructualCompare
{
public virtual bool Visit(Expression exp, Expression exp2)
{
if (exp == null)
return exp2 == null;

if (exp.NodeType != exp2.NodeType)
return false;

switch (exp.NodeType)
{
case ExpressionType.Negate:
case ExpressionType.NegateChecked:
case ExpressionType.Not:
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.ArrayLength:
case ExpressionType.Quote:
case ExpressionType.TypeAs:
return VisitUnary((UnaryExpression)exp, (UnaryExpression)exp2);
case ExpressionType.Add:
case ExpressionType.AddChecked:
case ExpressionType.Subtract:
case ExpressionType.SubtractChecked:
case ExpressionType.Multiply:
case ExpressionType.MultiplyChecked:
case ExpressionType.Divide:
case ExpressionType.Modulo:
case ExpressionType.And:
case ExpressionType.AndAlso:
case ExpressionType.Or:
case ExpressionType.OrElse:
case ExpressionType.LessThan:
case ExpressionType.LessThanOrEqual:
case ExpressionType.GreaterThan:
case ExpressionType.GreaterThanOrEqual:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.Coalesce:
case ExpressionType.ArrayIndex:
case ExpressionType.RightShift:
case ExpressionType.LeftShift:
case ExpressionType.ExclusiveOr:
return VisitBinary((BinaryExpression)exp, (BinaryExpression)exp2);
case ExpressionType.TypeIs:
return VisitTypeIs((TypeBinaryExpression)exp, (TypeBinaryExpression)exp2);
case ExpressionType.Conditional:
return VisitConditional((ConditionalExpression)exp, (ConditionalExpression)exp);
case ExpressionType.Constant:
return VisitConstant((ConstantExpression)exp, (ConstantExpression)exp2);
case ExpressionType.Parameter:
return VisitParameter((ParameterExpression)exp, (ParameterExpression)exp2);
case ExpressionType.MemberAccess:
return VisitMemberAccess((MemberExpression)exp, (MemberExpression)exp2);
case ExpressionType.Call:
return VisitMethodCall((MethodCallExpression)exp, (MethodCallExpression)exp2);
case ExpressionType.Lambda:
return VisitLambda((LambdaExpression)exp, (LambdaExpression)exp2);
case ExpressionType.New:
return VisitNew((NewExpression)exp, (NewExpression)exp2);
case ExpressionType.NewArrayInit:
case ExpressionType.NewArrayBounds:
return VisitNewArray((NewArrayExpression)exp, (NewArrayExpression)exp2);
case ExpressionType.Invoke:
return VisitInvocation((InvocationExpression)exp, (InvocationExpression)exp2);
case ExpressionType.MemberInit:
return VisitMemberInit((MemberInitExpression)exp, (MemberInitExpression)exp2);
case ExpressionType.ListInit:
return VisitListInit((ListInitExpression)exp, (ListInitExpression)exp2);
default:
throw new Exception(string.Format("Unhandled expression type: '{0}'", exp.NodeType));
}
}

protected virtual bool VisitBinding(MemberBinding binding, MemberBinding binding2)
{
if (binding.BindingType != binding2.BindingType)
return false;
switch (binding.BindingType)
{
case MemberBindingType.Assignment:
return VisitMemberAssignment((MemberAssignment)binding, (MemberAssignment)binding2);
case MemberBindingType.MemberBinding:
return VisitMemberMemberBinding((MemberMemberBinding)binding, (MemberMemberBinding)binding2);
case MemberBindingType.ListBinding:
return VisitMemberListBinding((MemberListBinding)binding, (MemberListBinding)binding2);
default:
throw new Exception(string.Format("Unhandled binding type '{0}'", binding.BindingType));
}
}

protected virtual bool VisitElementInitializer(ElementInit initializer, ElementInit initializer2)
{
return VisitExpressionList(initializer.Arguments, initializer2.Arguments);
}

protected virtual bool VisitUnary(UnaryExpression u, UnaryExpression u2)
{
return Visit(u.Operand, u2.Operand);
}

protected virtual bool VisitBinary(BinaryExpression b, BinaryExpression b2)
{
if (!Visit(b.Left, b2.Left)) return false;
if (!Visit(b.Right, b2.Right)) return false;
return Visit(b.Conversion, b2.Conversion);
}

protected virtual bool VisitTypeIs(TypeBinaryExpression b, TypeBinaryExpression b2)
{
return Visit(b.Expression, b2.Expression);
}

protected virtual bool VisitConstant(ConstantExpression c, ConstantExpression c2)
{
return true;
}

protected virtual bool VisitConditional(ConditionalExpression c, ConditionalExpression c2)
{
if (!Visit(c.Test, c2.Test)) return false;
if (!Visit(c.IfTrue, c2.IfTrue)) return false;
return Visit(c.IfFalse, c2.IfFalse);
}

protected virtual bool VisitParameter(ParameterExpression p, ParameterExpression p2)
{
return true;
}

protected virtual bool VisitMemberAccess(MemberExpression m, MemberExpression m2)
{
return Visit(m.Expression, m2.Expression);
}

protected virtual bool VisitMethodCall(MethodCallExpression m, MethodCallExpression m2)
{
if (!Visit(m.Object, m2.Object)) return false;
return VisitExpressionList(m.Arguments, m2.Arguments);
}

protected virtual bool VisitExpressionList(ReadOnlyCollection original, ReadOnlyCollection original2)
{
if (original.Count != original2.Count)
return false;
for (int i = 0, n = original.Count; i < n; i++)
{
if (!Visit(original[i], original2[i]))
return false;
}
return true;
}

protected virtual bool VisitMemberAssignment(MemberAssignment assignment, MemberAssignment assignment2)
{
return Visit(assignment.Expression, assignment2.Expression);
}

protected virtual bool VisitMemberMemberBinding(MemberMemberBinding binding, MemberMemberBinding binding2)
{
return VisitBindingList(binding.Bindings, binding2.Bindings);
}

protected virtual bool VisitMemberListBinding(MemberListBinding binding, MemberListBinding binding2)
{
return VisitElementInitializerList(binding.Initializers, binding2.Initializers);
}

protected virtual bool VisitBindingList(ReadOnlyCollection original, ReadOnlyCollection original2)
{
if (original.Count != original2.Count)
return false;
for (int i = 0, n = original.Count; i < n; i++)
{
if (!VisitBinding(original[i], original2[i]))
{
return false;
}
}
return true;
}

protected virtual bool VisitElementInitializerList(ReadOnlyCollection original, ReadOnlyCollection original2)
{
if (original.Count != original2.Count)
return false;
for (int i = 0, n = original.Count; i < n; i++)
{
if (!VisitElementInitializer(original[i], original2[i]))
{
return false;
}
}
return true;
}

protected virtual bool VisitLambda(LambdaExpression lambda, LambdaExpression lambda2)
{
return Visit(lambda.Body, lambda2.Body);
}

protected virtual bool VisitNew(NewExpression nex, NewExpression nex2)
{
return VisitExpressionList(nex.Arguments, nex2.Arguments);
}

protected virtual bool VisitMemberInit(MemberInitExpression init, MemberInitExpression init2)
{
if (!VisitNew(init.NewExpression, init2.NewExpression))
{
return false;
}
return VisitBindingList(init.Bindings, init2.Bindings);
}

protected virtual bool VisitListInit(ListInitExpression init, ListInitExpression init2)
{
if (!VisitNew(init.NewExpression, init2.NewExpression))
{
return false;
}
return VisitElementInitializerList(init.Initializers, init2.Initializers);
}

protected virtual bool VisitNewArray(NewArrayExpression na, NewArrayExpression na2)
{
return VisitExpressionList(na.Expressions, na2.Expressions);
}

protected virtual bool VisitInvocation(InvocationExpression iv, InvocationExpression iv2)
{
if (!VisitExpressionList(iv.Arguments, iv2.Arguments))
{
return false;
}
return Visit(iv.Expression, iv2.Expression);
}
}

No comments: