import struct
import re
import numpy as np
import time


def getplots(s):
    if type(s) == dict:
        return sorted(list(set(s.keys()) - {
            'SHA','chi2a','iteration','msg','value','time',
            'ndo','randy','xi','runtime'
        }))
    elif type(s) == list:
        p = [set(getplots(i)) for i in s]
        return sorted(set.intersection(*p))


def read_record(fp, typ):
    l1 = struct.unpack("<I", fp.read(4))[0]
    body = fp.read(l1)
    l2 = struct.unpack("<I", fp.read(4))[0]
    if l1 != l2:
        raise KeyError("Record is not properly close %d v %d" % (l1, l2))

    n = l1 / struct.calcsize(typ)
    body = struct.unpack("<"+typ*n, body)
    if typ == 'c':
        return ''.join(body)
    else:
        if n == 1:
            return body[0]
        else:
            return body


def write_record(fp, typ, content):
    if type(content) == str and len(content) > 1:
        content = list(content)
    elif type(content) != list:
        content = [content]

    body = "".join(
        struct.pack("<"+typ, i) for i in content
    )
    fp.write(struct.pack('<I', len(body)))
    fp.write(body)
    fp.write(struct.pack('<I', len(body)))


def guess_version(fp, inttype='i'):
    fp.seek(0)
    magic = fp.read(17).replace('\00', '')
    if magic == '\t McMule  \t':
        version = read_record(fp, 'c').strip()
        vers, intiness = re.match("v(\\d+)([NL])", version).groups()

        return int(vers), {'N': 'i', 'L': 'q'}[intiness]

    # Test v1 vs. v2
    intsize = struct.calcsize(inttype)
    fp.seek(6893 + 3*intsize)
    nrq, nrbins = read_record(fp, inttype)
    fp.seek(6925 + 5*intsize + 22*nrq + 16*nrbins*nrq)
    try:
        time = read_record(fp, 'd')
        fp.seek(0)
        return 2, inttype
    except KeyError:
        fp.seek(0)
        return 1, inttype
    except struct.error:
        fp.seek(0)
        return 1, inttype


def importvegas(filename="", fp=None, inttype='i'):
    if not fp and filename:
        fp = open(filename, 'rb')
    version, inttype = guess_version(fp, inttype)
    dic = {}

    sha = read_record(fp, 'c')
    it = read_record(fp, inttype)
    ndo = read_record(fp, inttype)
    si = read_record(fp, 'd')
    swgt = read_record(fp, 'd')
    schi = read_record(fp, 'd')
    xi = np.reshape(read_record(fp, 'd'), (17, -1))
    randy = read_record(fp, inttype)

    if version > 2:
        nrq, nrbins, namelen = read_record(fp, inttype)
    else:
        nrq, nrbins = read_record(fp, inttype)
        namelen = 6

    bounds = np.reshape(read_record(fp, 'd'), (-1, nrq))
    names = read_record(fp, 'c')
    names = [
        names[namelen*i:(i+1)*namelen].strip().replace('\00', '')
        for i in range(len(names)/namelen)
    ]
    quant = np.array(read_record(fp, 'd'))

    if version > 1:
        dic['time'] = read_record(fp, 'd')
        dic['msg'] = read_record(fp, 'c').strip().replace('\00', '')
    else:
        dic['time'] = -1
        dic['msg'] = ""

    dic['SHA'] = sha
    dic['iteration'] = it
    dic['value'] = np.array([si/swgt, np.sqrt(1/swgt)])
    if it > 1:
        dic['chi2a'] = (schi-si**2/swgt)/(it-1)
    else:
        dic['chi2a'] = -1

    if it > 1:
        if version >= 3:
            ouarray = lambda o, u, b: np.concatenate(([u], b, [o]))
        else:
            ouarray = lambda o, u, b: b

        delta = (bounds[1]-bounds[0])/nrbins
        nrbinsuo = nrbins + 2 if version > 2 else nrbins

        for i in range(nrq):
            if len(names[i]) == 0:
                continue
            if abs(delta[i]) < 1e-15:
                continue

            if names[i] in dic:
                names[i] += '2'
            x = ouarray(
                np.inf, -np.inf,
                np.around(
                    bounds[0,i] + delta[i]*(0.5+np.arange(nrbins)),
                    10
                )
            )
            y = quant[i:nrq*nrbinsuo:nrq] \
                / it / ouarray(1, 1, [delta[i]]*nrbins)
            e = np.sqrt(
                (quant[nrq*nrbinsuo+i::nrq]-quant[i:nrq*nrbinsuo:nrq]**2/it)
                / it / (it-1)
            ) / ouarray(1, 1, [delta[i]]*nrbins)
            dic[names[i]] = np.column_stack((x, y, e))

    if filename:
        fp.close()

    return dic


def exportvegas(dic, filename="", fp=None):
    if not fp and filename:
        fp = open(filename, 'wb')

    fp.write("\t\000\000\000 McMule  \t\000\000\000")

    if 'SHA' not in dic:
        dic['SHA'] = '00000'
    if 'iteration' not in dic:
        dic['iteration'] = 2
    if 'ndo' not in dic:
        dic['ndo'] = -1
    if 'xi' not in dic:
        dic['xi'] = -np.ones(17*50)
    if 'randy' not in dic:
        dic['randy'] = -1
    if 'runtime' not in dic:
        dic['runtime'] = time.clock()
    if 'msg' not in dic:
        dic['msg'] = "Warning: Generated with Python"

    while type(dic['chi2a']) == list:
        dic['chi2a'] = dic['chi2a'][0]

    plots = getplots(dic)

    y,e = dic['value']

    write_record(fp, 'c', 'v3N       ')
    write_record(fp, 'c', dic['SHA'])
    write_record(fp, 'i', dic['iteration'])
    write_record(fp, 'i', dic['ndo'])
    write_record(fp, 'd', y/e**2)
    write_record(fp, 'd', 1/e**2)
    write_record(
        fp, 'd',
        dic['chi2a']*(dic['iteration']-1) + y**2/e**2
    )

    write_record(fp, 'd', list(dic['xi']))
    write_record(fp, 'i', dic['randy'])

    nrbins = len(dic[plots[0]])
    nrq = len(plots)
    namelen = max(max([len(i) for i in plots]),6)

    write_record(fp, 'i', [
        nrq,
        nrbins,
        namelen
    ])

    quant = np.zeros(2*(nrbins + 2)*nrq)

    minv = []
    maxv = []
    for i in range(len(plots)):
        pp = dic[plots[i]].copy()
        if len(pp) == 0:
            maxv.append(0)
            minv.append(0)
            continue
        if pp[0,0] != -np.inf:
            pp = np.concatenate(([[-np.inf,0,0]], pp))
        if pp[-1,0] != np.inf:
            pp = np.concatenate((pp, [[np.inf,0,0]]))

        delta = pp[2,0]-pp[1,0]
        maxv.append(pp[1,0] - delta/2 + nrbins * delta)
        minv.append(pp[1,0] - delta/2)

        delta = np.array([1] + [delta]*nrbins + [1])

        quant[i:nrq*(nrbins+2):nrq] = delta * dic['iteration'] * pp[:,1]

        if np.all(pp[:,2] == 0):
            quant[nrq*(nrbins + 2) + i::nrq] = \
                quant[i:nrq*(nrbins + 2):nrq]**2/dic['iteration']
        else:
            quant[nrq*(nrbins + 2) + i::nrq] = delta**2*dic['iteration']*(
                pp[:,2]**2*(dic['iteration']-1) + pp[:,1]**2
            )

    write_record(fp, 'd', minv+maxv)
    write_record(fp, 'c', "".join([i.ljust(namelen) for i in plots]))
    write_record(fp, 'd', list(quant))

    write_record(fp, 'd', dic['runtime'])
    write_record(fp, 'c', dic['msg'])