#=============================================================================
# Monolithic FE^2
# Nils Lange, Geralf Huetter, Bjoern Kiefer
#   Nils.Lange@imfd.tu-freiberg.de, Geralf.Huetter@imfd.tu-freiberg.de, 
#   Bjoern.Kiefer@imfd.tu-freiberg.de
# distributed under CC BY-NC-SA 4.0 license
# (https://creativecommons.org/licenses/by-nc-sa/4.0/)
# Reference: 
#   N. Lange, G. Huetter, B. Kiefer: "An efficient monolithic solution scheme
#                                     for FE2 problems",
#   DOI: https://doi.org/10.1016/j.cma.2021.113886
#   N. Lange, G. Huetter, B. Kiefer: "A monolithic hyper ROM FE2 method with
#                                     clustered training at finite deformations"
#   DOI: https://doi.org/10.1016/j.cma.2023.116522
#
# Further information on the implementation, structure of the source code,
# examples and tutorials can be found in the file doc/documentation.pdf
#
#=============================================================================

from datetime import datetime
import multiprocessing as mp
import gc #for freeing memory
import sys
from scipy.optimize import nnls #scientific python nonnegative least square algorithm
import numpy as np #import numpy for numerical tools
import os
import ctypes as ct
import platform

def write_results_to_inputfile(path_to_inputfile,Method,displacement_basis=None,set_ELEM=None,element_mult_factor=None):
    
    print ('Start writing results to input file.')
    
    try:
        
        #delete the end of file line from the inputfile
        with open(path_to_inputfile,'r') as f:
            lines=f.readlines()
        
        with open(path_to_inputfile,'w') as f:
            for l in lines:
                if not l[:12].upper()=='*END_OF_FILE':
                    f.write(l)
                else:
                    break
        
        with open(path_to_inputfile,'a') as f:
            
            if Method=='ROM':
                #write modes to inputfile
                f.write('*ROM_Modes, N='+str(np.shape(displacement_basis)[0])+'\n')
                for i in range(np.shape(displacement_basis)[1]):
                    for j in range(np.shape(displacement_basis)[0]):
                        f.write(str(displacement_basis[j,i]))
                        if j==np.shape(displacement_basis)[0]-1:
                            f.write('\n')
                        else:
                            f.write(', ')
            
            elif Method=='hyperROM':
                
                #write labels of the elements that are active into the inputfile
                active_elements=[0]*len(element_mult_factor)
                for k in range(1,len(set_ELEM)+1):
                    if set_ELEM[k]<0:
                        active_elements[abs(set_ELEM[k])-1]=k
                f.write('*Active_Elements,N='+str(len(active_elements))+'\n')
                for k in str(active_elements):
                    if not (k=='[' or k==']'):
                        f.write(k)
                f.write('\n')
                #write the respective element multiplication factors
                for k in str(list(np.reshape(element_mult_factor,(np.shape(element_mult_factor)[0])))):
                    if not (k=='[' or k==']'):
                        f.write(k)
                f.write('\n')
            
            f.write('*End_of_File')
    
    except:
        
        raise Exception('Inputfile not found.')

def get_u_snaps(n_snaps,ndof,ncpus):
    #read the elastic and inelastic displacement snapshots using a FORTRAN routine
    #directly return the left singular vectors in u_snapshots_el resp. u_snapshots_inel
    
    u_snapshots_el=np.zeros((ndof,len(n_snaps)),order='F',dtype=ct.c_double)
    u_snapshots_inel=np.zeros((ndof,np.sum(n_snaps,dtype=int)-len(n_snaps)),order='F',dtype=ct.c_double)
    n_SV_el=ct.c_int(0) #returns the actual number of elastic singular vectors
    n_SV_inel=ct.c_int(0) #returns the actual number of inelastic singular vectors
    
    #call FORTRAN routine in Python style to geht the elastic/inelastic snapshots
    AuxilaryRoutines.get_u_snapshots(u_snapshots_el.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(u_snapshots_el)[0]),ct.c_int(np.shape(u_snapshots_el)[1]),
                                     u_snapshots_inel.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(u_snapshots_inel)[1]),
                                     n_snaps.ctypes.data_as(ct.POINTER(ct.c_int)),ct.c_int(ncpus),ct.byref(n_SV_el),ct.byref(n_SV_inel))
    
    #get the actual left singular vectors (elastic and inelastic) for output
    LSV_u_Snaps=np.zeros((ndof,n_SV_el.value+n_SV_inel.value),order='F',dtype=ct.c_double)
    LSV_u_Snaps[:,:n_SV_el.value]=u_snapshots_el[:,:n_SV_el.value]
    if n_SV_inel.value>0:
        LSV_u_Snaps[:,n_SV_el.value:n_SV_el.value+n_SV_inel.value]=u_snapshots_inel[:,:n_SV_inel.value]
    
    return LSV_u_Snaps

def get_f_snaps(n_snaps,NELEM_full,ncpus):
    #read the force snapshots of the original elements using a FORTRAN routine, directly return the
    #left singular vectors in the same matrix
    
    F_Snaps=np.zeros((NELEM_full,np.sum(n_snaps,dtype=int)),dtype=ct.c_double,order='F')
    n_F_SV=ct.c_int(0) #returns the actual number of inelastic singular vectors
    
    #call FORTRAN routine in Python style to get the left singular values of the force snapshot matrix
    AuxilaryRoutines.get_f_snapshots(F_Snaps.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(F_Snaps)[0]),ct.c_int(np.shape(F_Snaps)[1]),
                                     n_snaps.ctypes.data_as(ct.POINTER(ct.c_int)),ct.c_int(np.shape(n_snaps)[0]),ct.byref(n_F_SV),ct.c_int(ncpus))
    
    return F_Snaps[:,:n_F_SV.value]

def empirical_hyper_element_integration_method(J,b,NELEM,ncpus):
    
    ################################################################################################
    # get active elements and multiplication factor by th empirical hyper element integration method
    ################################################################################################
    
    NRHS=np.shape(J)[0] #corresponds to maxmimal number choosable hyper elements
    NELEM_full=np.shape(J)[1] #number of all available elements
    
    B=np.array([i for i in b[:,0]]) #because nnls solver needs different format
    r=np.copy(b) #initialize the residual
    
    J_tilde=np.zeros((NRHS-1,NELEM_full),order='F',dtype=ct.c_double) #normalized J -> J_tilde
    J_norm=np.zeros(NELEM_full,dtype=ct.c_double) 
    for i in range(NELEM_full):
        J_norm[i]=np.linalg.norm(J[:NRHS-1,i])
        if J_norm[i]>0.0000001:
            J_tilde[:,i]=J[:NRHS-1,i]/J_norm[i]
        else:
            J_norm[i]=0.0
    
    #Allocate memory for scalar product between normalized J and residual r
    correlation_J_tilde_r=np.zeros((NELEM_full),order='F',dtype=ct.c_double)
    correlation_J_tilde_r_full=np.zeros((NELEM_full),order='F',dtype=ct.c_double)
    
    #Allocate memory for J_z
    J_z=np.asfortranarray(np.zeros((NRHS,NELEM),order='F',dtype=ct.c_double))
    
    iter_max=10*NELEM #maximum number of iterations
    iteration=0
    
    excluded_elements=[] #exclude some elements from the algorithm if the lead to negative multiplication factors and produce an inifinite loop
    
    m=0 #current number of hyper elements
    beginning=True
    set_ELEM={}
    i=np.argmax(J_norm)
    K=1
    for k in range(NELEM_full):
        if k==i:
            set_ELEM[k+1]=-1
        else:
            set_ELEM[k+1]=K
            K=K+1
    
    print('-------- EHEIM algorithm --------')
    
    while (m<NELEM and iteration<iter_max): #for loop until all hyper elements and their corresponding multiplication factors are found
        
        #look if points must be exluded and accordingly update set_ELEM
        if len(excluded_elements)>0:
            set_excluded_elements=list(set(excluded_elements))
            if len(set_excluded_elements)>0:
                for k in set_excluded_elements:
                    set_ELEM[k]=0
                K=1
                for k in range(1,NELEM_full+1):
                    if set_ELEM[k]>0:
                        set_ELEM[k]=K
                        K=K+1
        
        if beginning: #at the beginning (1 element) just solve the least square problem
            beginning=False
        else: #find the next hyper element
            AuxilaryRoutines.multiplication(J_tilde.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(J_tilde)[0]),ct.c_int(np.shape(J_tilde)[1]),
                                            r.ctypes.data_as(ct.POINTER(ct.c_double)),correlation_J_tilde_r_full.ctypes.data_as(ct.POINTER(ct.c_double)),
                                            ct.c_int(ncpus))
            for k in range(1,NELEM_full+1):
                if set_ELEM[k]>0:
                    correlation_J_tilde_r[set_ELEM[k]-1]=correlation_J_tilde_r_full[k-1]
            
            i=np.argmax(correlation_J_tilde_r[:NELEM_full-m])
            #move the element to the hyper element set, remove it from the set with all other elements
            K=1
            for k in range(1,NELEM_full+1):
                if set_ELEM[k]>0:
                    if set_ELEM[k]==i+1:
                        set_ELEM[k]=-m-1
                    else:
                        set_ELEM[k]=K
                        K=K+1
        #Build up J_z
        for k in range(1,NELEM_full+1):
            if set_ELEM[k]<0:
                J_z[:,abs(set_ELEM[k])-1]=J[:,k-1]
        
        #unrestricted least square solving
        element_mult_factors=np.zeros((m+1,1),order='F',dtype=ct.c_double)
        AuxilaryRoutines.least_squares(J_z.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(J_z)[0]),ct.c_int(m+1),b.ctypes.data_as(ct.POINTER(ct.c_double)),
                                       element_mult_factors.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(ncpus))
        
        if np.any(element_mult_factors<0.0): #check whether any entries of alpha are smaller than zero
            print('Nonnegative Problem')
            #restricted non negative least square solving
            [Element_mult_factors,rnorm]=nnls(J_z[:,0:m+1],B)
            for o in range(m+1): #put element_mult_factors in same format as above
                element_mult_factors[o,0]=Element_mult_factors[o]
            if np.any(element_mult_factors==0.0):
                o=-1
                K=1
                Element_mult_factors=np.copy(element_mult_factors)
                element_mult_factors=np.zeros((np.count_nonzero(element_mult_factors),1),dtype=ct.c_double)
                for k in range(1,NELEM_full+1):
                    if set_ELEM[k]<0:
                        if Element_mult_factors[abs(set_ELEM[k])-1,0]==0.0:
                            set_ELEM[k]=K
                            K=K+1
                            excluded_elements.append(k)
                        else:
                            element_mult_factors[abs(o)-1,0]=Element_mult_factors[abs(set_ELEM[k])-1,0]
                            set_ELEM[k]=o
                            o=o-1
                    else:
                        set_ELEM[k]=K
                        K=K+1
                #Build up J_z again
                for k in range(1,NELEM_full+1):
                    if set_ELEM[k]<0:
                        J_z[:,abs(set_ELEM[k])-1]=J[:,k-1]
        else:
            if len(excluded_elements)>0:
                K=1
                for k in range(1,NELEM_full+1):
                    if set_ELEM[k]>=0:
                        set_ELEM[k]=K
                        K=K+1
            excluded_elements=[]
        #compute current residual
        m=len(element_mult_factors)
        r=b-np.matmul(J_z[:,:m],element_mult_factors)
        print('iteration '+str(iteration)+', residual: '+str(np.linalg.norm(r)))
        iteration=iteration+1
    
    if (iteration==iter_max and m<NELEM):
        raise Exception('Could not find '+str(NGP)+' hyper elements in '+str(iteration)+' iterations!')
    
    return set_ELEM,element_mult_factors
    
def get_nbr_snaps(kind):
    #get the number of snapshots stored, in the case of displacement snaps ==number timesteps, in case of already
    #compressed force snapshots the actual files have to be opened
    
    filename={'displacement':'training_data-u-','force':'training_data-f-'}
    
    n_snaps=[]
    while(True):
        try:
            f=open(filename[kind]+str(len(n_snaps)+1)+'.txt')
            n_snaps.append(0) #get the number of snaps in the file
            for line in f:
                n_snaps[-1]=n_snaps[-1]+1
            f.close()
        except:
            break
    n_snaps=np.array(n_snaps,dtype=ct.c_int,order='F')
    
    if len(n_snaps)==0:
        raise Exception('No training data found!')
    else:
        ndof=len(line.split(','))
    
    return n_snaps,ndof
    
def evaluate_training_data(path_to_inputfile,n_modes,NELEM,ncpus,Method='ROM'):
    
    #load some fortran routines for evaluating the training data
    if platform.system()=='Linux':
        ending='so'
    elif platform.system()=='Windows':
        ending='dll'
    else:
        raise Exception('Unknown operating system.')
    path_to_AuxilaryRoutines=os.path.dirname(os.path.abspath(__file__))+'/AuxilaryRoutines.'+ending
    if not os.path.exists(path_to_AuxilaryRoutines):
        raise Exception('The compiled version of AuxilaryRoutines must be present in the abaqus_plugins/MonolithFE2 folder. Please compile the source first.')
    try:
        AuxilaryRoutines = ct.CDLL(path_to_AuxilaryRoutines)
    except:
        raise Exception('Problems occured when trying to load the AuxilaryRoutines library. Possibly something went wrong during compilation.')
    
    if (ncpus>mp.cpu_count()):
        raise Exception('Not enough processes available')
    elif (ncpus<1):
        raise Exception('Number of cpus must be greater equal one.')
    
    print ('Start Evaluating the training data.')
    print (datetime.now())
    
    #change to the directory with the training data
    os.chdir(os.path.dirname(path_to_inputfile))
    
    if Method=='ROM':
        
        #get number of the training sets available
        n_snaps,ndof=get_nbr_snaps('displacement')
        
        #get elastic and inelastic snapshots using a FORTRAN routine
        LSV_u_Snaps=get_u_snaps(n_snaps,ndof,ncpus)
        
        if n_modes<1:
            raise Exception('Number of inelastic modes must be greater equal one!')
        elif n_modes>np.shape(LSV_u_Snaps)[1]:
            raise Exception('Number of modes larger than number of availabe left singular vectors('+str(np.shape(LSV_u_Snaps)[1])+')!')
        else:
            #get the displacement basis as the set of the first n_modes of dominant left singular values required
            displacement_basis=LSV_u_Snaps[:,:n_modes]
        
        #finally write the displacement basis to the inputfile
        write_results_to_inputfile(path_to_inputfile,Method,displacement_basis=displacement_basis)
        
    elif Method=='hyperROM':
        
        #get number of the training sets available
        n_snaps,NELEM_full=get_nbr_snaps('force')
        
        #get force snapshots using a FORTRAN routine
        LSV_F_Snaps=get_f_snaps(n_snaps,NELEM_full,ncpus)
                
        print('Element selection')
        
        if NELEM>np.shape(LSV_F_Snaps)[1]+1:
            raise Exception('number of chosen hyper Elements ('+str(NELEM)+') higher than available Elements.('+str(np.shape(LSV_F_Snaps)[1]+1)+')')
        
        #read reference element volumes
        try:
            with open('Element_Volume.txt','r') as f:
                Element_ref_Volume=np.array([float(w) for w in f.readlines()[0].split(',')],dtype=ct.c_double,order='F')
        except:
            raise Exception('File with reference element volume not found.')
        
        J=np.zeros((NELEM,np.shape(LSV_F_Snaps)[0]),dtype=ct.c_double,order='F')
        J[:NELEM-1,:]=np.transpose(LSV_F_Snaps[:,:NELEM-1]); J[-1:,:]=Element_ref_Volume
        b=np.matmul(J,np.ones((NELEM_full,1))) #right hand side in EHEIM
        
        #finally call the EHEIM function
        set_ELEM,element_mult_factor=empirical_hyper_element_integration_method(J,b,NELEM,ncpus)
        
        #write the results to the MonolithFE2 Inputfile
        write_results_to_inputfile(path_to_inputfile,Method,set_ELEM=set_ELEM,element_mult_factor=element_mult_factor)
        
    else:
        raise Exception('Evaluation method must be ROM or hyperROM.')
    
    print ('Finished.')
    print (datetime.now())
    
if __name__=='__main__':
    
    #if the script is started directly from terminal, call evaluate_training_data with the supplied arguments
    
    keyword_dictionary={"path_to_inputfile":None,"n_modes":0,"NELEM":0,"ncpus":1,"Method":'ROM'}
    
    for i in range(1,len(sys.argv)):
        KW=sys.argv[i].split("=")
        if not(KW[0] in keyword_dictionary):
            raise Exception("Keyword "+KW[0]+" misspelled.")
        else:
            keyword_dictionary[KW[0]]=KW[1]
    
    #finally call the actual simulation
    evaluate_training_data(keyword_dictionary["path_to_inputfile"],int(keyword_dictionary["n_modes"]),
                           int(keyword_dictionary["NELEM"]),int(keyword_dictionary["ncpus"]),
                           keyword_dictionary["Method"])
