#=============================================================================
# Monolithic FE^2
# Nils Lange, Geralf Huetter, Bjoern Kiefer
#   Nils.Lange@imfd.tu-freiberg.de, Geralf.Huetter@imfd.tu-freiberg.de, 
#   Bjoern.Kiefer@imfd.tu-freiberg.de
# distributed under CC BY-NC-SA 4.0 license
# (https://creativecommons.org/licenses/by-nc-sa/4.0/)
# Reference: 
#   N. Lange, G. Huetter, B. Kiefer: "An efficient monolithic solution scheme for FE2 problems",
#   https://arxiv.org/abs/2101.01802
#
# Further information on the implementation, structure of the source code,
# examples and tutorials can be found in the file doc/documentation.pdf
#
#=============================================================================

from datetime import datetime
import multiprocessing as mp
import gc #for freeing memory
import sys
from scipy.optimize import nnls #scientific python nonnegative least square algorithm
import numpy as np #import numpy for numerical tools
import os
import ctypes as ct
import platform

#load some fortran routines for evaluating the training data
if platform.system()=='Linux':
    ending='so'
elif platform.system()=='Windows':
    ending='dll'
else:
    raise Exception('Unknown operating system.')
path_to_AuxilaryRoutines=os.path.dirname(os.path.abspath(__file__))+'/AuxilaryRoutines.'+ending
if not os.path.exists(path_to_AuxilaryRoutines):
    raise Exception('The compiled version of AuxilaryRoutines must be present in the abaqus_plugins/MonolithFE2 folder. Please compile the source first.')
try:
    AuxilaryRoutines = ct.CDLL(path_to_AuxilaryRoutines)
except:
    raise Exception('Problems occured when trying to load the AuxilaryRoutines library. Possibly something went wrong during compilation.')

def write_results_to_inputfile(path_to_inputfile,Method,displacement_basis=None,set_IP=None,weights=None):
    
    print ('Start writing results to input file.')
    
    try:
        with open(path_to_inputfile,'a') as f:
            if Method=='ROM':
                #write modes to inputfile
                f.write('*ROM_Modes, N='+str(np.shape(displacement_basis)[1])+'\n')
                for i in range(np.shape(displacement_basis)[1]):
                    for j in range(np.shape(displacement_basis)[0]):
                        f.write(str(displacement_basis[j,i]))
                        if j==np.shape(displacement_basis)[0]-1:
                            f.write('\n')
                        else:
                            f.write(', ')
            elif Method=='hyperROM':
                #write active integration points
                f.write('*Integration_Points, N='+str(len(set_IP))+'\n')
                for k in range(1,len(set_IP)+1):
                    if set_IP[k]<0:
                        f.write(str(abs(set_IP[k])))
                    else:
                        f.write('0')
                    if k<len(set_IP):
                        f.write(', ')
                f.write('\n')
                #write integration point weights
                f.write('*IP_Weights, N='+str(len(weights))+'\n')
                for w in range(len(weights)):
                    f.write(str(weights[w,0]))
                    if w==len(weights)-1:
                        f.write('\n')
                    else:
                        f.write(', ')
    except:
        raise Exception('Inputfile not found.')

def get_u_snaps(n_snaps,ndof,ncpus):
    #read the elastic and inelastic displacement snapshots using a FORTRAN routine
    #directly return the left singular vectors in u_snapshots_el resp. u_snapshots_inel
    
    u_snapshots_el=np.zeros((ndof,len(n_snaps)),order='F',dtype=ct.c_double)
    u_snapshots_inel=np.zeros((ndof,np.sum(n_snaps,dtype=int)-len(n_snaps)),order='F',dtype=ct.c_double)
    n_SV_el=ct.c_int(0) #returns the actual number of elastic singular vectors
    n_SV_inel=ct.c_int(0) #returns the actual number of inelastic singular vectors
    
    #call FORTRAN routine in Python style to geht the elastic/inelastic snapshots
    AuxilaryRoutines.get_u_snapshots(u_snapshots_el.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(u_snapshots_el)[0]),ct.c_int(np.shape(u_snapshots_el)[1]),
                                     u_snapshots_inel.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(u_snapshots_inel)[1]),
                                     n_snaps.ctypes.data_as(ct.POINTER(ct.c_int)),ct.c_int(ncpus),ct.byref(n_SV_el),ct.byref(n_SV_inel))
    
    #get the actual left singular vectors (elastic and inelastic) for output
    LSV_u_Snaps=np.zeros((ndof,n_SV_el.value+n_SV_inel.value),order='F',dtype=ct.c_double)
    LSV_u_Snaps[:,:n_SV_el.value]=u_snapshots_el[:,:n_SV_el.value]
    if n_SV_inel.value>0:
        LSV_u_Snaps[:,n_SV_el.value:n_SV_el.value+n_SV_inel.value]=u_snapshots_inel[:,:n_SV_inel.value]
    
    return LSV_u_Snaps

def get_f_snaps(n_snaps,nIPs,ncpus):
    #read the force snapshots of the original GPs using a FORTRAN routine, directly return the
    #left singular vectors in the same matrix
    
    F_Snaps=np.zeros((nIPs,np.sum(n_snaps,dtype=int)),dtype=ct.c_double,order='F')
    n_F_SV=ct.c_int(0) #returns the actual number of inelastic singular vectors
    
    #call FORTRAN routine in Python style to get the left singular values of the force snapshot matrix
    AuxilaryRoutines.get_f_snapshots(F_Snaps.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(F_Snaps)[0]),ct.c_int(np.shape(F_Snaps)[1]),
                                     n_snaps.ctypes.data_as(ct.POINTER(ct.c_int)),ct.c_int(np.shape(n_snaps)[0]),ct.byref(n_F_SV),ct.c_int(ncpus))
    
    return F_Snaps[:,:n_F_SV.value]

def empirical_cubature_method(J,b,NGP,ncpus):
    
    ###########################################################################################
    # get integration points for hyperintegration using ECM method proposed by Herandez et. al
    ###########################################################################################
    
    NRHS=np.shape(J)[0] #corresponds to maxmimal number choosable hyper integration points
    NGP_total=np.shape(J)[1] #number of all available integration points
    
    B=np.array([i for i in b[:,0]]) #because nnls solver needs different format
    r=np.copy(b) #initialize the residual
    
    J_tilde=np.zeros((NRHS-1,NGP_total),order='F',dtype=ct.c_double) #normalized J -> J_tilde
    J_norm=np.zeros(NGP_total,dtype=ct.c_double) 
    for i in range(NGP_total):
        J_norm[i]=np.linalg.norm(J[:NRHS-1,i])
        if J_norm[i]>0.0000001:
            J_tilde[:,i]=J[:NRHS-1,i]/J_norm[i]
        else:
            J_norm[i]=0.0
    
    #Allocate memory for scalar product between normalized J and residual r
    correlation_J_tilde_r=np.zeros((NGP_total),order='F',dtype=ct.c_double)
    correlation_J_tilde_r_full=np.zeros((NGP_total),order='F',dtype=ct.c_double)
    
    #Allocate memory for J_z
    J_z=np.asfortranarray(np.zeros((NRHS,NGP),order='F',dtype=ct.c_double))
    
    iter_max=10*NGP #maximum number of iterations
    iteration=0
    
    excluded_points=[] #exclude some points from the algorithm if the lead to negative weights and produce a inifinite loop
    
    m=0 #current number of hyperintegration points
    beginning=True
    set_IP={}
    i=np.argmax(J_norm)
    K=1
    for k in range(NGP_total):
        if k==i:
            set_IP[k+1]=-1
        else:
            set_IP[k+1]=K
            K=K+1
    
    print('-------- ECM algorithm --------')
    
    while (m<NGP and iteration<iter_max): #for loop until all integration points and their corresponding weights are found
        
        #look if points must be exluded and accordingly update set_IP
        if len(excluded_points)>0:
            set_excluded_points=list(set(excluded_points))
            if len(set_excluded_points)>0:
                for k in set_excluded_points:
                    set_IP[k]=0
                K=1
                for k in range(1,NGP_total+1):
                    if set_IP[k]>0:
                        set_IP[k]=K
                        K=K+1
        
        if beginning: #at the beginning (1 IP) just solve the least square problem
            beginning=False
        else: #find the next hyperintegration point
            AuxilaryRoutines.multiplication(J_tilde.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(J_tilde)[0]),ct.c_int(np.shape(J_tilde)[1]),
                                            r.ctypes.data_as(ct.POINTER(ct.c_double)),correlation_J_tilde_r_full.ctypes.data_as(ct.POINTER(ct.c_double)),
                                            ct.c_int(ncpus))
            for k in range(1,NGP_total+1):
                if set_IP[k]>0:
                    correlation_J_tilde_r[set_IP[k]-1]=correlation_J_tilde_r_full[k-1]
            
            i=np.argmax(correlation_J_tilde_r[:NGP_total-m])
            #move the point to the hyper integration set, remove it from the set with all other points
            K=1
            for k in range(1,NGP_total+1):
                if set_IP[k]>0:
                    if set_IP[k]==i+1:
                        set_IP[k]=-m-1
                    else:
                        set_IP[k]=K
                        K=K+1
        #Build up J_z
        for k in range(1,NGP_total+1):
            if set_IP[k]<0:
                J_z[:,abs(set_IP[k])-1]=J[:,k-1]
        #unrestricted least square solving
        
        weights=np.zeros((m+1,1),order='F',dtype=ct.c_double)
        AuxilaryRoutines.least_squares(J_z.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(np.shape(J_z)[0]),ct.c_int(m+1),b.ctypes.data_as(ct.POINTER(ct.c_double)),
                                       weights.ctypes.data_as(ct.POINTER(ct.c_double)),ct.c_int(ncpus))
        
        if np.any(weights<0.0): #check whether any entries of alpha are smaller than zero
            print('Nonnegative Problem')
            #restricted non negative least square solving
            [Weights,rnorm]=nnls(J_z[:,0:m+1],B)
            for o in range(m+1): #put weights in same format as above
                weights[o,0]=Weights[o]
            if np.any(weights==0.0):
                o=-1
                K=1
                Weights=np.copy(weights)
                weights=np.zeros((np.count_nonzero(weights),1),dtype=ct.c_double)
                for k in range(1,NGP_total+1):
                    if set_IP[k]<0:
                        if Weights[abs(set_IP[k])-1,0]==0.0:
                            set_IP[k]=K
                            K=K+1
                            excluded_points.append(k)
                        else:
                            weights[abs(o)-1,0]=Weights[abs(set_IP[k])-1,0]
                            set_IP[k]=o
                            o=o-1
                    else:
                        set_IP[k]=K
                        K=K+1
                #Build up J_z again
                for k in range(1,NGP_total+1):
                    if set_IP[k]<0:
                        J_z[:,abs(set_IP[k])-1]=J[:,k-1]
        else:
            if len(excluded_points)>0:
                K=1
                for k in range(1,NGP_total+1):
                    if set_IP[k]>=0:
                        set_IP[k]=K
                        K=K+1
            excluded_points=[]
        #compute current residual
        m=len(weights)
        r=b-np.matmul(J_z[:,:m],weights)
        print('iteration '+str(iteration)+', residual: '+str(np.linalg.norm(r)))
        iteration=iteration+1
    
    if (iteration==iter_max and m<NGP):
        raise Exception('Could not find '+str(NGP)+' hyper integration points in '+str(iteration)+' iterations!')
    
    return set_IP,weights
    
def get_nbr_snaps(kind):
    #get the number of snapshots stored, in the case of displacement snaps ==number timesteps, in case of already
    #compressed force snapshots the actual files have to be opened
    
    filename={'displacement':'training_data-u-','force':'training_data-f-'}
    
    n_snaps=[]
    while(True):
        try:
            f=open(filename[kind]+str(len(n_snaps)+1)+'.txt')
            n_snaps.append(0) #get the number of snaps in the file
            for line in f:
                n_snaps[-1]=n_snaps[-1]+1
            f.close()
        except:
            break
    n_snaps=np.array(n_snaps,dtype=ct.c_int,order='F')
    
    if len(n_snaps)==0:
        raise Exception('No training data found!')
    else:
        ndof=len(line.split(','))
    
    return n_snaps,ndof
    
def evaluate_training_data(path_to_inputfile,n_modes,NGP,ncpus,Method='ROM',from_stored_SVD=False,store_SVD=True):
    
    if (ncpus>mp.cpu_count()):
        raise Exception('Not enough processes available')
    elif (ncpus<1):
        raise Exception('Number of cpus must be greater equal one.')
    
    print ('Start Evaluating the training data.')
    print (datetime.now())
    
    #change to the directory with the training data
    os.chdir(os.path.dirname(path_to_inputfile))       #'/'+'/'.join(path_to_inputfile.split('/')[:-1])+'/')
    
    if Method=='ROM':
        
        stored_SVD_not_found=False
        if from_stored_SVD:
            try: #read SVD of displacement snapshots from file
                
                print('Start reading SVD of displacement snapshots from file.')
                with open('LSV_u_Snaps.txt','r') as lsv_u_Snaps:
                    LSV_u_Snaps=[]
                    for line in lsv_u_Snaps:
                        LSV_u_Snaps.append([float(l) for l in line.split(',')])
                    LSV_u_Snaps=np.array(LSV_u_Snaps,order='F',dtype=ct.c_double)
                
            except:
                stored_SVD_not_found=True
                print('No stored SVD of displacement snapshots found.')
            
        if (not from_stored_SVD) or stored_SVD_not_found:
            
            #get number of the training sets available
            n_snaps,ndof=get_nbr_snaps('displacement')
            
            #get elastic and inelastic snapshots using a FORTRAN routine
            LSV_u_Snaps=get_u_snaps(n_snaps,ndof,ncpus)
            
            #save the singular value decomposition to a textfile
            if store_SVD:
                with open('LSV_u_Snaps.txt','w') as lsv_u_Snaps:
                    for i in range(np.shape(LSV_u_Snaps)[0]):
                        for j in range(np.shape(LSV_u_Snaps)[1]):
                            lsv_u_Snaps.write(str(LSV_u_Snaps[i,j]))
                            if j<np.shape(LSV_u_Snaps)[1]-1:
                                lsv_u_Snaps.write(',')
                        lsv_u_Snaps.write('\n')
            
        if n_modes<1:
            raise Exception('Number of inelastic modes must be greater equal one!')
        elif n_modes>np.shape(LSV_u_Snaps)[1]:
            raise Exception('Number of modes larger than number of availabe left singular vectors('+str(np.shape(LSV_u_Snaps)[1])+')!')
        else:
            #get the displacement basis as the set of the first n_modes of dominant left singular values required
            displacement_basis=LSV_u_Snaps[:,:n_modes]
        
        #finally write the displacement basis to the inputfile
        write_results_to_inputfile(path_to_inputfile,Method,displacement_basis=displacement_basis)
        
    elif Method=='hyperROM':
        
        stored_SVD_not_found=False
        
        if from_stored_SVD:
            try: #read SVD of displacement snapshots from file
                
                print('Start reading SVD of force snapshots from file.')
                
                with open('LSV_F_Snaps.txt','r') as lsv_F_Snaps:
                    LSV_F_Snaps=[]
                    for line in lsv_F_Snaps:
                        LSV_F_Snaps.append([float(l) for l in line.split(',')])
                    LSV_F_Snaps=np.array(LSV_F_Snaps,order='F',dtype=ct.c_double)
                
            except:
                stored_SVD_not_found=True
                print('No stored SVD of force snapshots found.')
                
        if (not from_stored_SVD) or stored_SVD_not_found:
            
            #get number of the training sets available
            n_snaps,nIPs=get_nbr_snaps('force')
            
            #get force snapshots using a FORTRAN routine
            LSV_F_Snaps=get_f_snaps(n_snaps,nIPs,ncpus)
            
            if store_SVD:
                with open('LSV_F_Snaps.txt','w') as lsv_F_Snaps:
                    for i in range(np.shape(LSV_F_Snaps)[0]):
                        for j in range(np.shape(LSV_F_Snaps)[1]):
                            lsv_F_Snaps.write(str(LSV_F_Snaps[i,j]))
                            if j<np.shape(LSV_F_Snaps)[1]-1:
                                lsv_F_Snaps.write(',')
                        lsv_F_Snaps.write('\n')
        
        print('GP selection')
        
        if NGP>np.shape(LSV_F_Snaps)[1]+1:
            raise Exception('number of chosen GPs ('+str(NGP)+') higher than available GPs.('+str(np.shape(LSV_F_Snaps)[1]+1)+')')
        
        #read integration point weights
        try:
            with open('IP_Weights.txt','r') as f:
                IPWeights=np.array([[float(w)] for w in f.readlines()[0].split(',')],dtype=ct.c_double,order='F')
        except:
            raise Exception('File with integration point weights not found.')
        
        #set up the needed basis matrix and rhs for the ECM algorithm
        J=np.zeros((NGP,np.shape(LSV_F_Snaps)[0]),dtype=ct.c_double,order='F')
        J[:NGP-1,:]=np.transpose(LSV_F_Snaps[:,:NGP-1]); J[-1:,:]=1.0/np.sqrt(np.shape(IPWeights)[0])
        b=np.matmul(J,IPWeights) #right hand side in ECM
        
        #finally call the ECM function
        set_IP,weights=empirical_cubature_method(J,b,NGP,ncpus)
        
        #write the results to the MonolithFE2 Inputfile
        write_results_to_inputfile(path_to_inputfile,Method,set_IP=set_IP,weights=weights)
        
    else:
        raise Exception('Evaluation method must be ROM or hyperROM.')
    
    print ('Finished.')
    print (datetime.now())
    
if __name__=='__main__':
    
    #if the script is started directly from terminal, call evaluate_training_data with the supplied arguments
    
    keyword_dictionary={"path_to_inputfile":None,"n_modes":0,"NGP":0,"ncpus":1,"Method":'ROM',"from_stored_SVD":False,"store_SVD":True}
    
    for i in range(1,len(sys.argv)):
        KW=sys.argv[i].split("=")
        if not(KW[0] in keyword_dictionary):
            raise Exception("Keyword "+KW[0]+" misspelled.")
        else:
            keyword_dictionary[KW[0]]=KW[1]
    
    #finally call the actual simulation
    evaluate_training_data(keyword_dictionary["path_to_inputfile"],int(keyword_dictionary["n_modes"]),
                           int(keyword_dictionary["NGP"]),int(keyword_dictionary["ncpus"]),
                           keyword_dictionary["Method"],bool(int(keyword_dictionary["from_stored_SVD"])),
                           bool(int(keyword_dictionary["store_SVD"])))
