Done the nonlinear classifer
[pci.git] / lib / svm / svm.rb
1 require 'gnuplot'
2
3 class ClassifiedData
4 attr_accessor :data, :classification
5
6 def initialize(data = Array.new, classification = nil)
7 @data = data
8 @classification = classification
9 self
10 end
11 end
12
13 def load_data(filename, only_numbers = false)
14 rows = []
15 IO.foreach(filename) do |line|
16 rows << ClassifiedData.new(line.chomp.split(',')[0..-2].map {|field| only_numbers ? field.to_f : field},
17 line.chomp.split(',')[-1].to_i)
18 end
19 rows
20 end
21
22 def matches_to_numeric(rows)
23 rows.map do |row|
24 d = row.data
25 ClassifiedData.new([d[0].to_f, yes_no(d[1]), yes_no(d[2]),
26 d[5].to_f, yes_no(d[6]), yes_no(d[7]),
27 match_count(d[3], d[8])],
28 row.classification)
29 end
30 end
31
32 def plot_age_matches(rows)
33 Gnuplot.open do |gp|
34 Gnuplot::Plot.new( gp ) do |plot|
35
36 plot.title "Ages of matches"
37 plot.ylabel "man"
38 plot.xlabel "woman"
39
40 matches = rows.select {|r| r.classification == 1}.map {|r| r.data}
41 non_matches = rows.select {|r| r.classification == 0}.map {|r| r.data}
42
43 plot.data = [ Gnuplot::DataSet.new( [matches] ) do |ds|
44 ds.with = "points"
45 ds.notitle
46 end , Gnuplot::DataSet.new( [non_matches] ) do |ds|
47 ds.with = "points"
48 ds.notitle
49 end ]
50 end
51 end
52 end
53
54
55 def linear_train(rows)
56 sums = {}
57 averages = {}
58 counts = Hash.new(0)
59
60 rows.each do |row|
61 row_class = row.classification
62 sums[row_class] ||= [0.0] * (row.data.length)
63
64 (row.data.length).times do |i|
65 sums[row_class][i] += row.data[i]
66 end
67 counts[row_class] += 1
68 end
69
70 sums.keys.each do |match_class|
71 averages[match_class] = sums[match_class].map {|sum| sum / counts[match_class]}
72 end
73 averages
74 end
75
76 def dot_product(v1, v2)
77 (v1.zip v2).map {|c| c.reduce(:*) }.reduce(:+)
78 end
79
80 def dot_product_classify(point, averages)
81 b = (dot_product(averages[1], averages[1]) - dot_product(averages[0], averages[0])) / 2
82 y = dot_product(point, averages[0]) - dot_product(point, averages[1]) + b
83 if y > 0
84 0
85 else
86 1
87 end
88 end
89
90 def yes_no(v)
91 if v == 'yes' then 1
92 elsif v == 'no' then -1
93 else 0
94 end
95 end
96
97 def match_count(interests1, interests2)
98 (interests1.split(':') & interests2.split(':')).length
99 end
100
101 def miles_distance(a1, a2)
102 0
103 end
104
105 def scale_data_set(rows)
106 # Could be many rows, so still make one pass through the data rather than
107 # using Array#max and #min for each data field
108 lows = Array.new(rows[0].data.length, 999999999.0)
109 highs = Array.new(rows[0].data.length, -999999999.0)
110 rows.each do |row|
111 data = row.data
112 (data.length).times do |i|
113 lows[i] = data[i] if data[i] < lows[i]
114 highs[i] = data[i] if data[i] > highs[i]
115 end
116 end
117
118 scale_data = Proc.new do |row|
119 row.zip(lows, highs).map {|d| (d[0] - d[1]) / (d[2] - d[1]) }
120 end
121
122 new_rows = rows.map do |row|
123 ClassifiedData.new(scale_data.call(row.data), row.classification)
124 end
125
126 return new_rows, scale_data
127 end
128
129 # Usage:
130 # numeric_matches = matches_to_numeric matches
131 # scaled_set, scale_f = scale_data_set numeric_matches
132 # averages = linear_train scaled_set
133 # dot_product_classify(scale_f.call(numeric_matches[11].data), averages)
134
135 def radial_basis(v1, v2, gamma = 20)
136 len = Math.sqrt((v1.zip v2).map {|c| (c[0] - c[1]) ** 2 }.reduce(:+))
137 Math.exp(-gamma * len)
138 end
139
140 def nonlinear_classify(point, rows, offset, gamma = 10)
141 match_sum = no_match_sum = 0.0
142 match_count = no_match_count = 0
143 rows.each do |row|
144 if row.classification == 1
145 match_sum += radial_basis(point, row.data, gamma)
146 match_count += 1
147 else
148 no_match_sum += radial_basis(point, row.data, gamma)
149 no_match_count += 1
150 end
151 end
152 y = match_sum / match_count - no_match_sum / no_match_count + offset
153 if y < 0
154 0
155 else
156 1
157 end
158 end
159
160 def nonlinear_offset(rows, gamma = 10)
161 matches = [] ; no_matches = []
162 rows.each do |r|
163 if r.classification == 1
164 matches << r.data
165 else
166 no_matches << r.data
167 end
168 end
169 sum_matches = matches.map {|v1| matches.map {|v2| radial_basis(v1, v2, gamma)}.reduce(:+)}.reduce(:+)
170 sum_no_matches = no_matches.map {|v1| no_matches.map {|v2| radial_basis(v1, v2, gamma)}.reduce(:+)}.reduce(:+)
171 sum_matches / matches.length ** 2 - sum_no_matches / no_matches.length ** 2
172 end