Initial commit
[pci.git] / lib / svm / svm.rb
1 require 'gnuplot'
2
3 class ClassifiedData
4 attr_accessor :data, :classification
5
6 def initialize(data = Array.new, classification = nil)
7 @data = data
8 @classification = classification
9 end
10 end
11
12 def load_data(filename, only_numbers = false)
13 rows = []
14 IO.foreach(filename) do |line|
15 rows << ClassifiedData.new(line.chomp.split(',')[0..-2].map {|field| only_numbers ? field.to_f : field},
16 line.chomp.split(',')[-1].to_i)
17 end
18 rows
19 end
20
21
22 def plot_age_matches(rows)
23 Gnuplot.open do |gp|
24 Gnuplot::Plot.new( gp ) do |plot|
25
26 plot.title "Ages of matches"
27 plot.ylabel "man"
28 plot.xlabel "woman"
29
30 matches = rows.select {|r| r.classification == 1}.map {|r| r.data}
31 non_matches = rows.select {|r| r.classification == 0}.map {|r| r.data}
32
33 plot.data = [ Gnuplot::DataSet.new( [matches] ) do |ds|
34 ds.with = "points"
35 ds.notitle
36 end , Gnuplot::DataSet.new( [non_matches] ) do |ds|
37 ds.with = "points"
38 ds.notitle
39 end ]
40 end
41 end
42 end
43
44
45 def linear_train(rows)
46 sums = {}
47 averages = {}
48 counts = Hash.new(0)
49
50 rows.each do |row|
51 row_class = row.classification
52 sums[row_class] ||= [0.0] * (row.data.length)
53
54 (row.data.length).times do |i|
55 sums[row_class][i] += row.data[i]
56 end
57 counts[row_class] += 1
58 end
59
60 sums.keys.each do |match_class|
61 averages[match_class] = sums[match_class].map {|sum| sum / counts[match_class]}
62 end
63 averages
64 end
65
66 def dot_product(v1, v2)
67 (v1.zip v2).map {|c| c.reduce(:*) }.reduce(:+)
68 end
69
70 def dot_product_classify(point, averages)
71 b = (dot_product(averages[1], averages[1]) - dot_product(averages[0], averages[0])) / 2
72 y = dot_product(point, averages[0]) - dot_product(point, averages[1]) + b
73 if y > 0
74 0
75 else
76 1
77 end
78 end
79
80 def yesno(v)
81 if v == 'yes' then 1
82 elsif v == 'no' then -1
83 else 0
84 end
85 end
86
87 def match_count(interests1, interests2)
88 (interests1.split(':') & interests2.split(':')).length
89 end
90
91 def miles_distance(a1, a2)
92 0
93 end
94
95