#
#     Copyright (C) 2017 CCP-EM
#
#     This code is distributed under the terms and conditions of the
#     CCP-EM Program Suite Licence Agreement as a CCP-EM Application.
#     A copy of the CCP-EM licence can be obtained by writing to the
#     CCP-EM Secretary, RAL Laboratory, Harwell, OX11 0FA, UK.
#

'''
Run TEMPy:SMOC scoring process.
    Take input list of structures and score against single map.
    Output csv table of results
'''

import sys
import os, re
import json, glob
import pandas as pd
from ccpem_core.TEMPy.MapParser import MapParser
from ccpem_core.TEMPy.StructureParser import PDBParser
from TEMPy.RigidBodyParser import RBParser
from ccpem_core.TEMPy.ScoringFunctions import ScoringFunctions
from ccpem_core import ccpem_utils
from ccpem_core.tasks.tempy.smoc import smoc_results
from collections import OrderedDict
from TEMPy.mapprocess import Filter
import mrcfile, numpy as np

def main(json_path, verbose=True, set_results=True):
    # Process arguments from json parameter file
    args = json.load(open(json_path, 'r'))
    map_path = args['map_path']
    map_resolution =  args['map_resolution']
    pdb_path_list = args['input_pdbs']
    rigid_body_path = args['rigid_body_path']
    job_location = args['job_location']
        
    distance_or_fragment_selection = args['dist_or_fragment_selection']
#     if args['use_smoc'] and args['auto_local_distance']:
#         local_distance = map_resolution
#     else: local_distance = args['local_distance']
    if args['use_smoc'] and args['auto_fragment_length']:
        if map_resolution < 3.0:
            fragment_length = 5
        elif map_resolution < 5.0:
            fragment_length = 7
        else: fragment_length = 9
    else:
        fragment_length = args['fragment_length']
        
    if not distance_or_fragment_selection == 'Distance':
#         print 'Local distance for SMOC scoring: {}'.format(local_distance)
#     else:
        print '\nFragment length : {0}'.format(fragment_length)

    if job_location is None:
        job_location = os.getcwd() 

    if not isinstance(pdb_path_list, list):
        pdb_path_list = [pdb_path_list]

    #get rigid body file
    if not rigid_body_path is None:
        if not os.path.isdir(rigid_body_path): rigid_body_path = None
        if map_resolution < 10.0:
            cutoff = 100
        elif map_resolution < 15.0:
            cutoff = 60
        elif map_resolution < 100.0:
            cutoff = 30
        list_cutoffs = []
        rigid_body_files = []
        for file in os.listdir(rigid_body_path):
            if re.search("_denclust_[0-9]+\.txt",file) != None:
                rigid_body_files.append(file)

        for rf in rigid_body_files:
            basename = '.'.join(os.path.basename(rf).split('.')[:-1])
            basename = basename.split('_')[-1]
            try:
                list_cutoffs.append(int(basename))
            except TypeError: pass
        list_cutoffs.sort()
        cutoff_array = np.array(list_cutoffs)
        sel_indices = np.searchsorted(cutoff_array,cutoff) 
        if isinstance(sel_indices,int):
            sel_cutoff = ''
            if sel_indices >= len(cutoff_array):
                sel_indices = -1
            sel_cutoff = cutoff_array[sel_indices]
        else: 
            try: sel_cutoff = cutoff_array[sel_indices[0]]
            except IndexError: sel_cutoff = cutoff_array[-1]
        pdb_outdir = '.'.join(os.path.basename(pdb_path_list[0]).split('.')[:-1])
        rigid_body_path = \
            os.path.join(rigid_body_path,
                         pdb_outdir+'_denclust_{}.txt'.format(sel_cutoff))
        if os.path.isfile(rigid_body_path):
            edit_argsfile(json_path,rigid_body_path)
        
    if args['use_smoc']:
        if verbose:
            ccpem_utils.print_sub_header('Process SMOC scores')
        # Run smoc score
        smoc_df = process_smoc_scores(
            map_path=map_path,
            map_resolution=map_resolution,
            pdb_path_list=pdb_path_list,
            rigid_body_path=rigid_body_path,
            fragment_length=fragment_length,
            directory=job_location,
            distance_or_fragment_selection=distance_or_fragment_selection)
            #local_distance=local_distance)
    
        # Save smoc scores as csv
        csv_path = os.path.join(job_location,
                                'smoc_score.txt')
        if verbose:
            ccpem_utils.print_sub_header('Save raw scores')
            ccpem_utils.print_sub_sub_header(csv_path)
        with open(csv_path, 'w') as path:
            smoc_df.to_csv(path)

        if set_results and not set_results == 'False':
            # Set jsrview
            smoc_results.SMOCResultsViewer(
                smoc_dataframe=smoc_df,
                directory=job_location)
    else:
        if verbose:
            ccpem_utils.print_sub_header('Process SCCC scores')
        # Run sccc score
        sccc_df = process_sccc_scores(
            map_path=map_path,
            map_resolution=map_resolution,
            pdb_path_list=pdb_path_list,
            rigid_body_path=rigid_body_path,
            fragment_length=fragment_length,
            directory=job_location)
        
        # Save smoc scores as csv
        csv_path = os.path.join(job_location,
                                'sccc_score.txt')
        if verbose:
            ccpem_utils.print_sub_header('Save raw scores')
            ccpem_utils.print_sub_sub_header(csv_path)
        with open(csv_path, 'w') as path:
            sccc_df.to_csv(path,sep=';')
        
        if set_results and not set_results == 'False':
            # Set jsrview
            smoc_results.SCCCResultsViewer(
                sccc_df,
                directory=job_location)

def edit_argsfile(args_file,rigid_body_file):
    with open(args_file,'r') as f:
        json_args = json.load(f)
        #set ribfind task output
        json_args['rigid_body_path'] = rigid_body_file
    with open(args_file,'w') as f:
        json.dump(json_args,f)


def process_smoc_scores(map_path,
                        map_resolution,
                        pdb_path_list,
                        rigid_body_path,
                        fragment_length,
                        directory,
                        distance_or_fragment_selection='Distance'):
                        #local_distance=5.0):
    mrcobj=mrcfile.open(map_path,mode='r')
    em_map = Filter(mrcobj)
    em_map.set_apix_tempy()
    #em_map = MapParser.readMRC(map_path)
    # Get score for each structure in list
    smoc_frames = []
    residue_label = 'resnum'
    resname_label = 'resname'
    CAx_label = 'CAx'
    CAy_label = 'CAy'
    CAz_label = 'CAz'
    score_label = 'smoc'

    for pdb_path in pdb_path_list:
        pdb_id = os.path.basename(pdb_path)
        # Read structure
        structure = PDBParser.read_PDB_file(
            structure_id=pdb_id,
            filename=pdb_path,
            hetatm=False,
            water=False)

        # Get scores
        smoc_chain_scores, smoc_chain_CA = get_smoc_score(
            em_map=em_map,
            map_resolution=map_resolution,
            structure=structure,
            rigid_body_path=rigid_body_path,
            fragment_length=fragment_length,
            #local_distance=local_distance,
            distance_or_fragment_selection=distance_or_fragment_selection)

         #set scores as b-factors for coloring
        scored_pdb_file = os.path.join(
            directory,pdb_id.split('.')[0]+'_sc.pdb')
        try:
            ScoringFunctions().set_score_as_bfactor(
                structure_instance=structure,
                dict_scores=smoc_chain_scores,
                outfile=scored_pdb_file)
        except: 
            pass
        # Put residue scores into dataframe
        for chain, scores in smoc_chain_scores.iteritems():
            if chain.isspace():
                chainfix = '_'
            else:
                chainfix = '_' + chain
            list_resnum = scores.keys()
            list_resscore = scores.values()
            list_resname = []
            list_CAx = []
            list_CAy = []
            list_CAz = []
            for rnum in list_resnum:
                try: list_resname.append(smoc_chain_CA[chain][rnum][0])
                except KeyError: list_resname.append(' ')
                try: list_CAx.append(smoc_chain_CA[chain][rnum][1])
                except KeyError: 
                    list_CAx.append(0.0)
                try: list_CAy.append(smoc_chain_CA[chain][rnum][2])
                except KeyError: 
                    list_CAy.append(0.0)
                try: list_CAz.append(smoc_chain_CA[chain][rnum][3])
                except KeyError: 
                    list_CAz.append(0.0)
            dict_smocdata = OrderedDict()
            dict_smocdata = {
                    (pdb_id+chainfix, residue_label):list_resnum, 
                    (pdb_id+chainfix, resname_label): list_resname,
                    (pdb_id+chainfix, score_label):list_resscore, 
                    (pdb_id+chainfix, CAx_label): list_CAx,
                    (pdb_id+chainfix, CAy_label): list_CAy,
                    (pdb_id+chainfix, CAz_label): list_CAz,
                            }
            df = pd.DataFrame(dict_smocdata)
            smoc_frames.append(df)
    return pd.concat(smoc_frames, axis=1)


def process_sccc_scores(map_path,
                        map_resolution,
                        pdb_path_list,
                        rigid_body_path,
                        fragment_length,
                        directory):
    em_map = MapParser.readMRC(map_path)
    # Get score for each structure in list
    smoc_frames = []

    dict_rbs = RBParser.read_rigid_body_file_as_dict(rigid_body_path)
    list_rbs = dict_rbs.keys()
    list_rigid_segs = []
    for rb in list_rbs:
        for ch in dict_rbs[rb]:
            list_segs = dict_rbs[rb][ch]
            list_seg_str = []
            ct_seg = 0
            for seg in list_segs:
                try:
                    seg_ch = [seg[0]+str(ch),seg[1]+str(ch)]
                except: continue
                seg_str = ' '.join(seg_ch)
                if ct_seg == 3:
                    seg_str += '\n'
                    ct_seg = 0
                ct_seg += 1
                list_seg_str.append(str(seg_str))
            list_rigid_segs.append(','.join(list_seg_str))
    dict_table = {}
    
    for pdb_path in pdb_path_list:
        pdb_id = os.path.basename(pdb_path)
        print 'Scoring', pdb_id
        # Read structure
        structure = PDBParser.read_PDB_file(
            structure_id=pdb_id,
            filename=pdb_path,
            hetatm=False,
            water=False)

        # Get scores
        sccc_chain_scores, rigid_body_scores = get_sccc_score(
            em_map=em_map,
            map_resolution=map_resolution,
            structure=structure,
            rigid_body_path=rigid_body_path)
        
         #set scores as b-factors for coloring
        scored_pdb_file = os.path.join(
            directory,pdb_id.split('.')[0]+'_sc.pdb')
        try:
            ScoringFunctions().set_score_as_bfactor(
                structure_instance=structure,
                dict_scores=sccc_chain_scores,
                outfile=scored_pdb_file)
        except: 
            pass
        # make a dict of scores
        list_scores = []
        for rb in list_rbs:
            try: 
                list_scores.append(rigid_body_scores[rb])
            except KeyError: 
                list_scores.append('-')
        dict_table[pdb_id] = list_scores
    df = pd.DataFrame(dict_table,index = list_rigid_segs)
    return df


def get_smoc_score(em_map,
                   map_resolution,
                   structure,
                   rigid_body_path=None,
                   fragment_length=9,
                   sigma_coeff=0.225,
                   distance_or_fragment_selection='Distance'):
                   #local_distance=5.0):
    '''
    Score window (number residues)
    Sigma coeff
    Returns chain scores dictionary
    '''
    if distance_or_fragment_selection == 'Distance':
        fragment_score=False
    else: fragment_score=True
    # Calculate SMOC scores
    score_class = ScoringFunctions()
    chain_scores, chain_residues, chain_CA = score_class.SMOC(
        map_target=em_map,
        resolution_densMap=map_resolution,
        structure_instance=structure,
        win=fragment_length,
        rigid_body_file=rigid_body_path,
        sigma_map=sigma_coeff,
        fragment_score=fragment_score,
        #dist=local_distance,
        sigma_thr=2.5,
        calc_metric='smoc')
    return chain_scores, chain_CA

def get_sccc_score(em_map,
                   map_resolution,
                   structure,
                   rigid_body_path=None,
                   sigma_coeff=0.225):
    
    # Calculate SMOC scores
    score_class = ScoringFunctions()
    chain_scores, rigid_body_scores = score_class.get_sccc(
        map_target=em_map,
        resolution=map_resolution,
        structure_instance=structure,
        rigid_body_file=rigid_body_path,
        sigma_map=sigma_coeff)
    return chain_scores, rigid_body_scores



def import_smoc_scores_from_csv(path):
    return pd.read_csv(path, index_col=0)


if __name__ == '__main__':
    if len(sys.argv) <= 1:
        main(json_path='./test_data/unittest_args.json')
#         print 'Please supply json parameter file'
    elif len(sys.argv) > 2:
        main(json_path=sys.argv[1], set_results=sys.argv[2])
    else:
        main(json_path=sys.argv[1])
