#
#     Copyright (C) 2016 CCP-EM
#
#     This code is distributed under the terms and conditions of the
#     CCP-EM Program Suite Licence Agreement as a CCP-EM Application.
#     A copy of the CCP-EM licence can be obtained by writing to the
#     CCP-EM Secretary, RAL Laboratory, Harwell, OX11 0FA, UK.
#


import os
from ccpem_core.ccpem_utils import ccpem_argparser
from ccpem_core import ccpem_utils
from ccpem_core import process_manager
from ccpem_core.tasks import task_utils
from ccpem_core.tasks.tempy.difference_map import difference_map


class DifferenceMap(task_utils.CCPEMTask):
    '''
    CCPEM TEMPy difference map wrapper.
    '''
    task_info = task_utils.CCPEMTaskInfo(
        name='TEMPyDifferenceMap',
        author='A. Joseph, M. Topf, M. Winn',
        version='1.1',
        description=(
            '''Difference map calculation using TEMPy and LocScale libraries.The maps
            are matched by amplitude scaling in resolution shells before 
            calculating difference.<br><br>
            Input two maps or a map and an atomic model.<br>
            - The maps have to be aligned before using as input 
            for this task. If an atomic model is provided, 
            it has to be fitted in the experimental map.<br><br>
            - Local or Global mode can be used for scaling. For local mode,
            scaling in done in local sliding windows and hence local resolution
             variations are considered for scaling. Local mode has MPI option 
             for faster run.
            - A fractional difference map is calculated by default (difference/initial).
            This can be disabled. A difference fraction threshold can be used to 
            mask the difference map as well.
            - Optionally, a dust filter can be applied on the difference maps.
            This is done to remove small isolated densities that usually 
            correspond to noise in the differences. A fractional difference threshold 
            is used to mask first before dusting. The threshold can be set.<br>
            <br>
            '''),
        short_description=(
            'Difference map calculation'),
        documentation_link='',
        references=None)

    commands = {'ccpem-python':
        ['ccpem-python', os.path.realpath(difference_map.__file__)]}

    def __init__(self,
                 database_path=None,
                 args=None,
                 args_json=None,
                 pipeline=None,
                 job_location=None,
                 parent=None):
        #
        super(DifferenceMap, self).__init__(
             database_path=database_path,
             args=args,
             args_json=args_json,
             pipeline=pipeline,
             job_location=job_location,
             parent=parent)

    def parser(self):
        parser = ccpem_argparser.ccpemArgParser()
        #
        job_title = parser.add_argument_group()
        job_title.add_argument(
            '-job_title',
            '--job_title',
            help='Short description of job',
            metavar='Job title',
            type=str,
            default=None)
        #
        job_location = parser.add_argument_group()
        job_location.add_argument(
            '-job_location',
            '--job_location',
            help='Directory to run job',
            metavar='Job location',
            type=str,
            default=None)
        #
        map_path_1 = parser.add_argument_group()
        map_path_1.add_argument(
            '-map_path_1',
            '--map_path_1',
            help='Input map 1 (mrc format)',
            metavar='Input map 1',
            type=str,
            default=None)
        #
        map_resolution_1 = parser.add_argument_group()
        map_resolution_1.add_argument(
            '-map_resolution_1',
            '--map_resolution_1',
            help='''Resolution of map 1 (Angstrom)''',
            metavar='Resolution map 1',
            type=float,
            default=None)
                #
        map_path_2 = parser.add_argument_group()
        map_path_2.add_argument(
            '-map_path_2',
            '--map_path_2',
            help='Input map 2 (mrc format)',
            metavar='Input map 2',
            type=str,
            default=None)
        #
        map_resolution_2 = parser.add_argument_group()
        map_resolution_2.add_argument(
            '-map_resolution_2',
            '--map_resolution_2',
            help='''Resolution of map 2 (Angstrom)''',
            metavar='Resolution map 2',
            type=float,
            default=None)

        self.add_diffmap_specific_arguments(parser)
        return parser
        
    def add_diffmap_specific_arguments(self,parser):
        #
        pdb_path = parser.add_argument_group()
        pdb_path.add_argument(
            '-pdb_path',
            '--pdb_path',
            help=('Synthetic map will be created and difference maps '
                  'calculated using this'),
            metavar='Input PDB',
            type=str,
            default=None)
        #
        map_or_pdb_selection = parser.add_argument_group()
        map_or_pdb_selection.add_argument(
            '-map_or_pdb_selection',
            '--map_or_pdb_selection',
            help=('Select map or PDB'),
            metavar='Input selection',
            choices=['Map', 'Model'],
            type=str,
            default='Map')
        #local or global
        mode_selection = parser.add_argument_group()
        mode_selection.add_argument(
            '-mode_selection',
            '--mode_selection',
            help=('Select local or global scaling mode'),
            metavar='Mode selection',
            choices=['global', 'local','None'],
            type=str,
            default='global')
        parser.add_argument(
            '-noscale',
            '--noscale',
            help=('Do not scale for difference'),
            metavar='Scale the maps?',
            type=bool,
            default=False)
        parser.add_argument(
            '-use_mpi',
            '--use_mpi',
            help=('Use MPI for local scaling'),
            metavar='Use MPI?',
            type=bool,
            default=False)
        parser.add_argument(
            '-n_mpi',
            '--n_mpi',
            help=('Use n proc for mpi'),
            metavar='Use n proc',
            type=int,
            default=1)
        parser.add_argument(
            '-maskfile',
            '--maskfile',
            help=('Input mask map. Local scaling is calculated within this mask'),
            metavar='Input mask map',
            type=str,
            default=None)
        parser.add_argument(
            '-w',
            '--window_size',
            help=('Window size for local scaling (default 7*resolution)'),
            metavar='Window size',
            type=int,
            default=None)
        #
        parser.add_argument(
            '-refscale',
            '--refscale',
            help='''Use second map as reference for scaling''',
            metavar='Use map2 as reference?',
            type=bool,
            default=False)
        #
        parser.add_argument(
            '-threshold_fraction',
            '--threshold_fraction',
            help='''Threshold difference map by fractional difference (difference/initial).'''
                    ''' Set this for dusting (0.35 by default for dusting)''',
            metavar='Cutoff: fractional difference',
            type=float,
            default=None)
        #
        dust_filter = parser.add_argument_group()
        dust_filter.add_argument(
            '-dust_filter',
            '--dust_filter',
            help='''Remove dust beyond a threshold of fractional difference (default: 0.35)''',
            metavar='Dust filter',
            type=bool,
            default=False)
        #
        dust_filter.add_argument(
            '-dustprob',
            '--dustprob',
            help='''Probability of dust size (sizes divided in 20 bins) greather than?'''
                ''' Dusts (small sized densities) are usually more probable than useful differences''',
            metavar='Dust size probability',
            type=float,
            default=0.1
            )
        #
        frac_map = parser.add_argument_group()
        frac_map.add_argument(
            '-save_fracmap',
            '--save_fracmap',
            help='''Calculate fractional difference maps''',
            metavar='Calc fractional difference',
            type=bool,
            default=True)
        #
        map_alignment = parser.add_argument_group()
        map_alignment.add_argument(
            '-map_alignment',
            '--map_alignment',
            help='''Map alignment (~15mins) (optional)''',
            metavar='Align maps',
            type=bool,
            default=False)
        #softmask currently not enabled
        #
#         soft_mask = parser.add_argument_group()
#         soft_mask.add_argument(
#             '-soft_mask',
#             '--soft_mask',
#             help='''Apply soft mask to input maps (optional)''',
#             metavar='Apply soft mask',
#             type=bool,
#             default=False)
        #

    def run_pipeline(self, job_id=None, db_inject=None):
        if self.args.map_or_pdb_selection.value == "Model":
            if self.args.pdb_path.value != None:
                self.args.map_path_2.value = None
        
        # Generate process
        self.tempy_difference_map = TEMPyDifferenceMap(
            job_location=self.job_location,
            command=self.commands['ccpem-python'],
            map_path_1=self.args.map_path_1.value,
            map_or_pdb=self.args.map_or_pdb_selection.value,
            map_path_2=self.args.map_path_2.value,
            map_resolution_1=self.args.map_resolution_1.value,
            map_resolution_2=self.args.map_resolution_2.value,
            pdb_path = self.args.pdb_path.value,
            mode=self.args.mode_selection.value,
            noscale=self.args.noscale.value,
            maskfile=self.args.maskfile.value,
            mpi=self.args.use_mpi.value,
            window_size=self.args.window_size.value,
            threshold_fraction=self.args.threshold_fraction.value,
            dust_filter=self.args.dust_filter.value,
            fracmap=self.args.save_fracmap.value,
            dust_prob = self.args.dustprob.value,
            ref_scale = self.args.refscale.value)
            #map_alignment=self.args.map_alignment.value)
            #soft_mask=self.args.soft_mask.value)

        pl = [[self.tempy_difference_map.process]]
        # pipeline
        self.pipeline = process_manager.CCPEMPipeline(
            pipeline=pl,
            job_id=job_id,
            args_path=self.args.jsonfile,
            location=self.job_location,
            database_path=self.database_path,
            db_inject=db_inject,
            taskname=self.task_info.name,
            title=self.args.job_title.value)
        self.pipeline.start()

class TEMPyDifferenceMap(object):
    '''
    Wrapper for TEMPy difference map process.
    '''
    def __init__(self,
                 job_location,
                 command,
                 map_path_1,
                 map_resolution_1,
                 map_path_2=None,
                 map_resolution_2=None,
                 pdb_path=None,
                 map_or_pdb='Map',
                 mode='global',
                 noscale=False,
                 maskfile=None,
                 mpi=False,
                 window_size=None,
                 threshold_fraction=None,
                 dust_filter=False,
                 fracmap=True,
                 contour_mask=False,
                 map_alignment=False,
                 map_contour_1=None,
                 map_contour_2=None,
                 dust_prob=None,
                 ref_scale = False,
                 name=None):
        
        assert [pdb_path, map_path_2].count(None) != 2
        self.job_location = ccpem_utils.get_path_abs(job_location)
        self.name = name
        if self.name is None:
            self.name = self.__class__.__name__
        self.map_path_1 = ccpem_utils.get_path_abs(map_path_1)
        # Set args
        self.args = ['-m1', self.map_path_1,
                     '-r1', map_resolution_1]
        # Difference map 1 vs map 2 or PDB
        if map_or_pdb == 'Model':
            self.pdb_path = ccpem_utils.get_path_abs(pdb_path)
            self.args += ['-p', self.pdb_path]
        else:
            self.map_path_2 = ccpem_utils.get_path_abs(map_path_2)
            self.args += ['-m2', self.map_path_2]
            self.args += ['-r2', map_resolution_2]
        #scaling mode
        if mode in ['global','local']:
            self.args += ['-mode',mode]
        if noscale:
            self.args += ['--noscale']
        if ref_scale:
            self.args += ['--refscale']
        #
        if dust_filter:
            self.args += ['-dust']
        if not fracmap:
            self.args += ['--nofracmap']
        if threshold_fraction is not None:
            self.args += ['-tf',threshold_fraction]
        if window_size is not None:
            self.args += ['-w',window_size]
        if dust_prob is not None:
            self.args += ['-dp',dust_prob]
#         XXX TODO: Agnel to write code for alignment, may take some time!
#         if map_alignment:
#             self.args += ['-XXX', '???']

        # Set process
        assert command is not None
        self.process = process_manager.CCPEMProcess(
            name=self.name,
            command=command,
            args=self.args,
            location=self.job_location,
            stdin=None)
