"""
median.py:  Using Gurobi for solving the k-median problem.

Problem:
minimize the total (weighted) travel cost from n customers to k facilities.
            
Copyright (c) by Joao Pedro PEDROSO and Mikio KUBO, 2010
"""
from gurobipy import *

def kmedian(m, n, c, k, strong=True):
    model = Model("k-median")
    y = {}
    x = {}
    for j in range(m):
        y[j] = model.addVar(obj=0, vtype="B", name="y[%s]"%j)
        for i in range(n):
            x[i,j] = model.addVar(obj=c[i,j], vtype="B", name="x[%s,%s]"%(i,j))    
    model.update()
    
    for i in range(n):
        coef = [1 for j in range(m)]
        var = [x[i,j] for j in range(m)]
        model.addConstr(LinExpr(coef,var), "=", 1, name="Assign[%s]"%i)

    if strong == True:
        for j in range(m):
            for i in range(n):
                model.addConstr(x[i,j], "<", y[j], name="Strong[%s,%s]"%(i,j))
    else:
        for j in range(m):
            coef = [1 for i in range(n)]
            var = [x[i,j] for i in range(n)]
            model.addConstr(LinExpr(coef,var), "<", LinExpr(n,y[j]), name="Weak[%s]"%j)
        
    coef = [1 for j in range(m)]
    var = [y[j] for j in range(m)]
    model.addConstr(LinExpr(coef,var), "=", rhs=k, name="k_median")              

    model.update()
    model.__data = x,y
    return model


import math
import random
def distance(x1, y1, x2, y2):
    return math.sqrt((x2-x1)**2 + (y2-y1)**2)


def make_data(n):
    # positions of the points in the plane
    x = [random.random() for i in range(n)]
    y = [random.random() for i in range(n)]

    c = {}
    for i in range(n):
        for j in range(n):
            c[i,j] = distance(x[i],y[i],x[j],y[j])

    return c, x, y

                
if __name__ == "__main__":
    import sys
    random.seed(67)
    n = 200
    c, x_pos, y_pos = make_data(n)
    m = n
    k = 20
    model = kmedian(m, n, c, k, strong=True)
    # model.Params.Threads = 1
    model.optimize()
    EPS = 1.e-6
    x,y = model.__data
    edges = [(i,j) for (i,j) in x if x[i,j].X > EPS]
    nodes = [j for j in y if y[j].X > EPS]
    print "Optimal value=", model.ObjVal
    print "Selected nodes:", nodes
    print "Edges:", edges
    print "max c:", max([c[i,j] for (i,j) in edges])

    try: # plot the result using networkx and matplotlib
        import networkx as NX
        import matplotlib.pyplot as P 
        P.ion() # interactive mode on
        G = NX.Graph()

        other = [j for j in y if j not in nodes]
        G.add_nodes_from(nodes)
        G.add_nodes_from(other)
        for (i,j) in edges:
            G.add_edge(i,j)

        position = {}
        for i in range(n):
            position[i]=(x_pos[i],y_pos[i])

        NX.draw(G, position, node_color='y', nodelist=nodes)
        NX.draw(G, position, node_color='g', nodelist=other)
        raw_input("press [enter] to continue")
    except ImportError:
        print "install 'networkx' and 'matplotlib' for plotting"


