'''
PD = vector of sets of imediate predecessors for each node in the precedence diagram
Nodes are enumerated 0, 1, 2, ... n-1

getP(PD) returns a vector of sets of all predecessors for each node (not only the imediate ones)
getS(PD) returns a vector of sets of all sucessors for each node (not only the imediate ones)



'''


import math
import numpy as np
import lpLib as lp
import time

def fillP_i(PD,i):
    '''
    Find all predecessors of node (task) i
    Recursive function. Starts with an empty set of predecessors, and add all
    nodes when traversing the precedence diagram from node i and then to the left

    '''
    global predecessors
    for j in PD[i]:
        if j not in predecessors:
            predecessors.append(j)
        fillP_i(PD,j)    
        
def getP(PD): # Creates and returns the vector of sets of all predecessors for each node (task)
    global predecessors
    P = []
    for i in range(len(PD)):
        predecessors = []
        fillP_i(PD,i)
        P.append(predecessors)
    return P

def isSuccessor(PD,candidate,mother):
    '''
    For <candidate> check for each predecessor if it matches <mother>, if so
    <candidate> is a successor of <mother>
    Repeate recursively to the left to check for "pre-predecessors" etc.
    '''
    global successorFound
    for j in PD[candidate]:
        if j == mother:
            successorFound = True
            return True
        isSuccessor(PD,j,mother) # Continue to the left
    return successorFound    

def getS(PD): # Creates and returns the vector of sets of all successors for each node (task)
    global successorFound 
    S = []
    for i in range(len(PD)):
        successors = []
        for j in range(len(PD)):
            successorFound = False
            if isSuccessor(PD,j,i):
                successors.append(j)
        S.append(successors)
    return S

def getE(t,c,P): # Creates and returns a vector of erliest stations for each node (task)
    E = []
    for i, P_i in enumerate(P):
        sm = t[i]
        for j in P_i:
            sm += t[j]
        if sm == 0:
            e = 1
        else:
            e = math.ceil(sm/c)
        E.append(e-1) # Note, indexing starts at 0 (and not 1) in Python!
    return E

def getL(t,c,L,m): # Creates and returns a vector of latest stations for each node (task)
    L = []
    for i, S_i in enumerate(S):
        sm = t[i]
        for j in S_i:
            sm += t[j]
        if sm == 0:
            l = m
        else:
            l = m + 1 - math.ceil(sm/c)
        L.append(l-1)
    return L

def ndx(i,j,n): # x_ij is stored in a vector rather than matrix, we need an index function for task i and station j
    return i + j*n

def occurenceConstraints(m,E,L):
    n = len(E)
    for i in range(n):
        A_row = np.zeros([n*m])        
        for j in range(E[i],L[i]+1):
            A_row[ndx(i,j,n)] = 1
        lp.equalityConstraint(1,A_row)
        A_row = np.zeros([n*m])
'''
        for j in range(m):  # Constraints to avoid task earlier or later than E_i and L_i
            if (j < E[i] or j >L[i]):
                A_row[ndx(i,j,n)] = 1
        lp.equalityConstraint(0,A_row)
'''
def canDoConstraints(n,W):
    m = len(W)
    for i in range(n):
        for j, W_j in enumerate(W):
            if not (i in W_j):
                A_row = np.zeros([n*m])
                A_row[ndx(i,j,n)] = 1
                lp.equalityConstraint(0,A_row)             


def precedenceConstraints(PD,m,E,L):
    n = len(E)
    for b, PD_b in enumerate(PD):
        for a in PD_b:
            A_row = np.zeros([n*m])
            L_a = max(L[a],E[b])
            E_b = min(L[a],E[b])           
            for j in range(E[a],L_a+1):
                A_row[ndx(a,j,n)] = (j+1)
            for k in range(E_b,L[b]+1):
                A_row[ndx(b,k,n)] = -(k+1)
            lp.upperBound(0,A_row)
    
def cycleTimeConstraints(t,c,W):
    n = len(t)
    m = len(W)
    for j, W_j in enumerate(W):
        A_row = np.zeros([n*m])
        for i in W_j:
            A_row[ndx(i,j,n)] = t[i]
        lp.upperBound(c,A_row)

def SX(t,W,x):
    s = 0
    m = len(W)
    t2=0
    for t_i in t:
        t2 += t_i
    for j, W_j in enumerate(W):
        A_row = np.zeros([n*m])
        t1=0      
        for i in W_j:
            t1 += t[i]*x[ndx(i,j,n)]
        s += abs(t1 - t2/m)
    return s
'''

Main program starts here

'''

PD = [[],[0],[0],[0],[3],[1,2,4],[5],[6]]
W = [[0,1,2,3,4],[1,2,3,4,5,6],[5,6,7]]
#W = [[0,1,2,],[1,2,3,4,5,6],[5,6,7]]
t = [2.2,3.5,3.0,2.8,3.5,2.3,2.3,2.5]
c = 10.0
names = ["A","B","C","D","E","F","G","H"]
P = getP(PD)
S = getS(PD)
m = len(W) # Number of stations
n = len(PD)
E = getE(t,c,P)
L = getL(t,c,S,m)

print("\nList of earliest and latest stations, given m =",m)
print("i  E_i L_i")
for  i,name in enumerate(names):
    print(name+" ",E[i]," ",L[i])
# The lpLib library requires the problem to be defined in terms of # of variables, use dummy...    
lp.coefficients(np.zeros([n*m]))
# Define all constraints:
occurenceConstraints(m,E,L)
precedenceConstraints(PD,m,E,L)
cycleTimeConstraints(t,c,W)
'''
Sort out non-feasible solutions, i.e., those that cannot be done by a station, and those
ones that are outside the "earliest-latest" interval
'''
xNotFeasible = np.zeros([n*m]) 
for i in range(n):
    A_row = np.zeros([n*m])        
    for j in range(m):  # Constraints to avoid task earlier or later than E_i and L_i
        if (j < E[i] or j >L[i]):
            xNotFeasible[ndx(i,j,n)] = 1
    for j, W_j in enumerate(W): # Constraints to avoid tasks that cannot be assigned to a station
        if not (i in W_j):
            xNotFeasible[ndx(i,j,n)] = 1
start_time = time.time()
x_combinations = lp.all_binary_combinations(n*m-int(sum(xNotFeasible))) # Generate all feasible x-vectors
end_time = time.time()
print("\nNumber of combinations",len(x_combinations))
print(f"Generated in: {end_time - start_time} seconds")

'''

The getFullX function takes as argument a given x-vector (among feasible x's) and fill these
values into the full x-vector. We can then use the resulting vector for the complete
validation of a solution

Example:

xNotFeasible = [0,0,1,1,1,0,0,1]
x_comb = [1,0,1,0]

result        =[1,0,0,0,0,1,0,0]

'''

def getFullX(x_comb,xNotFeasible): # Fill only with feasible x-values
    x = np.zeros(len(xNotFeasible))
    j = 0
    for i, test in enumerate(xNotFeasible):
        if test == 0:
            x[i] = x_comb[j]
            j += 1
    return x

'''

Test all possible combinations if they are feasible, and search for the best one

'''
best_x = []
best_SX = 1e30
start_time = time.time()
for x_comb in x_combinations:
    x = getFullX(x_comb,xNotFeasible)
    if sum(x) == n: # one task can only be executed at one station
        if lp.isFeasible(x):
            print("\nFeasable Solution found:")
            for j in range(m):
                print("Station =",str(j+1)+":")
                for i in range(n):
                    if x[ndx(i,j,n)] == 1:
                        print(names[i])
            if SX(t,W,x)< best_SX:
                best_SX = SX(t,W,x)
                best_x = x
            print("SX =", SX(t,W,x))

print("\n*** Best solution: ***")
for j in range(m):
    print("Station =",str(j+1)+":")
    for i in range(n):
        if best_x[ndx(i,j,n)] == 1:
            print(names[i])
print("SX(best x) =", best_SX)
end_time = time.time()
print("\nNumber of combinations tested",len(x_combinations))
print(f"Validating solutions in: {end_time - start_time} seconds")
