#!/usr/bin/env python3
"""
Create a list of regions by splitting a reference based on the amount of data in bam files.
Uses the `bai` index of the bam files. Useful for submitting jobs of equal size to a cluster.
"""

import sys
import os
import argparse
import time
import logging

import struct
import numpy as np
from scipy import interpolate


DEFAULT_LOGGING_LEVEL = logging.INFO
MAX_LOGGING_LEVEL = logging.CRITICAL

def setup_logger(verbose_level):
    fmt=('%(levelname)s %(asctime)s [%(module)s:%(lineno)s %(funcName)s] :: '
            '%(message)s')
    logging.basicConfig(format=fmt, level=max((0, min((MAX_LOGGING_LEVEL,
                        DEFAULT_LOGGING_LEVEL-(verbose_level*10))))))


def Main(argv):
    tic_total = time.time()

    # parse arguments
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('bamfiles', metavar='BAMFILE', nargs='*')
    parser.add_argument('-L', '--bam-list', nargs='*')
    parser.add_argument('-r', '--reference-fai', help="reference fasta index file", required=True)
    parser.add_argument('-s', '--target-data-size', default='100e6', help="target combined data size of bam files in each region (MB)")
    parser.add_argument('--bai-interval-size', default=16384, type=int, help="Size in baseparis of each interval in the bam index (bai).")
    parser.add_argument('-v', '--verbose', action='count', default=0,
                        help="increase logging verbosity")
    parser.add_argument('-q', '--quiet', action='count', default=0,
                        help="decrease logging verbosity")
    args = parser.parse_args(argv)

    # setup logger
    setup_logger(verbose_level=args.verbose-args.quiet)
    if argv is not None:
        logging.warning('Using passed arguments: '+str(argv))
    logging.info('args: '+str(args))

    # additional argument parsing and datatype handling
    if not args.bamfiles and not args.bam_list:
        logging.error("Must provide an BAMFILE and/or --bam-list argument")
        sys.exit(2)
    args.target_data_size = int(float(args.target_data_size))*1000000
    logging.info('target-data-size: '+str(args.target_data_size)+' bytes')

    # read bam-lists if provided
    if args.bam_list:
        for bamlistfile in args.bam_list:
            with open(bamlistfile,'r') as fh:
                for x in fh:
                    x = x.split('#')[0].strip()
                    if x:
                        args.bamfiles.append(x)
    #logging.info('bam files: '+", ".join(args.bamfiles)) # output complete list of bam files being used

    # read the reference fasta index
    fai_chrom = []
    fai_len = []
    with open(args.reference_fai,'r') as fh:
        for x in fh:
            x = x.strip().split(sep='\t')
            fai_chrom.append(str(x[0]))
            fai_len.append(int(x[1]))


    ## read bai indexes, skipping bin info
    # list by chrom of number of intervals
    n_intvs = np.array([int(np.ceil(x/args.bai_interval_size)) for x in fai_len])
    # list by chrom of lists of interval offsets
    icumsz = [] # cumulative size of data by interval
    for i,n in enumerate(n_intvs):
        icumsz.append(np.zeros((n,), dtype=np.int64))

    for bamfn in args.bamfiles:
        baifn = bamfn+'.bai'
        with open(baifn,'rb') as fh:
            logging.info("processing: "+baifn)
            # filetype magic check
            assert struct.unpack('4s', fh.read(4))[0] == b'BAI\x01'

            # number of reference sequences (chroms)
            n_ref = struct.unpack('i', fh.read(4))[0]
            assert n_ref == len(fai_len), "fasta index and bam index have must have same number of chroms"

            for ci in range(n_ref):
                # skip over the binning index
                n_bin = struct.unpack('i', fh.read(4))[0]
                for bini in range(n_bin):
                    bin_id = struct.unpack('I', fh.read(4))[0]
                    n_chunk = struct.unpack('i', fh.read(4))[0]
                    fh.seek(n_chunk*16, os.SEEK_CUR)
                # read interval index
                n_intv = struct.unpack('i', fh.read(4))[0]
                if n_intv > 0:
                    ioff = np.array(struct.unpack(str(n_intv)+'Q', fh.read(n_intv*8)), dtype=np.int64)
                    while( len(ioff) < len(icumsz[ci]) ):
                        ioff = np.append(ioff, ioff[-1]+1)
                    icumsz[ci] += ioff-ioff[0]



    ## make the list of regions
    regions = []

    for ci,chrom in enumerate(fai_chrom):

        # sanity check last point if there are more than one
        if len(icumsz[ci]) > 1:
            assert icumsz[ci][-1] >= icumsz[ci][-2]

        # tiny chroms just get 1 region
        if len(icumsz[ci]) < 2:
            regions.extend([ (fai_chrom[ci], 0, fai_len[ci]) ])
            continue
        ds = icumsz[ci]
        pos = np.arange(0, ds.shape[0])*args.bai_interval_size

        # estimate total data size for the chrom
        f = interpolate.interp1d(pos, ds, fill_value='extrapolate', kind='linear')
        ds_total = f([fai_len[ci]])[0]

        num_regions = int(np.ceil(ds_total/args.target_data_size))

        # approx equal LENGTH regions
        # tmp = np.linspace(0, fai_len[ci], num=num_regions+1, endpoint=True, dtype=int)

        # approx equal DATA SIZE regions
        f = interpolate.interp1d(ds, pos, fill_value='extrapolate', kind='linear')
        dsx = np.linspace(0, ds_total, num=num_regions+1, endpoint=True, dtype=int)
        tmp = f(dsx).astype(int)
        # ensure we exactly hit the endpoints
        tmp[0] = 0
        tmp[-1] = fai_len[ci]

        regions.extend([ (fai_chrom[ci], tmp[i], tmp[i+1]) for i in range(len(tmp)-1) ])

    ## Output regions file
    for r in regions:
        print(*r, sep='\t')

    logging.info("Number of chroms: {}".format(len(fai_len)))
    logging.info("Number of splits: {}".format(len(regions)-len(fai_len)))
    logging.info("Number of regions: {}".format(len(regions)))

    logging.info("Done: {:.2f} sec elapsed".format(time.time()-tic_total))
    return 0



#########################################################################
# Main loop hook... if run as script run main, else this is just a module
if __name__ == '__main__':
    if 'TESTING_ARGS' in globals():
        sys.exit(Main(argv=TESTING_ARGS))
    else:
        sys.exit(Main(argv=None))