Kohonen Map in Actionscript 3

Kohonen Map in Actionscript 3

This image is the output of a Kohonen Map, also called a self-organizing map, which is a type of simple neural network, mostly used for sorting and grouping large data sets. I am fairly happy with the results.

This is something I have wanted to do for about four years, starting with a project I worked on back in 2008. Click the image to launch the app. Once launched, click “PLAY/PAUSE” to begin; click “RESET” to restart the sorting process, and change the value in the “grid size” input field to change the dimensions of the color grid. Be careful; anything above 100 (e.g., 10,000 squares) will begin to run slowly.

To create the app, I started with Processing source code I found at jjguy.com. Translating the code to Actionscript 3 took about four hours, and another four to tweak and test and modify to accommodate Flash-specific functionality. All in all, it was surprisingly easy.

The code follows. There are three parts; Main.as, which is contains the UI, global variables, and initialization code; SOM.as, which is the code for the map, and Node.as, which contains the code for the individual blocks of color.

All of the source code, including compiled .swf, can be downloaded here.

 

Main.as

package {
    import flash.display.Sprite;
	import flash.display.StageAlign;
	import flash.display.StageScaleMode;
    import flash.events.Event;
	import flash.events.MouseEvent;
	import flash.events.TimerEvent;
	import flash.text.TextField;
	import flash.text.TextFieldType;
	import flash.text.TextFormat;
	import flash.utils.Timer;
    [SWF(width=720,height=480,frameRate=32,backgroundColor=0x000000)]
    public class Main extends Sprite {
		[Embed(
			source='C:/WINDOWS/Fonts/ARIAL.TTF', 
			fontName='ArialEmbed', 
			unicodeRange='U+0020-U+002F,U+0030-U+0039,U+003A-U+0040,U+0041-U+005A,U+005B-U+0060,U+0061-U+007A,U+007B-U+007E', 
			mimeType="application/x-font-truetype", embedAsCFF="false"
		)]
		private static var _arialEmbed:Class;
		
		internal var _timer:Timer;

		private var isPlaying:Boolean = false;
		
		private var btnPlayPause:Sprite;
		private var btnReset:Sprite;
		private var txtIterations:TextField;
		private var txtGridSize:TextField;
		private var colorTextBG:int = 0xcccccc;
		private var btnTextFormat:TextFormat = new TextFormat("ArialEmbed",12,0x333333,null,null,null,null,null,"center");
		private var labelTextFormat:TextFormat = new TextFormat("ArialEmbed",12,0xededed,null,null,null,null,null,"right");
		private var inputTextFormat:TextFormat = new TextFormat("ArialEmbed",12,0x000000,null,null,null,null,null,"left");

		private var som:SOM;
		private var iter:int;
		private var maxIters:int = 4000;
		public var screenW:int=480;
		public var screenH:int=480;
		private var gridSize:int = 40;

		private var rgb:Array = [
			[1,1,1],
			[0,0,0],
			[1,0,1],
			[1,0,0],
			[0,1,0],
			[0,0,1],
			[1,1,0],
			[0,1,1],
			[1,.4,.4],
			[.25,.25,.25]
		];
		
        public function Main():void {
            addEventListener(Event.ADDED_TO_STAGE,onAddedToStage);
        }
        private function onAddedToStage(e:Event):void {
            removeEventListener(Event.ADDED_TO_STAGE,onAddedToStage);
			stage.scaleMode = StageScaleMode.NO_SCALE;
			stage.align = StageAlign.TOP_LEFT;
            init();
        }
        private function init():void {
			_timer = new Timer(10);
			_timer.addEventListener(TimerEvent.TIMER,onTimer);
			initInterface();
			initMap();
			_timer.start();
        }
		
		/*	create and populate the UI elements	*/
		private function initInterface():void {
			var txtIterationsLabel:TextField = getTextField("Iterations",10,160,80,20,labelTextFormat);
			addChild(txtIterationsLabel);
			
			txtIterations = getTextField("0",100,160,50,20,labelTextFormat);
			txtIterations.selectable = false;
			addChild(txtIterations);
			
			var txtGridSizeLabel:TextField = getTextField("GRID SIZE",10,200,80,20,labelTextFormat);
			addChild(txtGridSizeLabel);
			
			txtGridSize = getTextField(gridSize.toString(),100,200,80,20,inputTextFormat);
			txtGridSize.background = true;
			txtGridSize.backgroundColor = 0xffffff;
			txtGridSize.border = true;
			txtGridSize.borderColor=0xcccccc;
			txtGridSize.type = TextFieldType.INPUT;
			txtGridSize.restrict = "0-9";
			addChild(txtGridSize);
			
			var playPauseButton:Sprite = getTextButton("PLAY/PAUSE",20,300,80,20);
			playPauseButton.addEventListener(MouseEvent.CLICK,function(e:Event):void {
				isPlaying = !isPlaying;
			});
			addChild(playPauseButton);
			
			var resetButton:Sprite = getTextButton("RESET",120,300,80,20);
			resetButton.addEventListener(MouseEvent.CLICK,function(e:Event):void {
				iter=1;
				gridSize = parseInt(txtGridSize.text);
				maxIters = gridSize*100;
				som.init(maxIters,gridSize,gridSize);
				updateMap();
			});
			addChild(resetButton);
		
		}
		
		/*	create and initialize an instance of the map 	*/
		private function initMap():void {
			som = new SOM(gridSize,gridSize, 3,screenW,screenH);
			som.x = 240;
			som.y = 0;
			addChild(som);
			iter = 1;	
			som.init(maxIters,gridSize,gridSize);
			updateMap();
		}
		
		/*	called on every tick of the timer	*/
		private function onTimer(e:TimerEvent):void {
			if(isPlaying) updateMap();
			e.updateAfterEvent();
		}
		
		/*	tell the map to make another iterations through the data, then render it to the screen	*/
		private function updateMap():void {
			var t:int = Math.floor(Math.random()*rgb.length);
			if (iter < maxIters){
				som.train(iter, rgb[t]);
				som.render();
				txtIterations.text = iter.toString();
				iter++;
			}
		}
		
		/*	functions for building interface elements	*/
		private function getTextField(txt:String,x:int,y:int,w:int,h:int,format:TextFormat):TextField {
			var tf:TextField = new TextField();
			tf.x = x;
			tf.y = y;
			tf.width = w;
			tf.height = h;
			tf.embedFonts = true;
			tf.text = txt;
			tf.setTextFormat(format);
			tf.defaultTextFormat = format;
			return tf;
		}
		private function getTextButton(txt:String,x:int,y:int,w:int,h:int):Sprite {
			var s:Sprite = new Sprite();
			s.x = x;
			s.y = y;
			s.graphics.lineStyle(1,0x808080,1,true);
			s.graphics.beginFill(colorTextBG,1);
			s.graphics.drawRect(0,0,w,h);
			s.graphics.endFill();
			s.buttonMode=true;
			s.mouseChildren = false;
			s.useHandCursor=true;
			var t:TextField = new TextField();
			t.width=w;
			t.height=h;
			t.selectable = false;
			t.embedFonts = true;
			t.text = txt;
			t.setTextFormat(btnTextFormat);
			t.defaultTextFormat = btnTextFormat;
			t.wordWrap = false;
			t.multiline=false;
			s.addChild(t);
			return s;
		}
    }
}

 

SOM.as

package {
	import flash.display.Bitmap;
	import flash.display.BitmapData;
	import flash.display.Sprite;
	import flash.geom.Point;
	import flash.geom.Rectangle;
	public class SOM extends Sprite {
		public var mapWidth:int;
		public var mapHeight:int;
		public var nodes:Array;
		public var radius:Number;
		public var timeConstant:Number;
		public var learnRate:Number = 0.05;
		public var inputDimension:int;
		private var pixPerNodeW:Number;
		private var pixPerNodeH:Number;
		
		private var canvasWidth:int;
		private var canvasHeight:int;
		private var canvasData:BitmapData;
		private var canvas:Bitmap;
		
		public var learnDecay:Number;
		public var radiusDecay:Number;
		
		/*	constructor	*/
		public function SOM(w:int,h:int,n:int,mapW:int,mapH:int):void {
			mapWidth = w;
			mapHeight = h;
			radius = (h + w) / 2;
			inputDimension = n;
			canvasWidth = mapW;
			canvasHeight = mapH;
			canvasData = new BitmapData(canvasWidth,canvasHeight,false,0x000000);
			canvas = new Bitmap(canvasData);
			addChild(canvas);
			
		}
		/*	initialize the map	*/
		public function init(iterations:int,w:int,h:int):void {
			mapWidth = w;
			mapHeight = h;
			radius = (h + w) / 2;
			pixPerNodeW = canvasWidth/mapWidth;
			pixPerNodeH = canvasHeight/mapHeight;
			nodes = [];
			for(var i:int = 0; i < mapHeight; i++){
				nodes[i] = [];
				for(var j:int = 0; j < mapWidth; j++) {
					nodes[i][j] = new Node(inputDimension, mapHeight, mapWidth);
					nodes[i][j].x = i;
					nodes[i][j].y = j;
				}//for j
			}//for i
			timeConstant = iterations/Math.log(radius);
			learnDecay = learnRate;
			radiusDecay = (mapWidth + mapHeight) / 2;
		}
		/*	iterate through and update the weights of each node	*/
		public function train(i:int,w:Array):void {  
			radiusDecay = radius*Math.exp(-(i/timeConstant));
			learnDecay = learnRate*Math.exp(-(i/timeConstant));
			//get best matching unit
			var ndxComposite:int = bestMatch(w);
			var x:int = ndxComposite >> 16;
			var y:int = ndxComposite & 0x0000FFFF;
			//scale best match and neighbors...
			for(var a:int = 0; a < mapHeight; a++) {
				for(var b:int = 0; b < mapWidth; b++) {
					var d:Number = dist(nodes[x][y].x, nodes[x][y].y, nodes[a][b].x, nodes[a][b].y);
					var influence:Number = Math.exp((-1 * Math.pow(d,2)) / (2*radiusDecay*i));
					if (d < radiusDecay) {      
						for(var k:int = 0; k < inputDimension; k++) {
							nodes[a][b].w[k] += influence * learnDecay * (w[k] - nodes[a][b].w[k]);
						}//for k
					}	//if d
				} //for j
			} // for i
		} // train()
		
		
		/*	functions used by training method, for calculating node weights and distances	*/
		public function dist(x1:Number,y1:Number,x2:Number,y2:Number):Number {
			return Math.sqrt( Math.pow(x2 - x1,2) + Math.pow(y2 - y1,2) );
		}
		public function distance(node1:Node, node2:Node):Number {
			return Math.sqrt( Math.pow(node2.x - node1.x,2) + Math.pow(node2.y - node1.y,2) );	
		}
		public function bestMatch(w:Array):int {
			var minDist:Number = Math.sqrt(inputDimension);
			var minIndex:int = 0;
			for (var i:int = 0; i < mapHeight; i++) {
				for (var j:int = 0; j < mapWidth; j++) {
				var tmp:Number = weight_distance(nodes[i][j].w, w);
					if (tmp < minDist) {
						minDist = tmp;
						minIndex = (i << 16) + j;
					}  //if
				} //for j
			} //for i
			return minIndex;
		}
		public function weight_distance(x:Array, y:Array):Number {
			if (x.length != y.length) {
				//	trace("Error in SOM::distance(): array lengths don't match");
			}
			var tmp:Number = 0.0;
			for(var i:int = 0; i < x.length; i++) {
				tmp += Math.pow( (x[i] - y[i]),2);
			}
			tmp = Math.sqrt(tmp);
			return tmp;
		}
		
		/*	render node information to the screen	*/
		public function render():void {
			for(var i:int = 0; i < mapWidth; i++) {
				for(var j:int = 0; j < mapHeight; j++) {
					var r:Number = (nodes[i][j].w[0]*255);
					var g:Number = (nodes[i][j].w[1]*255);
					var b:Number = (nodes[i][j].w[2]*255);
					var c:Number = r << 16 ^ g << 8 ^ b;
					canvasData.fillRect(new Rectangle(i*pixPerNodeW, j*pixPerNodeH, pixPerNodeW, pixPerNodeH),c);
				} // for j
			} // for i
		} // render()
	}
}

 

Node.as

package {
	public class Node {
		public var x:int;
		public var y:int;
		public var weightCount:int;
		public var w:Array;
		public function Node(n:int,X:int,Y:int):void {
			x = X;
			y = Y;
			weightCount = n;
			w = [];
			for(var i:int = 0;i<weightCount;i++) {
				w.push(Math.random()*.5+.25);
			}
		}
	}
}