#
#     Copyright (C) 2016 CCP-EM
#
#     This code is distributed under the terms and conditions of the
#     CCP-EM Program Suite Licence Agreement as a CCP-EM Application.
#     A copy of the CCP-EM licence can be obtained by writing to the
#     CCP-EM Secretary, RAL Laboratory, Harwell, OX11 0FA, UK.
#
import os
from ccpem_core.ccpem_utils import ccpem_argparser
from ccpem_core.settings import which
from ccpem_core import process_manager
from ccpem_core.tasks import task_utils
from ccpem_core.tasks.refmac import refmac_task
from ccpem_core.tasks.nautilus import nautilus_results


class Buccaneer(task_utils.CCPEMTask):
    '''
    CCPEM Buccaneer task.
    '''
    task_info = task_utils.CCPEMTaskInfo(
        name='Buccaneer',
        author='Cowtan K',
        version='1.6.5',
        description=(
            'Buccaneer performs statistical chain tracing by identifying '
            'connected alpha-carbon positions using a likelihood-based density '
            'target.\n'
            'The target distributions are generated by a simulation '
            'calculation using a known reference structure for which '
            'calculated phases are available. The success of the method is '
            'dependent on the features of the reference structure matching '
            'those of the unsolved, work structure. For almost all cases, a '
            'single reference structure can be used, with modifications '
            'automatically applied to the reference structure to match its '
            'features to the work structure.  N.B. requires CCP4.'),
        short_description=(
            'Automated model building.  Requires CCP4'),
        documentation_link='http://www.ccp4.ac.uk/html/cbuccaneer.html',
        references=None)

    commands = {'cbuccaneer': which('cbuccaneer'),
                'refmac': which('refmac5'),
                'sftools': which('sftools'),
                'freerflag': which('freerflag'),
                'libg': which('libg')
                }

    def __init__(self,
                 database_path=None,
                 args=None,
                 args_json=None,
                 pipeline=None,
                 job_location=None,
                 verbose=False,
                 parent=None):
        super(Buccaneer, self).__init__(
            database_path=database_path,
            args=args,
            args_json=args_json,
            pipeline=pipeline,
            job_location=job_location,
            parent=parent)

    def parser(self):
        parser = ccpem_argparser.ccpemArgParser()
        #
        parser.add_argument(
            '-job_title',
            '--job_title',
            help='Short description of job',
            metavar='Job title',
            type=str,
            default=None)
        #
        parser.add_argument(
            '-mapin',
            '--input_map',
            help='''Target input map (mrc format)''',
            type=str,
            metavar='Input map',
            default=None)
        #
        parser.add_argument(
            '-resolution',
            '--resolution',
            help='''Resolution of input map (Angstrom)''',
            metavar='Resolution',
            type=float,
            default=None)
        #
        parser.add_argument(
            '-input_seq',
            '--input_seq',
            help=('Input sequence file in any common format '
                  '(e.g. pir, fasta)'),
            type=str,
            default=None)
        #
        parser.add_argument(
            '-extend_pdb',
            '--extend_pdb',
            help='Initial PDB model to extend (pdb format)',
            type=str,
            default=None)
        #
        parser.add_argument(
            '-ncycle',
            '--ncycle',
            help='Number of Buccaneer pipeline cycles',
            metavar='Build cycles',
            type=int,
            default=5)
        #
        parser.add_argument(
            '-ncycle_refmac',
            '--ncycle_refmac',
            help='Number of refmac cycles',
            metavar='Refine cycles',
            type=int,
            default=20)
        #
        parser.add_argument(
            '-ncycle_buc1st',
            '--ncycle_buc1st',
            help='Number of Buccaneer cycles in 1st pipeline cycle',
            metavar='1st Buccaneer cycles',
            type=int,
            default=5)
        #
        parser.add_argument(
            '-ncycle_bucnth',
            '--ncycle_bucnth',
            help='Number of Buccaneer cycles in subsequent pipeline cycle',
            metavar='N-th Buccaneer cycles',
            type=int,
            default=5)
        #
        parser.add_argument(
            '-map_sharpen',
            '--map_sharpen',
            metavar='Sharpen / blur',
            help=('B-factor to apply to map. Negative B-factor to '
                  'sharpen, positive to blur, zero to leave map as input'),
            type=float,
            default=0)
        #
        parser.add_argument(
            '-ncpus',
            '--ncpus',
            help='Number of cpus for Buccaneer',
            metavar='CPU',
            type=int,
            default=1)
        #
        parser.add_argument(
            '-lib_in',
            '--lib_in',
            help='Library dictionary file for ligands (lib/cif format)',
            metavar='Ligand dictionary',
            type=str,
            default=None)
        # added
        parser.add_argument(
            '-local_refinement_on',
            '--local_refinement_on',
            help='Refine around radius of supplied molecule only',
            metavar='Local refine',
            type=bool,
            default=False)
        #
        parser.add_argument(
            '-mask_radius',
            '--mask_radius',
            help='Distance around molecule the map should be cut.',
            type=float,
            default=3.0)
        #
        parser.add_argument(
            '-nonprotein_radius_on',
            '--nonprotein_radius_on',
            help=('Check to turn on nonprotein-radius option '
                  'to preserve nonprotein parts of the extend model input.'),
            metavar='Nonprotein-radius',
            type=bool,
            default=False)
        #
        parser.add_argument(
            '-nonprotein_radius',
            '--nonprotein_radius',
            help=('Set radius value for nonprotein-radius. '
                  'Default = 2.0 Angstrom'),
            metavar='Nonprotein mask radius',
            type=float,
            default=2)
        #
        parser.add_argument(
            '-libg',
            '--libg',
            help='Use LIBG to generate restraints for nucleic acids',
            metavar='Nucleic acids',
            type=bool,
            default=False)
        #
        parser.add_argument(
            '-libg_selection',
            '--libg_selection',
            help='Specify nucleotide ranges for libg (please refer '
                 'to libg documentation). If you ignore this field, '
                 'all nucleic acids will be selected for restraint generation',
            metavar='Nucleotide range (ignore to select all)',
            type=str,
            default='')
        #
        parser.add_argument(
            '-keywords',
            '--keywords',
            help=('Keywords for advanced options.  Select file or '
                  'text'),
            type=str,
            metavar='Keywords',
            default='')
        #
        parser.add_argument(
            '-refmac_keywords',
            '--refmac_keywords',
            help=('Refmac keywords for advanced options. Select file or '
                  'define text'),
            type=str,
            metavar='Refmac Keywords',
            default='')
        return parser

    def run_pipeline(self, job_id=None, run=True, db_inject=None):
        '''
        Generate job classes and process.  Run=false for reloading.
        '''
        # Convert map to mtz (refmac)
        # set F columns label suffix
        bfactor = self.args.map_sharpen()
        if bfactor >= 0:
            sharp_array = None
            blur_array = [bfactor]
            if bfactor == 0:
                label_out_suffix = '0'
            else:
                label_out_suffix = 'Blur_{0:.2f}'.format(
                    self.args.map_sharpen.value)
        else:
            sharp_array = [-1 * bfactor]
            blur_array = None
            label_out_suffix = 'Sharp_{0:.2f}'.format(
                -1 * self.args.map_sharpen.value)

        self.process_maptomtz = refmac_task.RefmacMapToMtz(
            command=self.commands['refmac'],
            resolution=self.args.resolution.value,
            mode='Global',
            name='Map to MTZ',
            job_location=self.job_location,
            map_path=self.args.input_map.value,
            blur_array=blur_array,
            sharp_array=sharp_array)
        # atomsf_path=self.atomsf)
        pl = [[self.process_maptomtz.process]]

        # Set MTZ SigF and FOM with SFTools
        self.process_sftools_buccaneer = SFToolsBuccaneer(
            command=self.commands['sftools'],
            job_location=self.job_location,
            hklin=self.process_maptomtz.hklout_path,
            name='Set MTZ SigF and SG')
        pl.append([self.process_sftools_buccaneer.process])

        # Create R free flags
        hklout = os.path.join(self.job_location,
                              'buccaneer.mtz')
        self.process_free_r_flags = FreeRFlags(
            command=self.commands['freerflag'],
            job_location=self.job_location,
            #    hklin=self.process_cad_labels.hklout,
            hklin=self.process_sftools_buccaneer.hklout,
            hklout=hklout,
            name='Set Rfree')
        pl.append([self.process_free_r_flags.process])

        # Save seq if sequence string is provided rather than file path
        if (not os.path.exists(self.args.input_seq.value) and
                isinstance(self.args.input_seq.value, str)):
            path = os.path.join(self.job_location,
                                'input.seq')
            f = open(path, 'w')
            f.write(self.args.input_seq.value)
            f.close()
            self.args.input_seq.value = f.name

        # Add optional refmac keywords
        refine_keywords = ''
        # make newligand noexit to prevent refmac from exiting if no ligand dictionary found,
        # very rare unless is a new ligand
        if self.args.lib_in.value:
            refine_keywords += 'MAKE NEWLigand Noexit\n'
        refine_keywords += add_refmac_keywords(self.args.refmac_keywords.value)

        # Run Buccaneer pipeline
        # Buccaneer->{1}->{2}->Refmac refine ->repeat...{1}libg, {2}maptomtz
        # local
        for i in range(1, (self.args.ncycle.value + 1)):
            # 1st pipeline cycle Buccaneer/Refmac
            if i == 1:
                runbuc_ncycle = self.args.ncycle_buc1st.value
                pdbin_path = self.args.extend_pdb.value
            else:
                runbuc_ncycle = self.args.ncycle_bucnth.value
                pdbin_path = self.process_refine.pdbout_path

            self.process_buccaneer_pipeline = BuccaneerPipeline(
                command=self.commands['cbuccaneer'],
                job_location=self.job_location,
                hklin=self.process_free_r_flags.hklout,
                seqin=self.args.input_seq.value,
                label_out_suffix=label_out_suffix,
                ncycle=i,
                buc_cycle=runbuc_ncycle,
                resolution=self.args.resolution.value,
                pdbin=pdbin_path,
                pdbout=None,
                name='Buccaneer build {0}'.format(str(i)),
                job_title=self.args.job_title.value,
                ncpus=self.args.ncpus.value,
                nonprotein_on=self.args.nonprotein_radius_on.value,
                nonprotein_rad=self.args.nonprotein_radius.value,
                keywords=self.args.keywords.value)
            pl.append([self.process_buccaneer_pipeline.process])

            # Run libg for DNA/RNA restraints if True
        #$LIBG -p $fin -o ${name}/basepair.restraints -w bp >>& ${logf}
        #$LIBG -p $fin -o ${name}/stack.restraints -w sp >>& ${logf}
        #$LIBG -p $fin -o ${name}/pucker.restraints -w pu >>& ${logf}
            if self.args.libg.value:
                if self.args.libg_selection.value:
                    if len(self.args.libg_selection.value.strip()) > 0:
                        self.process_libg = refmac_task.LibgRestraints(
                            command=self.commands['libg'],
                            job_location=self.job_location,
                            name='DNA or RNA basepair restraints {0}'.format(
                                str(i)),
                            pdb_path=self.process_buccaneer_pipeline.pdbout,
                            u=self.args.libg_selection.value.strip())
                    else:
                        self.process_libg = refmac_task.LibgRestraints(
                            command=self.commands['libg'],
                            job_location=self.job_location,
                            name='DNA or RNA basepair restraints {0}'.format(
                                str(i)),
                            pdb_path=self.process_buccaneer_pipeline.pdbout)
                else:
                    self.process_libg = refmac_task.LibgRestraints(
                        command=self.commands['libg'],
                        job_location=self.job_location,
                        name='DNA or RNA basepair restraints {0}'.format(
                            str(i)),
                        pdb_path=self.process_buccaneer_pipeline.pdbout)
                pl.append([self.process_libg.process])
                libg_restraints_path = [self.process_libg.libg_restraints_path]
            else:
                libg_restraints_path = []

            # Run RefmacRefine (global or local)
            # Invert sharpening value for refmac
            if self.args.map_sharpen() is None:
                sharp = 0
            else:
                sharp = -1.0 * self.args.map_sharpen()

            if self.args.local_refinement_on():
                self.process_maptomtz_local = refmac_task.RefmacMapToMtz(
                    command=self.commands['refmac'],
                    name='Map to MTZ (local) {0}'.format(str(i)),
                    resolution=self.args.resolution.value,
                    mode='Local',
                    job_location=self.job_location,
                    lib_path=self.args.lib_in.value,
                    map_path=self.args.input_map.value,
                    pdb_path=self.process_buccaneer_pipeline.pdbout,
                    blur_array=blur_array,
                    sharp_array=sharp_array,
                    mrad=self.args.mask_radius.value)
                pl.append([self.process_maptomtz_local.process])
                # Refinement parameters: Local
                mode = 'Local'
                name = 'Refmac refine (local) {0}'.format(str(i))
                pdb_path_a = self.process_maptomtz_local.pdbout_path
                mtz_path_a = 'masked_fs.mtz'
            else:
                # Refinement parameters: Global
                mode = 'Global'
                name = 'Refmac refine (global) {0}'.format(str(i))
                pdb_path_a = self.process_buccaneer_pipeline.pdbout
                mtz_path_a = self.process_free_r_flags.hklout

            self.process_refine = refmac_task.RefmacRefine(
                command=self.commands['refmac'],
                job_location=self.job_location,
                pdb_path=pdb_path_a,
                mtz_path=mtz_path_a,
                pdbout_path=os.path.join(
                    self.job_location,
                    'refined{0}.pdb'.format(str(i))),
                resolution=self.args.resolution.value,
                mode=mode,  # 'Global or 'Local',
                name=name,  # 'Refmac refine (global) or (local) ' + str(i),
                sharp=sharp,
                ncycle=self.args.ncycle_refmac.value,
                output_hkl=True,
                lib_path=self.args.lib_in.value,
                symmetry_auto=False,
                keywords=refine_keywords,
                libg_restraints=self.args.libg.value,
                libg_restraints_path=libg_restraints_path)
            pl.append([self.process_refine.process])

        custom_finish = BuccaneerResultsOnFinish(
            pipeline_path=self.job_location + '/task.ccpem',
            refine_process=self.process_refine)

        if run:
            os.chdir(self.job_location)
            self.pipeline = process_manager.CCPEMPipeline(
                pipeline=pl,
                job_id=job_id,
                args_path=self.args.jsonfile,
                location=self.job_location,
                database_path=self.database_path,
                db_inject=db_inject,
                taskname=self.task_info.name,
                title=self.args.job_title.value,
                verbose=self.verbose,
                on_finish_custom=custom_finish)
            self.pipeline.start()


class SFToolsBuccaneer(object):
    '''
    STFTools process to set space group to p1 and SIGFem FOMem columns
    '''

    def __init__(self,
                 job_location,
                 hklin,
                 command=which('sftools'),
                 name=None):
        assert command is not None
        self.job_location = job_location
        self.hklin = hklin
        self.hklout = os.path.join(job_location,
                                   'sftools.mtz')
        self.name = name
        if self.name is None:
            self.name = self.__class__.__name__
        self.stdin = None
        #
        self.set_args()
        self.set_stdin()
        #
        self.process = process_manager.CCPEMProcess(
            name=self.name,
            command=command,
            args=self.args,
            location=self.job_location,
            stdin=self.stdin)

    def set_args(self):
        self.args = []

    def set_stdin(self):
        self.stdin = '''
READ {0}
SET SPACEGROUP
P 1
CALC Q COL SIGFem = 1
CALC W COL FOMem = 1
WRITE {1}
END'''.format(self.hklin, self.hklout)


class FreeRFlags(object):
    '''
    Tags each reflection in an MTZ file with a flag for cross-validation.
    '''

    def __init__(self,
                 job_location,
                 hklin,
                 command=which('freerflag'),
                 hklout=None,
                 free_r_frac=0.05,
                 name=None):
        assert command is not None
        self.job_location = job_location
        self.hklin = hklin
        self.hklout = hklout
        if self.hklout is None:
            self.hklout = os.path.join(job_location,
                                       'freerflags.mtz')
        self.name = name
        if self.name is None:
            self.name = self.__class__.__name__
        self.free_r_frac = free_r_frac
        self.stdin = None
        #
        self.set_args()
        self.set_stdin()
        #
        self.process = process_manager.CCPEMProcess(
            name=self.name,
            command=command,
            args=self.args,
            location=self.job_location,
            stdin=self.stdin)

    def set_args(self):
        self.args = ['hklin', self.hklin]
        self.args += ['hklout', self.hklout]

    def set_stdin(self):
        self.stdin = '''
FREERFRAC {0}
'''.format(self.free_r_frac)


class BuccaneerPipeline(object):
    '''
    Statistical protein chain tracing (N.B. runs cbuccaneer)
    '''

    def __init__(self,
                 command,
                 job_location,
                 hklin,
                 seqin,
                 label_out_suffix,
                 ncycle=1,
                 buc_cycle=5,
                 resolution=2.0,
                 pdbin=None,
                 pdbout=None,
                 name=None,
                 job_title=None,
                 ncpus=1,
                 nonprotein_on=False,
                 nonprotein_rad=-1.0,
                 keywords=None):
        # counter for overall BuccaneerPipeline cycles, start with 1
        assert command is not None
        self.job_location = job_location
        self.hklin = hklin
        self.seqin = seqin
        assert os.path.exists(path=self.seqin)
        self.label_out_suffix = label_out_suffix
        self.pdbout = pdbout
        self.ncycle = ncycle
        self.buc_cycle = buc_cycle
        self.resolution = resolution
        self.pdbin = pdbin
        if self.pdbout is None:
            self.pdbout = os.path.join(
                job_location,
                'build' + str(ncycle) + '.pdb')
        self.name = name
        if self.name is None:
            self.name = self.__class__.__name__
        self.job_title = job_title
        self.ncpus = ncpus
        self.stdin = None
        self.stdin_extra = None
        self.keywords = keywords
        self.nonprotein_on = nonprotein_on
        if self.nonprotein_on:
            self.keywords += '\nnonprotein-radius {0}\n'.format(nonprotein_rad)

        #
        self.set_args()
        self.set_stdin()
        #
        self.process = process_manager.CCPEMProcess(
            name=self.name,
            command=command,
            args=self.args,
            location=self.job_location,
            stdin=self.stdin)

    def set_args(self):
        self.args = ['-stdin']

    def set_stdin(self):
        # Use the EM reference structures if available
        pdbin_ref = os.path.join(os.environ['CLIBD'],
                                 'reference_structures/reference-EMD-4116.pdb')
        mtzin_ref = os.path.join(os.environ['CLIBD'],
                                 'reference_structures/reference-EMD-4116.mtz')
        if not (os.path.isfile(pdbin_ref) and os.path.isfile(mtzin_ref)):
            print('EM reference structures not found. Are you running '
                  'an up-to-date copy of CCP4?')
            print('Buccaneer will fall back to the crystallographic '
                  'reference structures')
            pdbin_ref = os.path.join(os.environ['CLIBD'],
                                     'reference_structures/reference-1tqw.pdb')
            mtzin_ref = os.path.join(os.environ['CLIBD'],
                                     'reference_structures/reference-1tqw.mtz')

        self.stdin = '''
title {0}
pdbin-ref {1}
mtzin-ref {2}
colin-ref-fo FP.F_sigF.F,FP.F_sigF.sigF
colin-ref-hl FC.ABCD.A,FC.ABCD.B,FC.ABCD.C,FC.ABCD.D
seqin {3}
mtzin {4}
colin-fo Fout{5},SIGFem
colin-free FreeR_flag
colin-phifom Pout0,FOMem
pdbout {6}
cycles {7}
anisotropy-correction
fast
correlation-mode
resolution {8}
xmlout program{9}.xml
jobs {10}
'''.format(self.job_title,
           pdbin_ref,
           mtzin_ref,
           self.seqin,
           self.hklin,
           self.label_out_suffix,
           self.pdbout,
           self.buc_cycle,
           self.resolution,
           self.ncycle,
           self.ncpus)
        #
        #self.stdin = self.stdin + self.stdin_extra
        # Add optional buccaneer keywords
        if isinstance(self.keywords, str):
            if self.keywords != '':
                # Remove trailing white space and new lines
                keywords = self.keywords.strip()
                for line in keywords.split('\n'):
                    self.stdin += line + '\n'
        if self.pdbin is not None:
            self.stdin += 'pdbin {0}\n'.format(self.pdbin)


class BuccaneerResultsOnFinish(process_manager.CCPEMPipelineCustomFinish):
    '''
    Generate RVAPI results on finish.
    '''

    def __init__(self,
                 pipeline_path,
                 refine_process):
        super(BuccaneerResultsOnFinish, self).__init__()
        self.pipeline_path = pipeline_path
        self.refine_process = refine_process

    def on_finish(self, parent_pipeline=None):
        # generate RVAPI report
        nautilus_results.PipelineResultsViewer(
            pipeline_path=self.pipeline_path)


def add_refmac_keywords(refkeywords):
    '''
    Add default and additional refmac keywords
    '''
    refine_keywords = ''
#    refine_keywords='''PHOUT
# PNAME buccaneer
# DNAME buccaneer
#'''.format(label_out_suffix)

    if isinstance(refkeywords, str):
        if refkeywords != '':
            # Remove trailing white space and new lines
            keywords = refkeywords.strip()
            for line in keywords.split('\n'):
                refine_keywords += line + '\n'

    return refine_keywords
