'''
This library file contains useful functions to simplify specification of matrixes
uesd by the scipy.optimize limnprog function:

Essentially you need to do the following:

1 - Specify the objective function with the coefficients() function
2 - Specify constraint equations one by one, by using:
    a) upperBound() constraint
    b) lowerBound() constraint
    c) equaltiyConstraint()
3 - Get the solution by calling lpSolve()
'''
from scipy.optimize import linprog
import numpy as np
import itertools    
c=[]
A_ub = []
A_eq =[]
b_ub = []
b_eq = []
isMin = True
def coefficients(c_in, isMinProblem = True): # send a Fals flag to "maximize", default is "minimize"
    global isMin, c, A_ub, A_eq, b_ub, b_eq
    A_ub = []
    A_eq =[]
    b_ub = []
    b_eq = []
    isMin = isMinProblem
    c = c_in
    if not isMinProblem:
        for i in range(len(c)):
            c[i] *= -1        

def upperBound(b,A_row): # Define one constraint equaltion, x * A_row <= b
    b_ub.append(b)
    A_ub.append(A_row)

def lowerBound(b,A_row): # Define one constraint equaltion, x * A_row >= b   
    b_ub.append(-b)
    for i in range(len(A_row)):
        A_row[i] *= -1
    A_ub.append(A_row)

def equalityConstraint(b,A_row): # Define one constraint equaltion, x * A_row = b
    b_eq.append(b)
    A_eq.append(A_row)

def balanceConstraint(b,ndx):
    '''
    This function assumes we have a logigal numbering of the x-vector, say
    x = [x_1,x_2,...,x_n]. Then ndx is vector of indexes which are non-zero, and
    further each index can have a negative sign. b is the righthand side. Example:

    b = 100 and ndx = [1,2,-4,7]

    corresponds to: x_1 + x_2 - x_4 + x_7 = 100
    '''
    b_eq.append(b)
    A_row = np.zeros(len(c))    
    for i in ndx:
        A_row[abs(i)-1] = np.sign(i)
    A_eq.append(A_row)
    

def lpSolve(quiet = False):    
    if len(b_eq) > 0 and len(b_ub) > 0:
        res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq )
    elif len(b_ub) > 0:
        res = linprog(c, A_ub=A_ub, b_ub=b_ub )        
    else:
        res = linprog(c,  A_eq=A_eq, b_eq=b_eq )
    if not quiet:
        print("Message from LP solver:",res.message)
    if res.success:
        z = res.fun
        if not isMin:
            z *= -1        
            res.fun *= -1
        if not quiet:    
            print("\nLP-solution\nZ =","{:8.2f}".format(z))
            for i in range(len(c)):
                print("x_",i+1," =","{:7.2f}".format(res.x[i]),sep='')
    return res
'''

The getX-functions are used if a calling functions needs to process the constraint equations 

'''
def getc():
    return c

def getA_ub():
    return A_ub

def getA_eq():
    return A_eq

def getb_ub():
    return b_ub

def getb_eq():
    return b_eq

def isFeasible(x): # Check if a solution is feasable...
    for i, A_row  in enumerate(A_eq): # The enumerate() function is a built-in function that adds a counter to an iterable and returns it as an enumerate object
        if np.dot(x,A_row) != b_eq[i]:# Use dot-product...            
            return False
    for i, A_row  in enumerate(A_ub):    
        if np.dot(x,A_row) > b_ub[i]:            
            return False
    return True

def Z(x):
    return np.dot(x,c)

'''

The all_binary_combinations() function is used to generate all possible
combinations of the x-vector

'''
def all_binary_combinations(n):
    # Use itertools.product to generate all combinations of 0 and 1
    return list(itertools.product([0, 1], repeat=n))
    

if False:
# Test how it works....
# Set the value of n (number of binary digits)
    n = 3
# Generate all binary combinations for n
    x_combinations = all_binary_combinations(n)
# Print the result
    for x in x_combinations:
        print(x)


