"""
npp_dfs.py: breadth-first search for Number Partitioning Problem.

The Number Partitioning Problem (NPP) is a combinatorial optimization
problem where: given a set of numbers, find a partition into two
subsets such that the difference between the sum of their elements is
minimal.

This file contains a set of functions for doing tree search for this
problem, using breadth-first search.  The search can be complete or
incomplete; it is interrupted if the CPU time used exceeds the limit
allowed.

Copyright (c) by Joao Pedro PEDROSO and Mikio KUBO, 2009
"""
from bisect import *
from graphtools import adjacent
from npp import mk_part, differencing_construct, longest_processing_time
from npp_dfs import mk_part_II, display, Infinity
from chrono import clock
import sys
sys.setrecursionlimit(10000)	# required for the large benchmarks
LOG = False
global init
global count

class Beam:
    """hold beam information:
      label - sorted list with pairs [(wi,i),...], where wi is the weight of item i
      disjoin - edges [(i,j),...] indicating that i and j must be in different partitions
      join - edges [(i,j),...] indicating that i and j must be in the same partition
      remain - sum of weights for current list of items
      UB - an upper bound (obtained with the differencing method)
      """
    def __cmp__(self,other):
        # does not work for long ints return int(self.UB-other.UB)
        if self.DEV<other.DEV:
            return -1
        elif  self.DEV>other.DEV:
            return 1
        if self.UB<other.UB:
            return -1
        elif  self.UB>other.UB:
            return 1
        return 0
    def __cmp____(self,other):
        # does not work for long ints return int(self.UB-other.UB)
        if self.UB<other.UB:
            return -1
        elif  self.UB>other.UB:
            return 1
        return 0
    def __str__(self):
        s = ""
        s += "label: " + str(self.label) + "\n"
        s += "disj: " + str(self.disjoin) + "\n"
        s += "join: " + str(self.join) + "\n"
        s += "remain: " + str(self.remain) + "\n"
        s += "dev: " + str(self.DEV) + "\n"
        try:
            s += "UB: " + str(self.UB) + "\n"
        except:
            pass
        return s

def differencing_UB(label_):
    """fast version, only for obtaining objective
    """
    label = [d for d,_ in label_]
    for _ in range(len(label)-1):
        d1 = label.pop()
        d2 = label.pop()
        insort(label, d1-d2)
    return label.pop()

def differencing_edges(label_):
    label = list(label_)
    edges = []
    for _ in range(len(label)-1):
        d1,i1 = label.pop()
        d2,i2 = label.pop()
        insort(label, (d1-d2, i1))
        edges.append((i1,i2))	# edge will force the two items in different partitions
    return edges


def npp_bs(n, label, disjoin, join, remain, bestobj, LB, bmax):
    """beam search for number partitioning, based on differencing method
    Arguments:
      n - number of items (on the current list)
      label - sorted list with pairs [(wi,i),...], where wi is the weight of item i
      disjoin - edges [(i,j),...] indicating that i and j must be in different partitions
      join - edges [(i,j),...] indicating that i and j must be in the same partition
      bestobj - objective value for the best known solution
      remain - sum of weights for current list of items
      LB - known lower bound (0 for even sum of weights, 1 for odd)
    """

    global init, limit, count
    opt = True


    # construct initial beam
    b = Beam()
    b.label = label
    b.disjoin = disjoin	# edges that force vertices to be in separate partitions
    b.join = join	# edges that force vertices to be in the same partition
    b.remain = remain	# remaining items/differences
    b.DEV = 0		# number of sums (right branches) up to the current node
    b.UB = None		# upeer bound

    # initial list of beams
    B = [b]
    while B != []:
        sys.stdout.flush()
        if count % 1 == 0 and clock() > limit:
            print "cpu time exceeded..."
            opt = False
            break

        if LOG:
            print 
            print 
            print 
            print "NEW ITERATION n=%d" % n, len(B), "BEAMS"
            print 
        cut = []
        i = -1
        for b in B:
            i += 1
            count += 1
            if LOG:
                print count, "branches, current beam:"
                print b

            # check if current branch can be cut
            d1,i1 = b.label.pop()
            b.slack = (2*d1 - b.remain)	# difference between largest item and sum of others
            if b.slack >= 0 or abs(b.slack) == LB:	# first item is larger that the sum of the others
                # best solution that can be achieved = slack
                if b.slack >= bestobj:
                    if LOG:	print "CUT:", display(b.slack), ">", display(bestobj)
                    cut.append(i)
                    continue
                else:
                    print "UPDATING BEST: A ***",
                    print display(b.slack), "<", display(bestobj), "cpu:", clock()-init, "/", limit-init, "\t", count, "nodes"
                    sys.stdout.flush()
                    # first item must be in a partition, and remaining items in another:
                    bdisj = list(b.disjoin)
                    for d2,i2 in b.label:
                        bdisj.append((i1,i2))
                    bestobj = abs(b.slack)
                    best = (bestobj, list(bdisj), list(b.join))
                    if bestobj == LB:
                        # optimal solution found: first element == sum others (+-1)
                        print "optimal solution found",
                        return True, best
                # no better solution can be obtained from here
                cut.append(i)
                continue
            insort(b.label, (d1, i1))	# restore first element in label list

            # calculate upper bounds
            if b.UB == None:
                b.UB = differencing_UB(b.label)
                if LOG:
                    print "updating UB, current beam:"
                    print b
                if b.UB < bestobj:
                    print "UPDATING BEST: B ***",
                    print display(b.UB), "<", display(bestobj), "cpu:", clock()-init, "/", limit-init, "\t", count, "nodes"
                    sys.stdout.flush()
                    best = (b.UB, b.disjoin+differencing_edges(b.label), list(b.join))
                    bestobj = b.UB
                    # print "best:", best

                    if bestobj == LB:
                        # optimal solution found: first element == sum others
                        print "optimal solution found with differencing"
                        return True, best
        for i in reversed(cut):
            B.pop(i)
        B.sort()
        if len(B) > bmax:
            opt = False
            del B[bmax:]

        sys.stdout.flush()
        if count % 1 == 0 and clock() > limit:
            print "cpu time exceeded..."
            opt = False
            break

        #print len(B), "/", bmax, "intermediate beams for n=", n
        #sys.stdout.flush()
        n -= 1
        # for 4 or less items, differencing is exact
        if n <= 3:	# 4, but n was already decremented
            break


        newB = []	# where to place the new beams
        for b in B:
            if LOG:
                print "SECOND PHASE, current beam:"
                print b

            # print "current list:", label, "disjoin:", disjoin, "join:", join
            d1,i1 = b.label.pop()
            d2,i2 = b.label.pop()

            # 
            # FIRST BRANCH: try the same as differencing heuristic
            # 
            insort(b.label, (d1-d2, i1))
            b.disjoin.append((i1,i2))	# edge will force the two items in different partitions
            b1 = Beam()
            b1.label = list(b.label)
            b1.disjoin = list(b.disjoin)
            b1.join = list(b.join)
            b1.remain = b.remain-2*d2	# -(d1+d2)+(d1-d2) = -2*d2
            b1.UB = b.UB
            b1.DEV = b.DEV
            newB.append(b1)
            if LOG:
                print "created new beam:"
                print b1

            # restore data structures
            pos = bisect_left(b.label,(d1-d2, i1))
            b.label.pop(pos)
            b.disjoin.pop()

            # 
            # SECOND BRANCH: try the other possibility: put i1 and i2 on the same partition
            # 
            if n <= 3:
                continue
            insort(b.label, (d1+d2, i1))
            b.join.append((i1,i2))	# to assure i1 and i2 will be in same partition
            b.UB = None
            b.DEV += 1
            if LOG:
                print "updated beam:"
                print b

        if n <= 3:
            B = newB
        else:
            B = newB + B

    return opt, best

def differencing_bfs(data, bmax, cpulimit):
    """depth-first search based on the differencing method:
       partition a list of items into two partitions
       * prepare data
       * call the recursive function for the actual depth-first search
       * return the two partitions obtained
    """
    global init, limit, count
    init = clock()
    limit = cpulimit
    count = 0

    # copy and sort data by decreasing order
    n = len(data)
    label = [(data[i],i) for i in range(n)]
    label.sort()

    # initialize data for the differencing method's graph
    bestobj = Infinity
    disjoin = []	# edges that force vertices to be in separate partitions
    join = []		# edges that force vertices to be in the same partition
    remain = sum(data)	# remaining items/differences
    LB = remain & 1	# LB=1 for odd sums, 0 for even

    # call beam search
    opt, (obj,disjoin,join) = npp_bs(n, label, disjoin, join, remain, bestobj, LB, bmax)

    # make the partition, based on the disjoin/join edges
    p1,p2 = mk_part_II(adjacent(range(n),disjoin),adjacent(range(n),join),[],[],0)
    # print "\n\nbest partition:", p1,p2
    # make a list with the weights for each partition
    d1 = [data[i] for i in p1]
    d2 = [data[i] for i in p2]
    # print "p1 indices:", p1, "weights:", d1, sum(d1)
    # print "p2 indices:", p2, "weights:", d2, sum(d2)
    print "objective:", display(obj)
    print "number of nodes on BFS:", count
    sys.stdout.flush()
    return opt, obj, d1, d2



if __name__ == '__main__':
    
    try:
        bmax = int(sys.argv[1])
    except IndexError:
        # print "usage:", sys.argv[0], "bmax filename"
        bmax = Infinity
        # bmax = 1000

    try:
        filename = sys.argv[2]
    except IndexError:
        filename = "INSTANCES/NPP/toy.dat"
        # filename = "INSTANCES/NPP/easy0070.dat"
        # filename = "INSTANCES/NPP/hard0020.dat"
        filename = "INSTANCES/NPP/n015d10e00.dat"
    try:
        f = open(filename)
    except IOError:
        print "file", filename, "could not be read"
        exit(-1)

    data = f.readlines()
    f.close()
    data = [int(i) for i in data]
    print "initial  data:", filename, ",\tlog2(sum+1)=", display(sum(data))
    print "bmax:", bmax

    # print "\ndifferencing_construct: case of two partitions"""
    # obj, d1, d2 = differencing_construct(data)
    # 
    # c = d1+d2
    # c.sort()
    # assert c == data
    # 
    # print "\nlongest_processing_time: case of two partitions"""
    # part = longest_processing_time(data,2)
    # pmax = 0
    # pmin = Infinity
    # for p in part:
    #     s = sum(p)
    #     print p, s
    #     pmin = min(pmin,s)
    #     pmax = max(pmax,s)
    # print "objective:", pmax - pmin

    print "\nbeam search"""
    opt, obj,d1,d2 = differencing_bfs(data,bmax,320)
    if opt == True:
        star = '*'
    else:
        star = ''
    c = d1+d2
    c.sort()
    assert abs(sum(d1)-sum(d2)) == obj
    assert c == data

    print "BeamSearch objective:", display(obj), star
    sys.stdout.flush()
