import numpy as np

from slic.core.acquisition import SFAcquisition
from slic.core.acquisition.sfacquisition import BSChannels, transpose_dicts, print_response
from slic.core.task import DAQTask


class CTAAcquisition(SFAcquisition):

    def __init__(self, cta, *args, n_block_size=None, **kwargs):
        self.cta = cta
        self.n_block_size = n_block_size
        super.__init__(self, *args, **kwargs)


    def acquire(self, filename, data_base_dir=None, detectors=None, channels=None, pvs=None, scan_info=None, n_pulses=100, n_repeat=1, is_scan_step=False, wait=True):
        if n_repeat != 1:
            raise NotImplementedError("Repetitions are not implemented") #TODO

        if not is_scan_step:
            run_number = self.client.next_run()
            print(f"Advanced run number to {run_number}.")
        else:
            run_number = self.client.run_number
            print(f"Continuing run number {run_number}.")

        if not filename or filename == "/dev/null":
            print("Skipping retrieval since no filename was given.")
            return

        if detectors is None:
            print("No detectors specified, using default detector list.")
            detectors = self.default_detectors

        if pvs is None:
            print("No PVs specified, using default PV list.")
            pvs = self.default_pvs

        if channels is None:
            print("No channels specified, using default channel list.")
            channels = self.default_channels

        bschs = BSChannels(*channels)
        bschs.check()

        client = self.client
        client.set_config(n_pulses, filename, detectors=detectors, channels=channels, pvs=pvs, scan_info=scan_info)

        def _acquire():
            self.cta.stop()
            self.cta.run()

            start_pid = self.cta.get_start_pid()
            print("CTA start pid:", start_pid)

            stop_pid = start_pid + n_pulses
            pids = np.arange(start_pid, stop_pid)

            pids_blocks = split(pids, self.n_block_size)

            for pb in pids_blocks
                res = self.retrieve(filename, pb, run_number=run_number)

                res = transpose_dicts(res)
                filenames = res.pop("filenames")
                print_response(res)

            return filenames

        def stopper():
            client.stop()
            self.cta.stop()

        task = DAQTask(_acquire, stopper=stopper, filename=filename, hold=False)
        self.current_task = task

        if wait:
            try:
                task.wait()
            except KeyboardInterrupt:
                print("Stopped current DAQ task:")

        return task





def split(a, block_size):
    if block_size is None:
        return [a]
    length = len(a)
    indices = np.arange(block_size, length, block_size) # must not start at 0, otherwise the first entry is an empty array
    return np.array_split(a, indices)





if __name__ == "__main__":
    from slic.devices.timing.events import CTASequencer

    cta = CTASequencer("SAT-CCTA-ESE"))
    daq = CTAAcquisition(cta, "maloja", "p19509", default_channels=["SAT-CVME-TIFALL5:EvtSet"], append_user_tag_to_data_dir=True)

    cta.cfg.repetitions = n_pulses # etc. etc.
    #daq.acquire("test")