#Added by AGNEL PRAVEEN JOSEPH
#Generate difference maps with amplitude scaling/matching, remove dusts 

import sys
import mrcfile
from TEMPy.StructureParser import PDBParser
from TEMPy.StructureBlurrer import StructureBlurrer
import os
from TEMPy.class_arg import TempyParser
import mrcfile
from ccpem_core.mrcfile_utils import MapCompare
from TEMPy.MapProcess import Filter
from copy import deepcopy
import gc
import numpy as np
#from memory_profiler import profile


def main():
    #print help
    '''
    print 'use --help for help'
    print '-m/-m1 [map] for input map, -m1,-m2 for two input maps'
    print '-p/-p1 [pdb] for input pdb'
    print '-r [resolution]; -r1,r2 for the two map resolutions'
    print '-t [density threhold]; -t1,t2 for the two map thresholds'
    print '--nodust to disable dusting of difference maps'
    print '--softmask to apply softmask for masked map input'
    print '--refscale to consider second map (or model) amplitudes as reference for scaling the first map'
    print '-sw [shellwidth] to change the default shell width (1/Angstroms) for scaling (default 0.02)'
    print '-dp [probability] to provide a probability of finding dust among the difference map densities (default (>)0.2)'
    print '--nofilt to disable lowpass filter before amplitude scaling (not recommended)'
    print '--noscale to disable amplitude scaling (not recommended)'
    '''
    
    tp = TempyParser()
    tp.generate_args()
    #COMMAND LINE OPTIONS
    #map input
    m1 = tp.args.inp_map1
    m2 = tp.args.inp_map2
    m = tp.args.inp_map
    #map resolutions
    r1 = tp.args.res1
    r2 = tp.args.res2
    r = tp.args.res
    #map contours
    c1 = tp.args.thr1
    c2 = tp.args.thr2
    c = tp.args.thr
    #atomic model input
    p = tp.args.pdb
    p1 = tp.args.pdb1
    p2 = tp.args.pdb2
    #voxel size
    apix = tp.args.apix
    #apply contour mask to calculated difference map?
    contour_mask = True
    if tp.args.nomask: 
        contour_mask = False
    #Whether to smooth unmasked region? (for hard masks)
    msk = tp.args.softmask
    #whether to scale amplitudes
    flag_scale = True
    if tp.args.noscale: 
        flag_scale = False
        print 'Warning: scaling disabled!'
    #whether to lowpass filter before scaling
    flag_filt = True
    if tp.args.nofilt:
        flag_filt = False
    #whether to use the second map (model map) as reference
    refsc=tp.args.refscale
    if not tp.args.mode is None: 
        if tp.args.mode == 2: refsc=True
    # width of resolution shell
    sw = tp.args.shellwidth
    #whether to apply dust filter after difference
    flag_dust = True
    if tp.args.nodust: flag_dust = False
    randsize = 0.2
    if flag_dust:
        randsize = tp.args.dustprob
    #print tp.args
    #EXAMPLE RUN
    flag_example = False
    if len(sys.argv) == 1:
        path_example=os.path.join(os.getcwd(),EXAMPLEDIR)
        if os.path.exists(path_example)==True:
            print "%s exists" %path_example
        flag_example = True
    
    #calculate model contour
    def model_contour(p,res=4.0,emmap=False,t=-1.):
        modelmap,modelinstance = blur_model(p,res,emmap)
        contour = None
        if t != -1.0:
            print 'calculating contour'
            contour = t*modelmap.std()#0.0
        return modelmap,contour,modelinstance
    def blur_model(p,res=4.0,emmap=False):
        print 'reading the model'
        structure_instance=PDBParser.read_PDB_file('pdbfile',p,hetatm=False,water=False)
        print 'filtering the model'
        blurrer = StructureBlurrer()
        if res is None: sys.exit('Map resolution required..')
        modelmap = blurrer.gaussian_blur_real_space(structure_instance, 
                                                  res,sigma_coeff=0.187,
                                                  densMap=emmap,normalise=True) 
        return modelmap, structure_instance    
  
    #GET INPUT DATA
    output_synthetic_map = False
    #no map input
    if all(x is None for x in [m,m1,m2]):
        # for 2 models
        if None in [p1,p2]:
            sys.exit('Input two maps or a map and model, \
            map resolution(s) (required) and contours (optional)')
        Name1 = os.path.basename(p1).split('.')[0]
        emmap1,c1,p1inst = model_contour(p1,res=4.0,emmap=False,t=0.1)
        r1 = r2 = r = 4.0
        Name2 = os.path.basename(p2).split('.')[0]
        if c2 is None: 
            emmap2,c2,p2inst = model_contour(p2,res=r,emmap=False,t=0.1)
        else: emmap2,p2inst = blur_model(p2,res=r,emmap=False)
        flag_filt = False
        flag_scale = False
    elif None in [m1,m2]:
        # for one map and model, m = map, c1 = map contour, c2 = model contour
        print 'reading map'
        if m is None and m1 is not None: 
            m = m1
        #read map
        Name1 = os.path.basename(m).split('.')[0]
        mrcobj1=mrcfile.open(m,mode='r')
        emmap1 = Filter(mrcobj1)
        #set contour
        if c1 is None and c is None: 
            c1 = emmap1.calculate_map_contour(sigma_factor=1.5)
        elif c is not None: 
            c1 = c
            
        if r1 is None and r is None: 
            sys.exit('Input two maps or a map and model, \
            map resolution(s) (required) and contours (optional)')
        elif r1 is None: r1 = r
        
        if all(x is None for x in [p,p1,p2]): 
            sys.exit('Input two maps or a map and model, \
            map resolution(s) (required) and contours (optional)')
        elif None in [p1,p2]:
            if p is None and p2 is not None: p = p2
            elif p1 is not None: p = p1  
        r2 = 3.0
        #TODO : fix a model contour
        if r1 > 20.0: mt = 2.0
        elif r1 > 10.0: mt = 1.0
        elif r1 > 6.0: mt = 0.5
        else: mt = 0.1
        if c2 is None:
            Name2 = os.path.basename(p).split('.')[0]
            emmap2,c2,p2inst = model_contour(p,res=r1,emmap=emmap1,t=mt)
        #scale based on the model amplitudes
        refsc = True   
        output_synthetic_map = True
    else: 
        # For 2 input maps
        if None in [r1,r2]: 
            sys.exit('Input two maps, \
            their resolutions(required) and contours(optional)')
        print 'reading map1'
        Name1 = os.path.basename(m1).split('.')[0]
        mrcobj1 = mrcfile.open(m1,mode='r')
        emmap1 = Filter(mrcobj1)
        if c1 is None:
            c1 = emmap1.calculate_map_contour(sigma_factor=1.5)
        print 'reading map2' 
        Name2 = os.path.basename(m2).split('.')[0]
        mrcobj2 = mrcfile.open(m2,mode='r')
        emmap2 = Filter(mrcobj2)
        if c2 is None:
            c2 = emmap2.calculate_map_contour(sigma_factor=1.5)
            
    #MAIN CALCULATION
    gc.collect()
    #check if map objects are from mrcfile
    if emmap1.__class__.__name__ == 'Map':
        emmap1 = Filter(emmap1)
    if emmap2.__class__.__name__ == 'Map':
        emmap2 = Filter(emmap2)
    #whether to shift density to positive values
    '''
    c1 = (c1 - emmap1.min())
    c2 = (c2-emmap2.min())
    emmap1.fullMap = (emmap1.fullMap - emmap1.min())
    emmap2.fullMap = (emmap2.fullMap - emmap2.min())
    '''
    #emmap1._crop_box(c1,2)
    #emmap2._crop_box(c2,2)    
    #TODO: implement the soft masking with mapprocess
    '''
    #if a soft mask has to be applied to both maps
    if msk:
        print 'Applying soft mask'
        emmap1.fullMap = emmap1.soft_mask(c1)
        emmap2.fullMap = emmap2.soft_mask(c2)
    '''
    #scaled maps
    if flag_scale:
        print 'scaling'
        if refsc: print 'Using second model/map amplitudes as reference'
        # amplitude scaling independant of the grid
        emmap_1,emmap_2, dict_plot = MapCompare.amplitude_match(
            emmap1,emmap2,sw,max(r1,r2),lpfiltb=flag_filt,lpfilta=False,ref=refsc)
    else:
        emmap_1 = emmap1.copy(deep=False)
        emmap_2 = emmap2.copy(deep=False)
    gc.collect()
    #find a common box to hold both maps
    try:
        MapCompare.compare_grid(emmap1,emmap2)
        flag_samegrid = True
    except AssertionError: flag_samegrid = False
    #if grid dimensions are different
    if not flag_samegrid:
        spacing = max(max(emmap1.apix),max(emmap2.apix))
        grid_shape, new_ori = MapCompare.alignment_box(emmap1,emmap2,spacing)
    
            #resample scaled maps to the common grid
        #     if apix is None: spacing = max(r1,r2)*0.33
        #     else: spacing = apix
        #interpolate to common grids
        diff1 = emmap_1.interpolate_to_grid(grid_shape,spacing,new_ori,inplace=False)
        diff2 = emmap_2.interpolate_to_grid(grid_shape,spacing,new_ori,inplace=False)
    else:
        diff1 = emmap_1.copy()
        diff2 = emmap_2.copy()
    
    # get mask inside contour for the initial maps
    #modify in place
    emmap_1.fullMap[:] = (emmap1.fullMap>c1)*1.0
    emmap_2.fullMap[:] = (emmap2.fullMap>c2)*1.0
    if not flag_samegrid:
        mask1 = emmap_1.interpolate_to_grid(grid_shape,spacing,new_ori,inplace=False)
        mask2 = emmap_2.interpolate_to_grid(grid_shape,spacing,new_ori,inplace=False)
        mask1.threshold_map(0.1,inplace=True)
        mask2.threshold_map(0.1,inplace=True)
    else:
        mask1 = emmap_1.copy(deep=False)
        mask2 = emmap_2.copy(deep=False)
    del emmap_1, emmap_2
    
    ##min of minimums in the two scaled maps
    min1 = diff1.min()
    min2 = diff2.min()
    min_scaled_maps = min(min1,min2)
    #shift to positive values
    if (min1 <= 0. or min2 <= 0.):
        #make values non-zero
        min_scaled_maps = min_scaled_maps - 0.01*min_scaled_maps
        diff1.shift_density(-min_scaled_maps,inplace=True)
        diff2.shift_density(-min_scaled_maps,inplace=True)
    print 'calculating difference'

    # find difference map and apply contour mask
    scaledmap1 = diff1.copy()
    diff1.fullMap = (diff1.fullMap - diff2.fullMap)
    
    if flag_dust:
        scaledmap2 = diff2.copy()
        
    diff2.fullMap = (diff2.fullMap - scaledmap1.fullMap)
    
    if contour_mask:
        diff1.apply_mask(mask1.fullMap,inplace=True)
        diff2.apply_mask(mask2.fullMap,inplace=True)
    #dust filter
    if flag_dust:
        mask1.fullMap[:] = diff1.fullMap/scaledmap1.fullMap > 0.35
        diff1.apply_mask(mask1.fullMap,inplace=True)
        print 'dusting'
        diff1.label_patches(0.0,
                            prob=randsize,inplace=True)
        mask2.fullMap[:] = diff2.fullMap/scaledmap2.fullMap > 0.35
        diff2.apply_mask(mask2.fullMap,inplace=True)
        diff2.label_patches(0.0,
                            prob=randsize,inplace=True)
        diff1.apply_mask(diff1.fullMap>0.0,inplace=True)
        diff2.apply_mask(diff2.fullMap>0.0,inplace=True)
    del mask1, mask2, scaledmap1, scaledmap2
    gc.collect()
    if not flag_samegrid:
        #interpolate back to original grids
        diff1_inigrid = diff1.interpolate_to_grid(emmap1.fullMap.shape,emmap1.apix,
                                  emmap1.origin,inplace=False)
        diff2_inigrid = diff2.interpolate_to_grid(emmap2.fullMap.shape,emmap2.apix,
                                  emmap2.origin,inplace=False)
    else:
        diff1_inigrid = diff1.copy(deep=False)
        diff2_inigrid = diff2.copy(deep=False)
    
    mapfile_diff1 = Name1+'-'+Name2+'_diff.mrc'
    mapfile_diff2 = Name2+'-'+Name1+'_diff.mrc'
    diffmap1 = mrcfile.new(mapfile_diff1,overwrite=True)
    diffmap2 = mrcfile.new(mapfile_diff2,overwrite=True)
    diff1_inigrid.set_newmap_data_header(diffmap1)
    diff2_inigrid.set_newmap_data_header(diffmap2)
    diffmap1.close()
    diffmap2.close()

    # If PDB given write out synthetic map
    if output_synthetic_map:
        print 'Output synthetic map from : ', Name2
        synmap_file = Name2+'_syn.mrc'
        synmap = mrcfile.new(synmap_file,overwrite=True)
        emmap2.set_newmap_data_header(synmap)
        synmap.close()


if __name__ == '__main__':
    main()
