#
#     Copyright (C) 2018 CCP-EM
#
#     This code is distributed under the terms and conditions of the
#     CCP-EM Program Suite Licence Agreement as a CCP-EM Application.
#     A copy of the CCP-EM licence can be obtained by writing to the
#     CCP-EM Secretary, RAL Laboratory, Harwell, OX11 0FA, UK.
#


import os
from loc_scale import np_locscale_fft
import mrcfile
from ccpem_core.ccpem_utils import ccpem_argparser
from ccpem_core import ccpem_utils
from ccpem_core import process_manager
from ccpem_core.tasks import task_utils
from ccpem_core.tasks.refmac import refmac_task
from ccpem_core import settings
from ccpem_core.tasks.bin_wrappers import mapmask
from ccpem_core.tasks.bin_wrappers import fft


mrc_axis = {1: 'x',
            2: 'y',
            3: 'z'}

class LocScale(task_utils.CCPEMTask):
    '''
    CCPEM LocScale Task.
    '''
    task_info = task_utils.CCPEMTaskInfo(
        name='LocScale',
        author='A. J. Jakobi, M. Wilmanns, C. Sachse',
        version='1.1',
        description=(
            'Local amplitude sharpening based target structure'),
        short_description=(
            'Local amplitude sharpening based target structure'
            'Coefficient'),
        documentation_link='https://git.embl.de/jakobi/LocScale/wikis/home',
        references=None)
    commands = {
        'loc_scale_python': ['ccpem-python',
                             os.path.realpath(np_locscale_fft.__file__)],
        'refmac': settings.which(program='refmac5'),
        'pdbset':  settings.which(program='pdbset')}

    def __init__(self,
                 database_path=None,
                 args=None,
                 args_json=None,
                 pipeline=None,
                 job_location=None,
                 parent=None):
        super(LocScale, self).__init__(database_path=database_path,
                                       args=args,
                                       args_json=args_json,
                                       pipeline=pipeline,
                                       job_location=job_location,
                                       parent=parent)

    def parser(self):
        parser = ccpem_argparser.ccpemArgParser()
        #
        job_title = parser.add_argument_group()
        job_title.add_argument(
            '-job_title',
            '--job_title',
            help='Short description of job',
            metavar='Job title',
            type=str,
            default=None)
        #
        job_location = parser.add_argument_group()
        job_location.add_argument(
            '-job_location',
            '--job_location',
            help='Directory to run job',
            metavar='Job location',
            type=str,
            default=None)
        #
        target_map = parser.add_argument_group()
        target_map.add_argument(
            '-target_map',
            '--target_map',
            help='Target map to be auto sharpened (mrc format)',
            metavar='Target map',
            type=str,
            default=None)
        #
        reference_model_map = parser.add_argument_group()
        reference_model_map.add_argument(
            '-reference_map',
            '--reference_map',
            help=(
                'Reference map (mrc format). If provided this will be '
                'used as reference otherwise model will be'),
            metavar='Reference map',
            type=str,
            default=None)
        #
        mask_map = parser.add_argument_group()
        mask_map.add_argument(
            '-mask_map',
            '--mask_map',
            help='Mask map (mrc format)',
            metavar='Mask map',
            type=str,
            default=None)
        #
        parser.add_argument(
            '-resolution',
            '--resolution',
            help='Resolution of target map (Angstrom)',
            metavar='Resolution',
            type=float,
            default=None)
        #
        parser.add_argument(
            '-pixel_size',
            '--pixel_size',
            help=('Pixel size in Angstrom, if not defined it will be '
                  'calculated from the header'),
            metavar='Pixel size',
            type=float,
            default=None)
        #
        parser.add_argument(
            '-reference_model',
            '--reference_model',
            help='Reference model (PDB format)',
            metavar='Reference model',
            type=str,
            default=None)
        #
        parser.add_argument(
            '-lib_in',
            '--lib_in',
            help='Ligand dictionary (cif format)',
            metavar='Ligand dictionary',
            type=str,
            default=None)
        # 
        parser.add_argument(
            '-refine_bfactors',
            '--refine_bfactors',
            help='Refine reference structure B-factors using Refmac',
            metavar='Refine B-factors',
            type=bool,
            default=True)
        #
        parser.add_argument(
            '-window_size',
            '--window_size',
            help=('Window size in pixel, if not given 7 * map resolution '
                  '/ pixel size used'),
            metavar='Window size',
            type=int,
            default=None)
        #
        parser.add_argument(
            '-use_mpi',
            '--use_mpi',
            help='Use mpi for parallel processing',
            metavar='Use MPI',
            type=bool,
            default=True)
        #
        parser.add_argument(
            '-n_mpi',
            '--n_mpi',
            help='Number of mpi nodes',
            metavar='MPI nodes',
            type=int,
            default=2)
        #
        return parser

    def run_pipeline(self, job_id=None, db_inject=None):
        # Calculate pixel size
        with mrcfile.open(self.args.target_map(), 'r') as mrc:
            map_nx = mrc.header.nx
            map_ny = mrc.header.ny
            map_nz = mrc.header.nz
            map_mapc = mrc.header.mapc
            map_mapr = mrc.header.mapr
            map_maps = mrc.header.maps

            if self.args.pixel_size() is None:
                self.args.pixel_size.value = (mrc.voxel_size.x + 
                                              mrc.voxel_size.y + 
                                              mrc.voxel_size.z) / 3.0
        
#         # Run Refmac to refine reference structure B-factors
#         if self.args.refine_bfactors():
#             self.refmac_process = refmac_task.RefmacRefine(
#                 command=self.command,
#                 job_location=self.job_location,
#                 pdb_path=self.args.reference_model(),
#                 mtz_path=self.process_free_r_flags.hklout,
#                 resolution=self.args.resolution.value,
#                 mode='Global',
#                 name='Refmac refine (global)',
#                 sharp=None,
#                 ncycle=20,
#                 output_hkl=True)
#             pl.append([self.refmac_process.process])
#             # XXX TODO - get refmac structure
#             loc_scale_structure = ''
#         else:
#             loc_scale_structure = self.args.reference_model
        pl = []
        if self.args.reference_map() is None:
            self.reference_map = os.path.join(
                self.job_location,
                'model_reference.mrc')

            # Set PDB cryst from mtz
            self.process_pdb_set = refmac_task.PDBSetCell(
                command=self.commands['pdbset'],
                job_location=self.job_location,
                name='Set PDB cell',
                pdb_path=self.args.reference_model(),
                map_path=self.args.target_map())
            pl = [[self.process_pdb_set.process]]

            self.pdb_set_path = os.path.join(
                self.job_location,
                'pdbset.pdb')

            # Calculate mtz from pdb 
            self.refmac_sfcalc_crd_process = refmac_task.RefmacSfcalcCrd(
                job_location=self.job_location,
                pdb_path=self.pdb_set_path,
                lib_in=self.args.lib_in(),
                resolution=self.args.resolution())
            self.refmac_sfcalc_mtz = os.path.join(
                self.job_location,
                'sfcalc_from_crd.mtz')
            pl.append([self.refmac_sfcalc_crd_process.process])

            # Convert calc mtz to map
            self.fft_process = fft.FFT(
                job_location=self.job_location,
                map_nx=map_nx,
                map_ny=map_ny,
                map_nz=map_nz,
                # For now only use x,y,z
#                 fast=mrc_axis[int(map_mapc)],
#                 medium=mrc_axis[int(map_mapr)],
#                 slow=mrc_axis[int(map_maps)],
                fast='X',
                medium='Y',
                slow='Z',
                mtz_path=self.refmac_sfcalc_mtz,
                map_path=self.reference_map)
            pl.append([self.fft_process.process])

        else:
            self.reference_map = self.args.reference_map()
            with mrcfile.open(self.reference_map, 'r') as mrc:
                ref_mapc = mrc.header.mapc
                ref_mapr = mrc.header.mapr
                ref_maps = mrc.header.maps

            if not all([ref_mapc == 1,
                        ref_mapr == 2,
                        ref_maps == 3]):
                ref_xyz_path = 'xyz_' + os.path.basename(
                    self.reference_map)
                self.mapmask_process = mapmask.MapMask(
                    job_location=self.job_location,
                    fast='X',
                    medium='Y',
                    slow='Z',
                    mapin1=self.reference_map,
                    mapout=ref_xyz_path)
                pl.append([self.mapmask_process.process])
                self.reference_map = os.path.join(
                    self.job_location,
                    ref_xyz_path)

        # Temporary fix to set axis order to x,y,z
        # Target map
        self.target_map = self.args.target_map()
        if not all([map_mapc == 1,
                    map_mapr == 2,
                    map_maps == 3]):
            map_xyz_path = 'xyz_' + os.path.basename(self.target_map)
            self.mapmask_process = mapmask.MapMask(
                job_location=self.job_location,
                fast='X',
                medium='Y',
                slow='Z',
                mapin1=self.target_map,
                mapout=map_xyz_path)
            pl.append([self.mapmask_process.process])
            self.target_map = os.path.join(
                self.job_location,
                map_xyz_path)

        # Mask map
        self.mask_map = self.args.mask_map()
        if self.mask_map is not None:
            with mrcfile.open(self.mask_map, 'r') as mrc:
                mask_mapc = mrc.header.mapc
                mask_mapr = mrc.header.mapr
                mask_maps = mrc.header.maps

            if not all([mask_mapc == 1,
                        mask_mapr == 2,
                        mask_maps == 3]):
                mask_xyz_path = 'xyz_' + os.path.basename(self.mask_map)
                self.mapmask_process = mapmask.MapMask(
                    job_location=self.job_location,
                    fast='X',
                    medium='Y',
                    slow='Z',
                    mapin1=self.mask_map,
                    mapout=mask_xyz_path)
                pl.append([self.mapmask_process.process])
                self.mask_map = os.path.join(
                    self.job_location,
                    mask_xyz_path)

        # LocScale process
        self.loc_scale_process = LocScaleWrapper(
            job_location=self.job_location,
            target_map=self.target_map,
            resolution=self.args.resolution(),
            pixel_size=self.args.pixel_size(),
            reference_model_map=self.reference_map,
            use_mpi=self.args.use_mpi(),
            n_mpi=self.args.n_mpi(),
            mask_map = self.mask_map,
            output_map=None,
            window_size=self.args.window_size(),
            name='LocScale')
        pl.append([self.loc_scale_process.process])

        # pipeline
        self.pipeline = process_manager.CCPEMPipeline(
            pipeline=pl,
            job_id=job_id,
            args_path=self.args.jsonfile,
            location=self.job_location,
            db_inject=db_inject,
            database_path=self.database_path,
            taskname=self.task_info.name,
            title=self.args.job_title.value)
        self.pipeline.start()

class LocScaleWrapper(object):
    '''
    Wrapper for LocScale process.
    '''
    commands = {
    'loc_scale_python': ['ccpem-python',
                         os.path.realpath(np_locscale_fft.__file__)]}

    def __init__(self,
                 job_location,
                 target_map,
                 resolution,
                 pixel_size,
                 reference_model_map,
                 use_mpi=False,
                 n_mpi=None,
                 output_map=None,
                 mask_map=None,
                 window_size=None,
                 name=None):
        self.job_location = ccpem_utils.get_path_abs(job_location)
        self.name = name
        if self.name is None:
            self.name = self.__class__.__name__
        if window_size is None:
            window_size = int((7.0 * resolution) / pixel_size)
        if output_map is None:
            output_map = 'loc_scale.mrc'
        output_map = os.path.join(self.job_location,
                                  output_map)

        self.args = [
            '--em_map', os.path.abspath(target_map),
            '--model_map', os.path.abspath(reference_model_map),
            '--apix', pixel_size,
            '--window_size', window_size,
            '--outfile', output_map,
            '--verbose', 'True']

        if mask_map is not None:
            self.args += ['--mask', mask_map]

        # Set process, use mpi for multithreaded
        if use_mpi and n_mpi > 1:
            command = ['ccpem-mpirun',
                       '-np',
                       str(n_mpi)]
            self.args = self.commands['loc_scale_python'] + self.args
            self.args += ['--mpi']
            self.process = process_manager.CCPEMProcess(
                name=self.name,
                command=command,
                args=self.args,
                location=self.job_location,
                stdin=None)
        else:
            self.process = process_manager.CCPEMProcess(
                name=self.name,
                command=self.commands['loc_scale_python'],
                args=self.args,
                location=self.job_location,
                stdin=None)
