From d1e95cca1a3c4816557d5e57b88b14619e5db8f2 Mon Sep 17 00:00:00 2001 From: Neil Smith Date: Fri, 30 Sep 2011 16:14:37 +0100 Subject: [PATCH] Done the nonlinear classifer --- Gemfile.lock | 25 ++++++++++++++++ lib/svm/svm.rb | 79 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 Gemfile.lock diff --git a/Gemfile.lock b/Gemfile.lock new file mode 100644 index 0000000..5c6621e --- /dev/null +++ b/Gemfile.lock @@ -0,0 +1,25 @@ +GEM + remote: http://rubygems.org/ + specs: + diff-lcs (1.1.3) + gnuplot (2.3.6) + rake (0.9.2) + rdoc (3.9.4) + rspec (2.6.0) + rspec-core (~> 2.6.0) + rspec-expectations (~> 2.6.0) + rspec-mocks (~> 2.6.0) + rspec-core (2.6.4) + rspec-expectations (2.6.0) + diff-lcs (~> 1.1.2) + rspec-mocks (2.6.0) + +PLATFORMS + ruby + +DEPENDENCIES + bundler (~> 1.0.0) + gnuplot + rake + rdoc + rspec (~> 2.6.0) diff --git a/lib/svm/svm.rb b/lib/svm/svm.rb index e4a4d7f..a681d9f 100644 --- a/lib/svm/svm.rb +++ b/lib/svm/svm.rb @@ -6,6 +6,7 @@ class ClassifiedData def initialize(data = Array.new, classification = nil) @data = data @classification = classification + self end end @@ -18,6 +19,15 @@ def load_data(filename, only_numbers = false) rows end +def matches_to_numeric(rows) + rows.map do |row| + d = row.data + ClassifiedData.new([d[0].to_f, yes_no(d[1]), yes_no(d[2]), + d[5].to_f, yes_no(d[6]), yes_no(d[7]), + match_count(d[3], d[8])], + row.classification) + end +end def plot_age_matches(rows) Gnuplot.open do |gp| @@ -77,7 +87,7 @@ def dot_product_classify(point, averages) end end -def yesno(v) +def yes_no(v) if v == 'yes' then 1 elsif v == 'no' then -1 else 0 @@ -92,4 +102,71 @@ def miles_distance(a1, a2) 0 end +def scale_data_set(rows) + # Could be many rows, so still make one pass through the data rather than + # using Array#max and #min for each data field + lows = Array.new(rows[0].data.length, 999999999.0) + highs = Array.new(rows[0].data.length, -999999999.0) + rows.each do |row| + data = row.data + (data.length).times do |i| + lows[i] = data[i] if data[i] < lows[i] + highs[i] = data[i] if data[i] > highs[i] + end + end + + scale_data = Proc.new do |row| + row.zip(lows, highs).map {|d| (d[0] - d[1]) / (d[2] - d[1]) } + end + + new_rows = rows.map do |row| + ClassifiedData.new(scale_data.call(row.data), row.classification) + end + + return new_rows, scale_data +end + +# Usage: +# numeric_matches = matches_to_numeric matches +# scaled_set, scale_f = scale_data_set numeric_matches +# averages = linear_train scaled_set +# dot_product_classify(scale_f.call(numeric_matches[11].data), averages) + +def radial_basis(v1, v2, gamma = 20) + len = Math.sqrt((v1.zip v2).map {|c| (c[0] - c[1]) ** 2 }.reduce(:+)) + Math.exp(-gamma * len) +end + +def nonlinear_classify(point, rows, offset, gamma = 10) + match_sum = no_match_sum = 0.0 + match_count = no_match_count = 0 + rows.each do |row| + if row.classification == 1 + match_sum += radial_basis(point, row.data, gamma) + match_count += 1 + else + no_match_sum += radial_basis(point, row.data, gamma) + no_match_count += 1 + end + end + y = match_sum / match_count - no_match_sum / no_match_count + offset + if y < 0 + 0 + else + 1 + end +end +def nonlinear_offset(rows, gamma = 10) + matches = [] ; no_matches = [] + rows.each do |r| + if r.classification == 1 + matches << r.data + else + no_matches << r.data + end + end + sum_matches = matches.map {|v1| matches.map {|v2| radial_basis(v1, v2, gamma)}.reduce(:+)}.reduce(:+) + sum_no_matches = no_matches.map {|v1| no_matches.map {|v2| radial_basis(v1, v2, gamma)}.reduce(:+)}.reduce(:+) + sum_matches / matches.length ** 2 - sum_no_matches / no_matches.length ** 2 +end -- 2.34.1