Saturday, August 21, 2010

Scala: Functional programming & DSL

I first learned about Scala in forge's comment to this post, and the tutorial he mentioned is really great. Just for experimenting, I extended the Calculator example for Variables, Functions and exponentiation. The result is really extensible and took me just 140 lines of code, which really impressed me. So I'm showing you the result:

AST.scala (Scala doesn't require Files to represent the class names, and AST (abstract syntax tree) fits the classes in this file)
package net.slightlymagic.calc {
  abstract class Expr
  case class Variable(name: String) extends Expr {
    override def toString = name
  }
  case class Number(value: Double) extends Expr {
    override def toString = "" + value
  }
  case class UnaryOp(name: String, arg: Expr) extends Expr {
    override def toString = name + arg
  }
  case class BinaryOp(name: String, left: Expr, right: Expr) extends Expr {
    override def toString = "(" + left + name + right + ")"
  }
  case class Function(name: String, args: List[Expr]) extends Expr {
    override def toString = name + args.mkString("(", ", ", ")")
  }
}


Calc.scala

package net.slightlymagic.calc {
  object Calc {
    def simplify(e: Expr): Expr = {
      def simpArgs(e: Expr) = e match {
        case BinaryOp(op, left, right) => BinaryOp(op, simplify(left), simplify(right))
        case UnaryOp(op, operand) => UnaryOp(op, simplify(operand))
        case Function(op, operands) => Function(op, operands.map((x) => simplify(x)))
        case x => x
      }
     
      def simpTop(e: Expr) = e match {
        case UnaryOp("-", UnaryOp("-", x)) => x
        case BinaryOp("-", x, Number(0)) => x
        case BinaryOp("-", Number(0), x) => UnaryOp("-", x)
        case BinaryOp("-", x1, x2) if x1 == x2 => Number(0)
       
        case UnaryOp("+", x) => x
        case BinaryOp("+", x, Number(0)) => x
        case BinaryOp("+", Number(0), x) => x
       
        case BinaryOp("*", x, Number(1)) => x
        case BinaryOp("*", Number(1), x) => x
       
        case BinaryOp("*", x, Number(0)) => Number(0)
        case BinaryOp("*", Number(0), x) => Number(0)
       
        case BinaryOp("/", x, Number(1)) => x
        case BinaryOp("/", x1, x2) if x1 == x2 => Number(1)
       
        case BinaryOp("^", x, Number(1)) => x
        case BinaryOp("^", x, Number(0)) if x != Number(0) => Number(1)
        case BinaryOp("^", x1, UnaryOp("-", x2)) => BinaryOp("/", Number(1), BinaryOp("^", x1, x2))
        case e => e
      }
     
      simpTop(simpArgs(e))
    }
   
    def evaluate(expr: Expr, variables: Map[String, Double], functions: Map[String, (List[Double]) => Double]): Double = {
      expr match {
        case Number(x) => x
        case UnaryOp("-", x) => -evaluate(x, variables, functions)
        case BinaryOp("+", x1, x2) => evaluate(x1, variables, functions) + evaluate(x2, variables, functions)
        case BinaryOp("-", x1, x2) => evaluate(x1, variables, functions) - evaluate(x2, variables, functions)
        case BinaryOp("*", x1, x2) => evaluate(x1, variables, functions) * evaluate(x2, variables, functions)
        case BinaryOp("/", x1, x2) => evaluate(x1, variables, functions) / evaluate(x2, variables, functions)
        case BinaryOp("^", x1, x2) => Math.pow(evaluate(x1, variables, functions), evaluate(x2, variables, functions))
        case Variable(x) => variables(x)
        case Function(x, args) => functions(x)(args.map((x) => evaluate(x, variables, functions)))
      }
    }
   
    def parse(text : String) = CalcParser.parse(text).get
   
    val standardVars =
      Map(
        "E"  -> Math.E,
        "Pi" -> Math.Pi
      )
   
    val standardFuns =
      Map(
        "sin"  -> {x:List[Double] => Math.sin(x(0))},
        "cos"  -> {x:List[Double] => Math.cos(x(0))},
        "tan"  -> {x:List[Double] => Math.tan(x(0))},
        "cot"  -> {x:List[Double] => 1/Math.tan(x(0))},
        "asin" -> {x:List[Double] => Math.asin(x(0))},
        "acos" -> {x:List[Double] => Math.acos(x(0))},
        "atan" -> {x:List[Double] => Math.atan(x(0))},
        "acot" -> {x:List[Double] => Math.atan(1/x(0))},
        "exp"  -> {x:List[Double] => Math.exp(x(0))},
        "log"  -> {x:List[Double] => Math.log(x(0))},
        "min"  -> {x:List[Double] => Math.min(x(0), x(1))},
        "max"  -> {x:List[Double] => Math.max(x(0), x(1))}
      )
   
    def evaluate(expr: String, variables: Map[String, Double], functions: Map[String, (List[Double]) => Double]): Double = {
      evaluate(
        simplify(parse(expr)),
        if(variables == null) standardVars else standardVars ++ variables,
        if(functions == null) standardFuns else standardFuns ++ functions
      )
    }
   
   
    import scala.util.parsing.combinator._
   
    object CalcParser extends JavaTokenParsers {
     
      def expr: Parser[Expr] =
        (term ~ "+" ~ term) ^^ { case lhs~plus~rhs => BinaryOp("+", lhs, rhs) } |
        (term ~ "-" ~ term) ^^ { case lhs~minus~rhs => BinaryOp("-", lhs, rhs) } |
        term
 
      def term: Parser[Expr] =
        (factor ~ "*" ~ factor) ^^ { case lhs~times~rhs => BinaryOp("*", lhs, rhs) } |
        (factor ~ "/" ~ factor) ^^ { case lhs~div~rhs => BinaryOp("/", lhs, rhs) } |
        factor
 
      def factor : Parser[Expr] =
        (exp ~ "^" ~ exp) ^^ { case lhs~exp~rhs => BinaryOp("^", lhs, rhs) } |
        exp
 
      def exp : Parser[Expr] =
        ("+" ~ expr) ^^ { case plus~rhs => UnaryOp("+", rhs) } |
        ("-" ~ expr) ^^ { case minus~rhs => UnaryOp("-", rhs) } |
        "(" ~> expr <~ ")" |
        ident ~ params ^^ {case name~params => Function(name, params)} |
        ident ^^ {case name => Variable(name)} |
        floatingPointNumber ^^ {x => Number(x.toFloat) }
     
      def params : Parser[List[Expr]] =
        "(" ~> repsep(expr, ",") <~ ")"
     
      def parse(text : String) = parseAll(expr, text)
    }
  }
}


Wow! I don't want to repeat the tutorial, which is really great, but I want to mention a few features that really improved the way this code is written.

First of all, Scala is a mixture of functional and object oriented: Map[String, (List[Double]) => Double] (Scala uses square brackets for generic types) means a Map of Strings and Functions, which take a list of Doubles as a parameter and yields a Double as the result. To get the same result as standardFuns in the Calc object... Yes, object. This basically means that all members are static. On the opposite, normal classes can't have static members, which doesn't really matter because you can have class and object of the same name. To get the same result as standardFuns using Java, you'd have to define a common interface and several anonymous classes, which means a lot of duplication. Using anonymous functions and the fact that a function can be passed as a parameter, something you don't get from nonfuntional languages, you avoid most of the redundant code. No interface definition, just a single line per anonymous function.

The second thing that you can write method calls in an "unusual way": a.plus(b) is the same as a plus b. Together with the fact that you can name methods, e.g. "+", you basically have overriding operators. This means not much more than you can write short, concise code. But that is very important; just look at the preceding paragraph. The parser package of Scala uses such "Operators" to write short, expressive text parsing code.

Lastly, pattern matching. This is a great feature that is really hard to imitate using Java. It uses case classes, as seen in AST.scala, and the match/case construct. For example:

expr match {
  case BinaryOp(x1, op, x2) => doSomething()
  case BinaryOp(Number(x1), "+", Number(x2) => doSomethingElse()
}

doSomething() is only executed if the case matches, and at the same time assigns the three variables x1, op, x2. You can build more complex constructs, as the second example shows, including nesting Objects and using constant expressions to narrow down matching cases. This works as long as you use case classes.

And now comes another important part: Scala compiles to Java byte code, which means that you can code a part in scala. For example, the simple Calculator DSL was way easier to write in Scala than it would be in Java. Once it's packed as a jar, you don't even notice that the code was not written in traditional Java, and integrates nicely into the other parts. (Except you'd have a hard time to call "operator methods" or use API that uses Functions as parameters. Consider that for your public, plain-old-java accessible API)
The only downside is that Eclipse integration is not great so far. Maven does a good task in still compiling your Scala code, but the Scala plugin is not very stable, so I use the plain text editor. It works for small things like the Calculator, but I miss obvious features like import organizing, content assist, automatic building on every change, syntax & error highlighting, parentheses matching, folding...

2 comments:

Hellfish said...

This post has made me realize how much "standard" (i.e. java/C-family style) syntax is ingrained in my mind. I got a headache trying to make sense of the code!

nantuko84 said...

I remember the same feeling when I studied Haskel. you need the new way of "thinking".