I am very new to Scala. I have done an exercise using Scala to solve the maximum path problem.
Basically, I have a triangle of integers, I want to find the path from the top to bottom which the numbers on the route produces the largest sum. For example:
5
12 3
2 4 9
1 9 12 7
Should return:
5 -> 12 -> 4 -> 12
Sum = 33
Could someone please help review the code and make some suggestions about:
- Better algorithm
- Writing better scala code
Here is the code:
import scala._
class MaxPath {
def sumTree(input : List[List[Int]], tempResult : List[List[Int]], currentlevel : Int) : List[List[Int]] = {
if(currentlevel == input.size) return tempResult
val updatedRow : List[Int] = currentlevel match {
case 0 => input(currentlevel)
case _ => {
val lastRow = tempResult(currentlevel - 1)
val currentRow = input(currentlevel)
val newHead = currentRow.head + lastRow.head
val newLast = currentRow.last + lastRow.last
val middleSection = currentRow.drop(1).dropRight(1)
val newMiddle : List[Int] = for {
(value, index) <- middleSection.zipWithIndex
middle = value + Math.max(lastRow(index), lastRow(index + 1))
} yield middle
newHead::newMiddle:::List(newLast)
}
}
sumTree(input, tempResult:::List(updatedRow), currentlevel + 1)
}
def traceBack(input : List[List[Int]], transformed : List[List[Int]], index : Int, currentLevel : Int) : List[Int] = {
if(currentLevel == input.size) return Nil
val row = input(currentLevel)
val rowTransformed = transformed(currentLevel)
val max = index match {
case -1 => {
val maxIndex = rowTransformed.zipWithIndex.maxBy(_._1)._2
(row(maxIndex), maxIndex)
}
case x => {
val rowSize = row.size
x match {
case 0 => (row.head, 0)
case `rowSize` => (row.last, 0)
case _ => {
val maxIndex = List(rowTransformed(index-1), rowTransformed(index)).zip(List(index - 1, index)).maxBy(_._1)._2
(row(maxIndex), maxIndex)
}
}
}
}
traceBack(input, transformed, max._2, currentLevel + 1) ::: List(max._1)
}
def calcualte(input : List[List[Int]]) : (List[Int], Int) = {
val transformed = sumTree(input, Nil, 0)
val path = traceBack(input.reverse, transformed.reverse, -1, 0)
(path, path.sum)
}
}
object MaxPath {
val input : List[List[Int]] = List(
List(5),
List(12, 3),
List(2, 4, 9),
List(1, 9, 12, 7)
)
def main(args : Array[String]) = {
val mp = new MaxPath
val result = mp.calcualte(input)
println(result._1.mkString(" -> "))
println("Sum = " + result._2)
}
}