In [31]:
import itertools
import collections
from functools import lru_cache

In [2]:
def value_of(elements):
 return sum(e['value'] for e in elements)
 
def weight_of(elements):
 return sum(e['weight'] for e in elements)

In [32]:
Element = collections.namedtuple('Element', ['weight', 'value'])

In [3]:
def dp_count(elements, weight_limit):
 count_table = {(0, j): 0 for j in range(weight_limit+1)}
 back_refs = {}

 for i, element in enumerate(elements):
 for remaining_weight in range(weight_limit+1):
 if element['weight'] > remaining_weight:
 count_table[i+1, remaining_weight] = count_table[i, remaining_weight]
 back_refs[i+1, remaining_weight] = (i, remaining_weight)
 else:
 count_table[i+1, remaining_weight] = max(
 count_table[i, remaining_weight],
 count_table[i, remaining_weight - element['weight']] + 1)
 if count_table[i, remaining_weight] > count_table[i, remaining_weight - element['weight']] + 1:
 back_refs[i+1, remaining_weight] = (i, remaining_weight)
 else:
 back_refs[i+1, remaining_weight] = (i, remaining_weight - element['weight'])

 return count_table[len(elements), weight_limit], count_table, back_refs

In [37]:
@lru_cache(maxsize=None)
def recursive_count(elements, weight_limit):
 if len(elements) == 0:
 return []
 else:
 this_element = list(elements)[0]
 other_elements = elements.difference(frozenset([this_element]))
# this_element = elements[0]
# other_elements = elements[1:]
 if this_element.weight > weight_limit:
 return recursive_count(other_elements, weight_limit)
 else:
 with_this = recursive_count(other_elements, weight_limit - this_element.weight)
 without_this = recursive_count(other_elements, weight_limit)
 if len(with_this) + 1 > len(without_this):
 return [this_element] + with_this
 else:
 return without_this

In [5]:
def dp_value(elements, weight_limit):
 value_table = {(0, j): 0 for j in range(weight_limit+1)}
 back_refs = {}
 
 for i, element in enumerate(elements):
 for wl in range(weight_limit+1):
 if element['weight'] > wl:
 value_table[i+1, wl] = value_table[i, wl]
 back_refs[i+1, wl] = (i, wl)

 else:
 value_table[i+1, wl] = max(
 value_table[i, wl],
 value_table[i, wl - element['weight']] + element['value'])
 if value_table[i, wl] > value_table[i, wl - element['weight']] + element['value']:
 back_refs[i+1, wl] = (i, wl)
 else:
 back_refs[i+1, wl] = (i, wl - element['weight'])

 return value_table[len(elements), weight_limit], value_table, back_refs

In [19]:
fs = frozenset([1, 2, 3])
fs

frozenset({1, 2, 3})

In [21]:
list(fs)[0]

1

In [23]:
fs.difference(frozenset([1]))

frozenset({2, 3})

In [47]:
@lru_cache(maxsize=None)
def recursive_valuefs(elements, weight_limit):
 if len(elements) == 0:
 return frozenset()
 else:
 this_element = list(elements)[0]
 other_elements = elements.difference(frozenset([this_element]))
 if this_element.weight > weight_limit:
 return recursive_valuefs(other_elements, weight_limit)
 else:
 with_this = recursive_valuefs(other_elements, weight_limit - this_element.weight)
 without_this = recursive_valuefs(other_elements, weight_limit)
 items_with_this = with_this.union(frozenset([this_element]))
 if sum(e.value for e in items_with_this) > sum(e.value for e in without_this):
 return items_with_this
 else:
 return without_this

In [7]:
def display_table(table, suppress_zero=True):
 def formatted_row_element(e, suppress_zero):
 if suppress_zero and e == 0:
 return ' .'
 else:
 return '{:4d}'.format(e)
 
 
 rows = max(k[0] for k in table.keys())
 columns = max(k[1] for k in table.keys())
 for r in range(rows+1):
# print(''.join('{:4d} '.format(table[r, c]) for c in range(columns + 1)))
 print(' '.join(formatted_row_element(table[r, c], suppress_zero) for c in range(columns + 1)))

In [8]:
def backtrace(table):
 r = max(k[0] for k in table.keys())
 c = max(k[1] for k in table.keys())
 back_table = {}
 while r > 0:
 back_table[r, c] = table[r, c]
 r, c = table[r, c]
 return back_table

In [9]:
def traced_table(base, backtrace):
 return {k: base[k] if k in backtrace else 0 for k in base}

In [10]:
def greedy_fill(elements, weight_limit):
 return len(list(itertools.takewhile(lambda s: s < weight_limit, itertools.accumulate(sorted(e['weight'] for e in elements)))))

In [11]:
def greedy_value_vpw(elements, weight_limit):
 return list(itertools.takewhile(lambda es: es['weight'] < weight_limit,
 itertools.accumulate(
 sorted((e for e in elements), key=lambda e: e['value'] / e['weight'], reverse=True),
 lambda es, e: {'weight': es['weight'] + e['weight'], 'value': es['value'] + e['value']}))
 )[-1]['value']

In [12]:
def greedy_value_w(elements, weight_limit):
 return list(itertools.takewhile(lambda es: es['weight'] < weight_limit,
 itertools.accumulate(
 sorted((e for e in elements), key=lambda e: e['weight']),
 lambda es, e: {'weight': es['weight'] + e['weight'], 'value': es['value'] + e['value']}))
 )[-1]['value']

In [13]:
elements = [{'weight': int(l.strip().split()[0]), 'value': int(l.strip().split()[1])} 
 for l in open('../../data/09-bags.txt')]
weight_limit = 5000

In [33]:
hashable_elements = frozenset(
 Element(weight=e['weight'], value=e['value']) for e in elements
 )

In [14]:
value, ct, br = dp_count(elements, weight_limit)
value

15

In [15]:
greedy_fill(elements, weight_limit)

15

In [39]:
len(recursive_count(hashable_elements, weight_limit))

15

In [16]:
value, vt, vbr = dp_value(elements, weight_limit)
value

2383

In [17]:
greedy_value_w(elements, weight_limit)

1801

In [18]:
greedy_value_vpw(elements, weight_limit)

2300

In [48]:
recursive_valuefs(hashable_elements, weight_limit)

frozenset({Element(weight=301, value=134),
 Element(weight=314, value=166),
 Element(weight=320, value=154),
 Element(weight=336, value=190),
 Element(weight=337, value=140),
 Element(weight=340, value=172),
 Element(weight=353, value=191),
 Element(weight=356, value=153),
 Element(weight=359, value=171),
 Element(weight=365, value=177),
 Element(weight=381, value=166),
 Element(weight=382, value=185),
 Element(weight=414, value=189),
 Element(weight=434, value=195)})

In [49]:
sum(e.value for e in recursive_valuefs(hashable_elements, weight_limit))

2383

In [52]:
(len(elements) + 1) * 5001

305061