As an effort to teach myself Swift as well as to get familiar with machine learning algorithms, I've been trying to implement common algorithms, starting with a Random Forest. This is, for the moment just one of the tree, but I have been trying to implement it just from the theory, without looking at pseudo-code, in order to really understand the process.
It was harder than I thought, due to the lack of convenience statistical and data-related functions and methods that are common in R or Python. This code seems to work and builds correct trees, although some of the methods I use, (lots of mapping...) sometimes seem a bit convoluted.
First, a node to store the splits:
import Foundation
class Node: CustomStringConvertible
{
let isTerminal:Bool
var value:Double? = nil
var leftChild:Node? = nil
var rightChild:Node? = nil
var variable:Int? = nil
var description: String
var result:Int? = nil
init(value:Double, variable:Int)
{
//Split node
self.value = value
self.isTerminal = false
self.variable = variable
self.description = "\(variable): \(value)"
}
init(result: Int)
{
//Terminal node
self.result = result
self.isTerminal = true
self.description = "Terminal node: \(result)\n"
}
func addLeftChild(child:Node)
{
self.leftChild = child
self.description += " L -> \(child)\n"
}
func addRightChild(child:Node)
{
self.rightChild = child
self.description += " R -> \(child)\n"
}
//For prediction
func getChild(x:Double) -> Node
{
if x < value
{
return leftChild!
}
else
{
return rightChild!
}
}
}
Some data taken from the famous Iris dataset:
let x = [[6.8,6.2,5.9,5.9,5.7,7.7,4.5,5.8,5,6.3,5.1,4.3,5.7,4.9,7],
[3,3.4,3.2,3,2.6,3,2.3,2.7,2.3,2.5,3.8,3,3.8,3,3.2],
[5.5,5.4,4.8,5.1,3.5,6.1,1.3,4.1,3.3,4.9,1.9,1.1,1.7,1.4,4.7],
[2.1,2.3,1.8,1.8,1,2.3,0.3,1,1,1.5,0.4,0.1,0.3,0.2,1.4]]
let y = [3,3,2,3,2,3,1,2,2,2,1,1,1,1,2]
The impurity criterion (gini) function:
func giniImpurity(y:[Int]) -> Double
{
let len = Double(y.count)
let countedSet = NSCountedSet(array: y)
let squaredProbs = countedSet.map { (c) -> Double in
let cnt = Double(countedSet.countForObject(c)) / len
return cnt * cnt
}
return 1 - squaredProbs.reduce(0, combine: +)
}
Iterating through X to find the best split. I am not certain of the way to sort and iterate through the values...
func findBestSplit(x:[Double], y:[Int]) -> (bestVal: Double, maxDelta: Double)
{
// Find the indices that sort x
let xSortedIndices = x.indices.sort { x[$0] > x[$1] }
let xSorted = xSortedIndices.map { x[$0] }
//Sort y according to those
let ySorted = xSortedIndices.map { y[$0] }
var bestGin:Double = 0
let origini = giniImpurity(y)
var bestSplit = 0
//Iterate through all values of x to find the best split
for i in 0..<ySorted.count
{
let left = Array(ySorted[0..<i])
let right = Array(ySorted[i..<ySorted.count])
let gini = (giniImpurity(left) * Double(left.count) + giniImpurity(right) * Double(right.count)) / Double(y.count)
let deltaGini = origini - gini
if deltaGini > bestGin
{
bestGin = deltaGini
bestSplit = i
}
}
return (bestVal: xSorted[bestSplit], maxDelta: bestGin)
}
And finally the tree building:
func buildTree(x: [[Double]], y:[Int]) -> Node
{
var bestVar:Int = 0
var bestGini:Double = 0
var bestVal:Double = 0
// Apply the findBestSplit on all columns to find the best split among those
for col in 0..<x.count
{
let res = findBestSplit(x[col], y: y)
if res.maxDelta > bestGini
{
bestVar = col
bestGini = res.maxDelta
bestVal = res.bestVal
}
}
let node = Node(value: bestVal, variable: bestVar)
//Split X & Y according to the split found
let rightIndices = x[bestVar].indices.filter { x[bestVar][$0] > bestVal}
let leftIndices = x[bestVar].indices.filter { x[bestVar][$0] <= bestVal}
let rightX = x.map { (col) -> [Double] in
return rightIndices.map {col[$0]}
}
let leftX = x.map { (col) -> [Double] in
return leftIndices.map {col[$0]}
}
let rightY = rightIndices.map {y[$0]}
let leftY = leftIndices.map {y[$0]}
// If pure enough, add terminal node, else recurse
if giniImpurity(leftY) < 0.1
{
let countedSet = NSCountedSet(array: leftY)
let counts = countedSet.map { countedSet.countForObject($0) }
let result = Array(countedSet)[counts.indexOf(counts.maxElement()!)!] as! Int
node.addLeftChild(Node(result: result))
}
else
{
node.addLeftChild(buildTree(leftX, y: leftY))
}
if giniImpurity(rightY) < 0.1
{
let countedSet = NSCountedSet(array: rightY)
let counts = countedSet.map { countedSet.countForObject($0) }
let result = Array(countedSet)[counts.indexOf(counts.maxElement()!)!] as! Int
node.addRightChild(Node(result: result))
}
else
{
node.addRightChild(buildTree(rightX, y: rightY))
}
return node
}
let root = buildTree(x, y: y)
I would love to have feedback on this, either on the Swift style or on the correctness of the algorithm.