Done the nonlinear classifer master
authorNeil Smith <neil.github@njae.me.uk>
Fri, 30 Sep 2011 15:14:37 +0000 (16:14 +0100)
committerNeil Smith <neil.github@njae.me.uk>
Fri, 30 Sep 2011 15:14:37 +0000 (16:14 +0100)
Gemfile.lock [new file with mode: 0644]
lib/svm/svm.rb

diff --git a/Gemfile.lock b/Gemfile.lock
new file mode 100644 (file)
index 0000000..5c6621e
--- /dev/null
@@ -0,0 +1,25 @@
+GEM
+  remote: http://rubygems.org/
+  specs:
+    diff-lcs (1.1.3)
+    gnuplot (2.3.6)
+    rake (0.9.2)
+    rdoc (3.9.4)
+    rspec (2.6.0)
+      rspec-core (~> 2.6.0)
+      rspec-expectations (~> 2.6.0)
+      rspec-mocks (~> 2.6.0)
+    rspec-core (2.6.4)
+    rspec-expectations (2.6.0)
+      diff-lcs (~> 1.1.2)
+    rspec-mocks (2.6.0)
+
+PLATFORMS
+  ruby
+
+DEPENDENCIES
+  bundler (~> 1.0.0)
+  gnuplot
+  rake
+  rdoc
+  rspec (~> 2.6.0)
index e4a4d7fbd99c02f9b1eb4dd68559d210ce39037b..a681d9f39c57743c4e227d5207107014c9221e28 100644 (file)
@@ -6,6 +6,7 @@ class ClassifiedData
   def initialize(data = Array.new, classification = nil)
     @data = data
     @classification = classification
   def initialize(data = Array.new, classification = nil)
     @data = data
     @classification = classification
+    self
   end
 end
 
   end
 end
 
@@ -18,6 +19,15 @@ def load_data(filename, only_numbers = false)
   rows
 end
 
   rows
 end
 
+def matches_to_numeric(rows)
+  rows.map do |row|
+    d = row.data
+    ClassifiedData.new([d[0].to_f, yes_no(d[1]), yes_no(d[2]), 
+             d[5].to_f, yes_no(d[6]), yes_no(d[7]), 
+             match_count(d[3], d[8])],
+            row.classification)
+  end
+end
 
 def plot_age_matches(rows)
   Gnuplot.open do |gp|
 
 def plot_age_matches(rows)
   Gnuplot.open do |gp|
@@ -77,7 +87,7 @@ def dot_product_classify(point, averages)
   end
 end
 
   end
 end
 
-def yesno(v)
+def yes_no(v)
   if v == 'yes' then 1
   elsif v == 'no' then -1
   else 0
   if v == 'yes' then 1
   elsif v == 'no' then -1
   else 0
@@ -92,4 +102,71 @@ def miles_distance(a1, a2)
   0
 end
 
   0
 end
 
+def scale_data_set(rows)
+  # Could be many rows, so still make one pass through the data rather than 
+  # using Array#max and #min for each data field
+  lows = Array.new(rows[0].data.length, 999999999.0)
+  highs = Array.new(rows[0].data.length, -999999999.0)
+  rows.each do |row|
+    data = row.data
+    (data.length).times do |i|
+      lows[i]  = data[i] if data[i] < lows[i]
+      highs[i] = data[i] if data[i] > highs[i]
+    end
+  end
+  
+  scale_data = Proc.new do |row|
+    row.zip(lows, highs).map {|d| (d[0] - d[1]) / (d[2] - d[1]) }
+  end
+  
+  new_rows = rows.map do |row|
+    ClassifiedData.new(scale_data.call(row.data), row.classification)
+  end
+  
+  return new_rows, scale_data
+end
+
+# Usage:
+# numeric_matches = matches_to_numeric matches
+# scaled_set, scale_f = scale_data_set numeric_matches
+# averages = linear_train scaled_set
+# dot_product_classify(scale_f.call(numeric_matches[11].data), averages)
+
+def radial_basis(v1, v2, gamma = 20)
+  len = Math.sqrt((v1.zip v2).map {|c| (c[0] - c[1]) ** 2 }.reduce(:+))
+  Math.exp(-gamma * len)
+end
+
+def nonlinear_classify(point, rows, offset, gamma = 10)
+  match_sum = no_match_sum = 0.0
+  match_count = no_match_count = 0
+  rows.each do |row|
+    if row.classification == 1
+      match_sum += radial_basis(point, row.data, gamma)
+      match_count += 1
+    else
+      no_match_sum += radial_basis(point, row.data, gamma)
+      no_match_count += 1
+    end
+  end
+  y = match_sum / match_count - no_match_sum / no_match_count + offset
+  if y < 0
+    0
+  else
+    1
+  end
+end
 
 
+def nonlinear_offset(rows, gamma = 10)
+  matches = [] ; no_matches = []
+  rows.each do |r|
+    if r.classification == 1
+      matches << r.data
+    else
+      no_matches << r.data
+    end
+  end
+  sum_matches = matches.map {|v1| matches.map {|v2| radial_basis(v1, v2, gamma)}.reduce(:+)}.reduce(:+)
+  sum_no_matches = no_matches.map {|v1| no_matches.map {|v2| radial_basis(v1, v2, gamma)}.reduce(:+)}.reduce(:+)
+  sum_matches / matches.length ** 2 - sum_no_matches / no_matches.length ** 2
+end