#
#     Copyright (C) 2016 CCP-EM
#
#     This code is distributed under the terms and conditions of the
#     CCP-EM Program Suite Licence Agreement as a CCP-EM Application.
#     A copy of the CCP-EM licence can be obtained by writing to the
#     CCP-EM Secretary, RAL Laboratory, Harwell, OX11 0FA, UK.
#


import os
import json
from ccpem_core.ccpem_utils import ccpem_argparser
from ccpem_core import ccpem_utils
from ccpem_core import process_manager
from ccpem_core.tasks import task_utils
from ccpem_core.tasks.tempy.smoc import smoc_process
from ccpem_core.process_manager import job_register
from ccpem_core.tasks.ribfind import ribfind_task
from ccpem_core import settings

class SMOC(task_utils.CCPEMTask):
    '''
    CCPEM / TEMPy difference map wrapper.
    '''
    task_info = task_utils.CCPEMTaskInfo(
        name='TEMPy-LocScore',
        author='A. Joseph, M. Topf',
        version='1.1',
        description=(
            '''Local scoring using TEMPy library.<br><br>
            Input a map and one or more atomic models. The atomic models
            are expected to be fitted in the map.<br><br>
            Two methods are provided to evaluate atomic model fit.<br>
            SMOC: Local fragment Score based on Manders Overlap 
            Coefficient. Scores are calculated either on <br>
            - a local region within a distance from each amino acid or<br>
            - on overlapping fragments of amino acids along each chain.<br> 
            Distance or Fragment length can be adjusted (default setting 
            is based on map resolution).<br>
            SCCC: Cross correlation coefficient for each rigid body 
            segment. Each rigid body segment defined in the rigid body 
            file is scored. If a rigid body file is not provided, it is 
            calculated using RIBFIND. NOTE that the current implementation
            of RIBFIND doesnt support multiple chains, so a rigid body 
            file has to be uploaded for the calculation to work.<br>
            Rigid body file format (two lines):<br>
            10:A 20:A<br>
            30:B 40:B 60:B 70:B<br>
            indicates two rigid bodies, one formed of residues 10 
            to 20 from chain A, and another formed of residues 30 to 40 
            and 60 to 70 of chain B.<br><br>
            The choice of the scoring method is based on map resolution 
            by default. If resolution is better than 7.5A, SMOC score 
            is calculated.
            '''),
        short_description=(
            'SMOC: Local fragment Score based on Manders\' Overlap '
            'Coefficient\n'
            'SCCC: Cross correlation coefficient for each rigid body segment'),
        documentation_link='http://topf-group.ismb.lon.ac.uk/TEMPY.html',
        references=None)

    commands = {'ccpem-python':
        ['ccpem-python', os.path.realpath(smoc_process.__file__)],
        'ccpem-ribfind': settings.which(
                        'ccpem-ribfind'),
        'ccpem-mapprocess':
        ['ccpem-python', os.path.realpath(smoc_process.__file__)]}


    def __init__(self,
                 database_path=None,
                 args=None,
                 args_json=None,
                 pipeline=None,
                 job_location=None,
                 parent=None):
        #
        super(SMOC, self).__init__(
            database_path=database_path,
            args=args,
            args_json=args_json,
            pipeline=pipeline,
            job_location=job_location,
            parent=parent)

    def parser(self):
        parser = ccpem_argparser.ccpemArgParser()
        #
        job_title = parser.add_argument_group()
        job_title.add_argument(
            '-job_title',
            '--job_title',
            help='Short description of job',
            metavar='Job title',
            type=str,
            default=None)
        #
        job_location = parser.add_argument_group()
        job_location.add_argument(
            '-job_location',
            '--job_location',
            help='Directory to run job',
            metavar='Job location',
            type=str,
            default=None)
        #
        map_path = parser.add_argument_group()
        map_path.add_argument(
            '-map_path',
            '--map_path',
            help='Input map (mrc format)',
            metavar='Input map',
            type=str,
            default=None)
        #
        parser.add_argument(
            '-input_pdbs',
            '--input_pdbs',
            help=('Input pdb(s)'
                  'ensure same residue numbering in all'),
            metavar='Input PDB',
            type=str,
            nargs='*',
            default=None)
        #
        parser.add_argument(
            '-input_pdb_chains',
            '--input_pdb_chains',
            help='Input PDB chain(s)',
            metavar='Input PDB chain selections',
            type=str,
            nargs='*',
            default=None)
        #
        parser.add_argument(
            '-map_resolution',
            '--map_resolution',
            help=('Resolution of map (Angstrom). Maximum resolution recommended '
                  'for use of SMOC is 7.5 Angstrom'),
            metavar='Map resolution',
            type=float,
            default=None)
        #
        parser.add_argument(
            '-use_smoc',
            '--use_smoc',
            help= ('''SMOC: calculate Segment Manders\' Overlap Coefficient  \n''' \
                   ''' Scores overlapping fragments and outputs per residue plots'''),
            metavar='SMOC score',
            type=bool,
            default=True)
        #
        parser.add_argument(
            '-use_sccc',
            '--use_sccc',
            help= ('''SCCC: calculate Segment Cross Correlation Coefficient \n''' \
                   ''' Scores rigid body fits and outputs score per rigid body'''),
            metavar='SCCC score',
            type=bool,
            default=False)
        
        parser.add_argument(
            '-rigid_body_path',
            '--rigid_body_path',
            help=('Input rigid body file. '
                  'Can be used with SMOC as well'),
            metavar='Input rigid body',
            type=str,
            nargs='*',
            default=None)
        #
        ribfind_cutoff = parser.add_argument_group()
        ribfind_cutoff.add_argument(
            '-ribfind_cutoff',
            '--ribfind_cutoff',
            help=('Set Ribfind cutoff for rigid body elements.  Default '
                  'of 100 corresponds to one rigid body per secondary '
                  'structure element'),
            metavar='Ribfind cutoff',
            type=int,
            default=100)
        #
        create_rigid_body = parser.add_argument_group()
        create_rigid_body.add_argument(
            '-create_rigid_body',
            '--create_rigid_body',
            help=('Generate rigid bodies based on '
                  'secondary structure elements'),
            metavar='Auto rigid body',
            type=bool,
            default=False)
        
        dist_or_fragment_selection = parser.add_argument_group()
        dist_or_fragment_selection.add_argument(
            '-dist_or_fragment_selection',
            '--dist_or_fragment_selection',
            help=('''Calculate SMOC score on neighboring voxels 
                    selected by: distance or sequence fragment length'''),
            metavar='''Residue neighborhood''',
            choices=['Distance', 'Fragment'],
            type=str,
            default='Distance')

        parser.add_argument(
            '-local_distance',
            '--local_distance',
            help='''Distance over which SMOC score is calculated 
                    for each residue (Angstroms)''',
            metavar='Local distance for scoring',
            type=float,
            default=5.0)

        parser.add_argument(
            '-auto_local_distance',
            '--auto_local_distance',
            help='''Automatically set distance for SMOC score calculation
                    for each residue''',
            metavar='Auto distance (SMOC)',
            type=bool,
            default=True)
        #
        parser.add_argument(
            '-auto_fragment_length',
            '--auto_fragment_length',
            help='''Automatically set number of residues in averaging window''',
            metavar='Auto window (SMOC)',
            type=bool,
            default=True)
        #
        parser.add_argument(
            '-fragment_length',
            '--fragment_length',
            help='''Number of resiudes in averaging window''',
            metavar='Fragment length',
            type=int,
            default=9)
        #
        return parser
    
    def ribfind_job(self, db_inject=None, job_title=None):
        # Set 
        if self.database_path is None:
            path = os.path.dirname(self.job_location)
        else:
            path = os.path.dirname(self.database_path)

        job_id, job_location = job_register.job_register(
            db_inject=db_inject,
            path=path,
            task_name=ribfind_task.Ribfind.task_info.name)

        # Set task args
        args = ribfind_task.Ribfind().args
        args.job_title.value = job_title
        pdbs = self.args.input_pdbs()
        if not isinstance(pdbs, list):
            pdbs = [pdbs]
        args.input_pdb.value = pdbs[0]
        args_json_string = args.output_args_as_json(
            return_string=True)

        # Set process args
        process_args = ['--no-gui']
        process_args += ["--args_string='{0}'".format(
            args_json_string)]
        process_args += ['--job_location={0}'.format(
            job_location)]
        if self.database_path is not None:
            process_args += ['--project_location={0}'.format(
                os.path.dirname(self.database_path))]
        if job_id is not None:
            process_args += ['--job_id={0}'.format(
                job_id)]

        # Create process
        self.ribfind_process = process_manager.CCPEMProcess(
            name='Ribfind auto run task',
            command=self.commands['ccpem-ribfind'],
            args=process_args,
            location=job_location,
            stdin=None)

    def edit_argsfile(self,args_file):
        with open(args_file,'r') as f:
            json_args = json.load(f)
            #set ribfind task output
            json_args['rigid_body_path'] = self.args.rigid_body_path.value
            #set mapprocess task output
            if self.mapprocess_task is not None:
                mapprocess_job_location = self.mapprocess_task.job_location
                #the processed map is written to the same dir as input map
                #the job location has other output files
                if mapprocess_job_location is not None:
                    processed_map_path = os.path.splitext(
                                        os.path.abspath(self.args.map_path.value))[0] \
                                        +'_processed.mrc'
                    #print processed_map_path
                    if os.path.isfile(processed_map_path):
                        json_args['map_path'] = processed_map_path
                        self.unprocessed_map = self.args.map_path.value
                        self.args.map_path.value = processed_map_path
        with open(args_file,'w') as f:
            json.dump(json_args,f)
            
    def check_processed_map(self):
        processed_map_path = os.path.splitext(
                            os.path.abspath(self.args.map_path.value))[0] \
                            +'_processed.mrc'
        #print processed_map_path
        if os.path.isfile(processed_map_path) and \
            self.mapprocess_task is not None:
            self.map_input.value_line.setText(processed_map_path)


    def run_pipeline(self, job_id=None, db_inject=None):
        
        pl = None
        if (self.args.map_resolution.value > 7.5 or self.args.use_sccc.value) \
            and self.args.rigid_body_path.value is None :
            ribfind_job_title = 'Auto run'
            if job_id is not None:
                ribfind_job_title += ' {0}'.format(
                    job_id)
            self.ribfind_job(db_inject=db_inject,
                             job_title=ribfind_job_title)
            pl = [[self.ribfind_process]]
            
            # Set rigid body path
            # TODO: set Cutoffs based on resolution
            if self.args.map_resolution.value < 10.0:
                cutoff = 100
            elif self.args.map_resolution.value < 15.0:
                cutoff = 60
            elif self.args.map_resolution.value < 100.0:
                cutoff = 30
            self.args.ribfind_cutoff.value = cutoff
            self.args.rigid_body_path.value = os.path.join(
                self.ribfind_process.location,
                'rigid_body_{0:0>3}.txt'.format(
                    self.args.ribfind_cutoff.value))
            
        
        # Set args
        #print self.args.rigid_body_path.value
        args_file = os.path.join(self.job_location,
                                 'args.json')
        self.check_processed_map()
        self.edit_argsfile(args_file)
            
        # Generate process
        if self.args.use_smoc.value:
            self.smoc_wrapper = SMOCWrapper(
                command=self.commands['ccpem-python'],
                job_location=self.job_location,
                name='SMOC score')
        else:
            self.smoc_wrapper = SMOCWrapper(
                command=self.commands['ccpem-python'],
                job_location=self.job_location,
                name='SCCC score')

        if pl is None:
            pl = [[self.smoc_wrapper.process]]
        else:
            pl.append([self.smoc_wrapper.process])
        # pipeline
        self.pipeline = process_manager.CCPEMPipeline(
            pipeline=pl,
            job_id=job_id,
            args_path=self.args.jsonfile,
            location=self.job_location,
            database_path=self.database_path,
            db_inject=db_inject,
            taskname=self.task_info.name,
            title=self.args.job_title.value)
        self.pipeline.start()

class SMOCWrapper(object):
    '''
    Wrapper for TEMPy SMOC process.
    '''
    def __init__(self,
                 command,
                 job_location,
                 name=None):
        self.job_location = ccpem_utils.get_path_abs(job_location)
        self.name = name
        if self.name is None:
            self.name = self.__class__.__name__

        # Set args
        self.args = os.path.join(self.job_location,
                                 'args.json')
        
        # Set process
        assert command is not None
        self.process = process_manager.CCPEMProcess(
            name=self.name,
            command=command,
            args=self.args,
            location=self.job_location,
            stdin=None)
