#5#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 19 17:10:32 2018

@author: lee
with batch normalization
"""
import sys
import time
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import tifffile as tiff
from tensorflow.python.ops import gen_image_ops
import argparse
#import tflearn
import makedata3D_train_single as Input
import makedata3D_test_single as InputT
from makedata3D_test_single import normalize
#import random
import tqdm
import math
import itertools
from tensorflow.compat.v1 import ConfigProto, InteractiveSession
import argparse




config =ConfigProto()
#config.gpu_options.per_process_gpu_memory_fraction = 0.9#0.4
#config.gpu_options.allow_growth=True
#session=InteractiveSession(config=config)

## mode: TR:train ; VL:validation, with known ground truth ; TS: test , no ground truth; TSS: test which is used for big data larger than 400M (32 bit)


global mode

import argparse


def parse_arguments():
    parser = argparse.ArgumentParser(description="RLN")
    parser.add_argument("--mode", type=str, default="TS", help="Mode description")
    parser.add_argument("--train_iter_num", type=int, default=10000, help="Training iteration number")
    parser.add_argument("--test_iter_num", type=int, default=9, help="Test iteration number")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Training batch size")
    parser.add_argument("--pre_train_batch_size", type=int, default=1, help="Pre-training batch size")
    parser.add_argument("--test_batch_size", type=int, default=1, help="Test batch size")
    parser.add_argument("--num_input_channels", type=int, default=1, help="Number of input channels")
    parser.add_argument("--num_output_channels", type=int, default=1, help="Number of output channels")
    parser.add_argument("--normal_pmin", type=float, default=0.01, help="Normal pmin")
    parser.add_argument("--normal_pmax", type=float, default=99.5, help="Normal pmax")
    return parser.parse_args()


args = parse_arguments()

# Now you can access the arguments using args.mode, args.data_dir, etc.
mode = args.mode
train_iter_num = args.train_iter_num
test_iter_num = args.test_iter_num
train_batch_size = args.train_batch_size
pre_train_batch_size = args.pre_train_batch_size
test_batch_size = args.test_batch_size
num_input_channels = args.num_input_channels
num_output_channels = args.num_output_channels
normal_pmin = args.normal_pmin
normal_pmax = args.normal_pmax

data_dir = '/data/'
train_model_path='/data/train/model_rl/'
train_output='/data/train/output_rl/'
test_output='/data/test/output_rl/'
log_dir = "/data/logs/"
image_dim=3


crop_data_size=320

if not os.path.exists(train_model_path):
	os.makedirs(train_model_path)

if not os.path.exists(train_output):
	os.makedirs(train_output)

if not os.path.exists(test_output):
	os.makedirs(test_output)
EPS = 10e-5


#自定义两个函数
@tf.custom_gradient
def tensor_div(a,x1):
    m=tf.div_no_nan(a, x1+0.001)#0.9*x1+0.1*a)
    #m=tf.where(x1<0.01,x=tf.ones_like(a),y=tf.div_no_nan(a, x1))
    def grad(dy):
        return dy,-tf.square(m)*dy#tf.where(x1<0.01, x=dy,y=-tf.square(m)*dy)#(tf.where(m>3.0, x=-9*tf.ones_like(m),y=-tf.square(m)))
    return m,grad
        
@tf.custom_gradient
def tensor_mul(a,x):
    m=tf.multiply(a, x)
    def grad(dy):
        return dy,dy*(a)
    return m,grad


def chan_ave(x):
    weights=tf.reduce_mean(x,axis=[0,1,2,3],keep_dims=True)
    print(weights)
#    weights=weights/tf.reduce_sum(weights,axis=0)
      #print(weights)
    num=weights.shape[4]    
    weights= tf.layers.dense(inputs=weights, units=num, activation=None)
    xo=weights*x
#    xo=tf.reduce_sum(xo,axis=4,keep_dims=True)
    return xo


#@tf.custom_gradient
def s_sigmoid(x):
    return tf.nn.softplus(x)#tf.nn.sigmoid(x/2)*4#tf.nn.softplus(x)

#@tf.custom_gradient
def s_sigmoid1(x):
    return tf.nn.softplus(x)#tf.nn.sigmoid(x/2)*4#sigmoid(x/2)*4

class Unet:
    def __init__(self):
        self.input_image = {}
        self.ground_truth = {}
        self.cast_image = None
        self.cast_ground_truth = None
        self.is_traing =None
        self.m=None
        self.loss, self.loss_square, self.loss_all, self.train_step = [None] * 4
        self.prediction, self.correct_prediction, self.accuracy = [None] * 3
        self.result_conv = {}
        self.result_relu = {}
        self.result_from_contract_layer = {}
        self.w = {}
        self.a,self.b=[None]*2
        self.sub_diff,self.mse_square,self.mse=[None] * 3
        self.mean_prediction,self.mean_gt=[None] * 2
        self.sigma_prediction,self.sigma_gt,self.sigma_cross=[None] * 3
        self.SSIM_1,self.SSIM_2,self.SSIM_3,self.SSIM_4,self.SSIM=[None] * 5
        self.learning_rate=[None]  
        self.prediction_min= [None]
        self.summary = None
 
    def tf_fspecial_gauss(self,size, sigma1):
        x_data, y_data, z_data = np.mgrid[-size[0]//2 + 1:size[0]//2 + 1, -size[1]//2 + 1:size[1]//2 + 1,-size[2]//2 + 1:size[2]//2 + 1]
#        print(x_data.shape)
        x_data = np.expand_dims(x_data, axis=-1)
        x_data = np.expand_dims(x_data, axis=-1)

        y_data = np.expand_dims(y_data, axis=-1)
        y_data = np.expand_dims(y_data, axis=-1)

        z_data = np.expand_dims(z_data, axis=-1)
        z_data = np.expand_dims(z_data, axis=-1)

        x = tf.constant(x_data, dtype=tf.float32)
        y = tf.constant(y_data, dtype=tf.float32)
        z = tf.constant(z_data, dtype=tf.float32)

        g = tf.exp(-((x**2 + y**2+z**2)/(2.0*sigma1**2)))#* tf.exp(-((z**2)/(2.0*9*sigma1**2)))
        g=g/tf.reduce_max(g)
        return g#g/tf.reduce_max(g) #last_sum

        
    def init_w(self, shape, name,stddev=1.0):#0.5和1都可以
        with tf.name_scope('init_w'):
            w = tf.Variable(initial_value=tf.random.truncated_normal(shape=shape,mean=0.0,stddev=stddev, dtype=tf.float32), name=name)
            return w

    def gaussian_ker(self,shape,name,stddev=2):#1
#        kernal1=self.tf_fspecial_gauss(shape,stddev*0.2)
        kernal2=self.tf_fspecial_gauss(shape,stddev*0.5)
        kernal3=self.tf_fspecial_gauss(shape,stddev)
        kernal4=self.tf_fspecial_gauss(shape,stddev*1.5)
        kernal1=self.tf_fspecial_gauss(shape,stddev*2)
        rad=tf.random_uniform([1,1,1,shape[3],1],minval=0.7,maxval=1,dtype=tf.float32)
        init=tf.concat([kernal1,kernal2,kernal3,kernal4],-1)
#        print(init.shape)
        init=tf.tile(init,[1,1,1,shape[3],1])*rad
        w =tf.Variable(initial_value=init, name=name)
        return w

    def gaussian_ker1(self,shape,name,stddev=1):#1
        kernal1=self.gaussian_ker_single(shape,stddev)
        for i in range(shape[3]-1):
            kernal=self.gaussian_ker_single(shape,stddev)
            kernal1=tf.concat([kernal1,kernal],-2)
        print(kernal1.shape)
        w =tf.Variable(initial_value=kernal1, name=name)
        return w

    def srelu(self,x):
        a1=tf.Variable(tf.constant(1.0), name='alpha', trainable=True)
        b1=tf.Variable(tf.constant(1.0), name='beta', trainable=True)
        return tf.nn.relu(x/a1+b1)

    def leaky_relu(self,x,name='leaky_relu'):
        a=tf.nn.softplus(x)
        b=tf.nn.sigmoid(x/2)*2##放宽3倍没有2倍好,1也没有2好
        return s_sigmoid(x)#tf.nn.leaky_relu(x,alpha=0.1)#0.03
    ############### batch normalization  ###################
    @staticmethod
    def batch_norm(x,is_training, eps=EPS, name='BatchNorm3d'):#GroupNorm(x,G=3,eps=1e-5):  
        return tf.layers.batch_normalization(x,training=is_training)  
  
      
    @staticmethod
    def copy_and_crop_and_merge(result_from_downsampling, result_from_upsampling):
        return tf.concat(values=[result_from_downsampling, result_from_upsampling], axis=-1)##axis=4

    def resize3D(self,x,shape):
        N,D,H,W,C=shape[0],shape[1],shape[2],shape[3],shape[4]
        N=tf.cast(N,tf.int32)
        D=tf.cast(D,tf.int32)
        H=tf.cast(H,tf.int32)
        W=tf.cast(W,tf.int32)
        C=tf.cast(C,tf.int32)
        c1=tf.zeros([N,1,H,W,C])
        c2=tf.zeros([N,D+1,1,W,C])
        c3=tf.zeros([N,D+1,H+1,1,C])
        x1=tf.concat([c1,x],axis=1)
        x2=tf.concat([x,c1],axis=1)
        x_out=x1+x2
        x1=tf.concat([c2,x_out],axis=2)
        x2=tf.concat([x_out,c2],axis=2)
        x_out=x1+x2
        x1=tf.concat([c3,x_out],axis=3)
        x2=tf.concat([x_out,c3],axis=3)
        x_out=x1+x2
        return x_out[:,:D,:H,:W,:]


    def get_SSIM(self,gt_label, dl_op,max_val=1):
        mean_prediction =tf.reduce_mean(dl_op)
        mean_gt =tf.reduce_mean(gt_label)
        sigma_prediction=tf.reduce_mean(tf.square(tf.subtract(dl_op,mean_prediction)))
        sigma_gt=tf.reduce_mean(tf.square(tf.subtract(gt_label,mean_gt)))
        sigma_cross=tf.reduce_mean(tf.multiply(tf.subtract(dl_op,mean_prediction),
                                                        tf.subtract(gt_label,mean_gt)))
        SSIM_1=2*tf.multiply(mean_prediction,mean_gt)+1e-4*max_val*max_val
        SSIM_2=2*sigma_cross+9e-4**max_val*max_val
        SSIM_3=tf.square(mean_prediction)+tf.square(mean_gt)+1e-4**max_val*max_val
        SSIM_4=sigma_prediction+sigma_gt+9e-4**max_val*max_val
        SSIM=tf.div(tf.multiply(SSIM_1,SSIM_2),tf.multiply(SSIM_3,SSIM_4))
        return SSIM
	
    def get_mse(self,gt_label, dl_op):
        sub_diff =dl_op-gt_label
        mse_square =tf.square(sub_diff)
        MSE = tf.reduce_mean(mse_square)+0.0001
        return MSE

    def get_mae(self,gt_label,dl_op):
        sub_diff =tf.abs(dl_op-gt_label)
        MAE=tf.reduce_mean(sub_diff)
        return MAE                
        
        
    
    def set_up_unet(self, batch_size):
        # input
        with tf.name_scope('input'):
            self.shape=tf.placeholder(dtype=tf.int32)
            self.input_image = tf.placeholder(dtype=tf.float32)
            self.ground_truth = tf.placeholder(dtype=tf.float32)
            self.cast_image = tf.reshape(
                    tensor=self.input_image,
                    shape=[batch_size, self.shape[0],self.shape[1],self.shape[2],num_input_channels]
                    )
            
            self.cast_ground_truth = tf.reshape(
                    tensor=self.ground_truth,
                    shape=[batch_size, self.shape[0],self.shape[1],self.shape[2],num_output_channels]
                    )

            self.cast_ground_truth1=self.cast_ground_truth*0.8+0.2*self.cast_image
            self.is_traing = tf.placeholder(tf.bool)
            normed_batch=self.cast_image
#            normed_batch=self.batch_norm(x=self.cast_image, is_training=self.is_traing, name='n_0')

        # layer 1
        with tf.name_scope('estimation'):
            # conv_1
            m=tf.reduce_max(normed_batch)
            normed_batch_t_down=tf.nn.avg_pool3d(normed_batch,[1,2,2,2,1],strides=[1,2,2,2,1],padding='VALID',data_format='NDHWC')
            normed_batch_t_down_m=s_sigmoid(tf.tile(input=normed_batch_t_down,multiples=[1,1,1,1,4]))
            self.w[3] = self.gaussian_ker(shape=[3,3,3,1,4], name='e_1')#9,9,9
            result_conv_3 = tf.nn.conv3d(
                    input=normed_batch_t_down, filter=self.w[3],
                    strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_1')#+normed_batch_t_down_m
            result_prelu_3 =s_sigmoid1(self.batch_norm(x=result_conv_3, is_training=self.is_traing, name='eb_1'))

            self.w[9] = self.gaussian_ker(shape=[3,3,3,4,4], name='e_2')
            result_conv_9 = tf.nn.conv3d(
                    input=result_prelu_3, filter=self.w[9],
                    strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_2')#+normed_batch_t_down_m
            result_conv_9 =s_sigmoid1(self.batch_norm(x=result_conv_9, is_training=self.is_traing, name='eb_2'))#+normed_batch_t_down_m

            result_conv_9_1=tf.concat([result_prelu_3,result_conv_9],-1)

            self.w[91] = self.gaussian_ker(shape=[3,3,3,8,4], name='e_3')
            result_conv_91 = tf.nn.conv3d(
                        input=result_conv_9_1, filter=self.w[91],
                        strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_3')#+normed_batch_t_down_m
            normed_batch_9 =self.batch_norm(x=result_conv_91, is_training=self.is_traing, name='eb_3')
            result_prelu_9 =s_sigmoid(normed_batch_9)+normed_batch_t_down_m

            ave_9=tf.reduce_mean(result_prelu_9,axis=4,keep_dims=True)

            temp_layer =tensor_div(normed_batch_t_down, ave_9)
            temp_layer=self.batch_norm(x=temp_layer, is_training=self.is_traing, name='t_1')

            self.w[15]= self.init_w(shape=[3,3,3,1,8], name='e_4')
            result_conv_15 = tf.nn.conv3d(
                input=temp_layer, filter=self.w[15],
                strides=[1, 1, 1, 1, 1], padding='SAME',name='conv_4')
            result_prelu_15 =s_sigmoid1(self.batch_norm(x=result_conv_15, is_training=self.is_traing, name='eb_4'))


            self.w[10] = self.init_w(shape=[3,3,3,8,8], name='e_5')#18,1,1
            result_conv_10 = tf.nn.conv3d(
                    input=result_prelu_15, filter=self.w[10],
                    strides=[1, 1, 1, 1, 1], padding='SAME',name='conv_5')
            result_prelu_10 =s_sigmoid1(self.batch_norm(x=result_conv_10, is_training=self.is_traing, name='eb_5'))

            result_prelu_10_1=tf.concat([result_prelu_10,result_prelu_15],-1)

            self.w[12] = self.init_w(shape=[3,3,3,16,8], name='e_6')#last 3,3,3,gaussian1.0
            result_conv_12 = tf.nn.conv3d(
                input=result_prelu_10_1, filter=self.w[12],
                strides=[1, 1, 1, 1, 1], padding='SAME',name='conv_6')
#            result_conv_12=result_conv_12+tf.ones_like(result_conv_12)
            normed_batch_12 =self.batch_norm(x=result_conv_12, is_training=self.is_traing, name='eb_6')
            result_prelu_12 =s_sigmoid1(normed_batch_12)


            self.w[20] = self.init_w(shape=[2,2,2,4,8], name='e_7')
            result_conv_12_u = tf.nn.conv3d_transpose(
                value=result_prelu_12, filter=self.w[20],
                output_shape=[batch_size,self.shape[0],self.shape[1],self.shape[2],4],
                strides=[1, 2, 2, 2, 1], padding='VALID', name='conv_7')
#            result_conv_12_u=self.resize3D(result_conv_12_u,[batch_size,self.shape[0],self.shape[1],self.shape[2],2])
            result_conv_12_u=s_sigmoid1(self.batch_norm(x=result_conv_12_u, is_training=self.is_traing, name='eb_7'))
            self.w[205] = self.init_w(shape=[3,3,3,4,4], name='e_8')
            result_conv_12_u2 = tf.nn.conv3d(
                input=result_conv_12_u, filter=self.w[205],
                strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_8')
            normed_batch_12_u2 = self.batch_norm(x=result_conv_12_u2, is_training=self.is_traing, name='eb_8')
            result_prelu_12_u2 =s_sigmoid(normed_batch_12_u2)


            ave_result_prelu_12_u2=tf.reduce_mean(result_prelu_12_u2,axis=4,keep_dims=True)
            temp2=tensor_mul(normed_batch,ave_result_prelu_12_u2)
            result_prelu_12_1=s_sigmoid(self.batch_norm(x=temp2, is_training=self.is_traing, name='t_2'))





        with tf.name_scope('update'):
            normed_batch_m=s_sigmoid(tf.tile(input=normed_batch,multiples=[1,1,1,1,4]))
            self.w[1] = self.gaussian_ker(shape=[3,3,3, num_input_channels,4], name='u_1')
            result_conv_1 = tf.nn.conv3d(
                    input=normed_batch, filter=self.w[1],
                    strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_9')
            result_conv_1_b=s_sigmoid1(self.batch_norm(x=result_conv_1, is_training=self.is_traing, name='ub_1'))


            self.w[101] = self.gaussian_ker(shape=[3,3,3,4,4], name='u_2')
            result_conv_1_1 = tf.nn.conv3d(
                    input=result_conv_1_b, filter=self.w[101],
                    strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_10')#+normed_batch_m
            normed_batch_1=self.batch_norm(x=result_conv_1_1, is_training=self.is_traing, name='ub_2')
            result_prelu_1=s_sigmoid(normed_batch_1)+normed_batch_m


            ave_1=tf.reduce_mean(result_prelu_1,axis=4,keep_dims=True)
            EST=tensor_div(normed_batch, ave_1)
            EST=self.batch_norm(x=EST, is_training=self.is_traing, name='t_3')


            self.w[2] = self.init_w(shape=[3,3,3,1,8], name='u_3')#+self.w[15]
            result_conv_2 = tf.nn.conv3d(
                    input=EST, filter=self.w[2],
                    strides=[1, 1, 1, 1, 1], padding='SAME',name='conv_11')
            result_conv_2_b=s_sigmoid1(self.batch_norm(x=result_conv_2, is_training=self.is_traing, name='ub_3'))


            self.w[201] = self.init_w(shape=[3,3,3,8,8], name='u_4')
            result_conv_2_1 = tf.nn.conv3d(
                    input=result_conv_2_b, filter=self.w[201],
                    strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_12')
            normed_batch_2 = self.batch_norm(x=result_conv_2_1, is_training=self.is_traing, name='u_4')
            act_2=s_sigmoid(normed_batch_2)+tf.ones_like(normed_batch_2)


            ave_2=tf.reduce_mean(act_2,axis=4,keep_dims=True)
            Estimation1=tensor_mul(result_prelu_12_1,ave_2)
            Estimation=s_sigmoid(self.batch_norm(x=Estimation1, is_training=self.is_traing, name='t_4'))
            result_prelu_2 =Estimation
            result_prelu_2_1=Estimation

            
            Estimation_tile=tf.tile(input=Estimation,multiples=[1,1,1,1,8])
            self.w[202] = self.init_w(shape=[3,3,3,1,8], name='u_5')
            result_conv_2_fine1 = tf.nn.conv3d(
                    input=result_prelu_2_1, filter=self.w[202],
                    strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_13')
            result_conv_2_fined=s_sigmoid1(result_conv_2_fine1)

            result_conv_2_fine1_c=tf.concat([result_conv_2_fined,result_prelu_2_1,result_prelu_12_1],-1)#if not ok, try result_prelu_12_1

            self.w[203] = self.init_w(shape=[3,3,3,10,8], name='u_6')
            result_conv_2_fine = tf.nn.conv3d(
                    input=result_conv_2_fine1_c, filter=self.w[203],
                    strides=[1, 1, 1, 1, 1], padding='SAME',name='conv_14')
            act_2_fine=s_sigmoid1(result_conv_2_fine)
            Merge=tf.concat([result_conv_2_fined,act_2_fine],-1)


            self.w[13] = self.init_w(shape=[3,3,3,16,8], name='u_7')
            result_conv_13 = tf.nn.conv3d(
                input=Merge, filter=self.w[13],
                strides=[1, 1, 1, 1, 1], padding='SAME', name='conv_15')#+result_conv_2_fined
            normed_batch_13 =self.batch_norm(x=result_conv_13, is_training=self.is_traing,name='ub_7')
            result_prelu_13 =s_sigmoid(normed_batch_13) #tf.nn.leaky_relu(normed_batch_13,name='relu_1')




            self.prediction=tf.reduce_mean(result_prelu_13,axis=4,keep_dims=True)#/tf.reduce_max(result_prelu_14)#*tf.reduce_max(normed_batch) #tf.multiply(result_prelu_7,result_prelu_1)
            self.prediction_log=self.prediction


            self.e=temp2
            self.e2=Estimation1
            
            self.first=tf.reduce_mean(result_prelu_13,axis=4,keep_dims=True)

            e_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='estimation')
            w_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='update')
            
        with tf.name_scope('MSE'):
            self.mse =0.0*self.get_mae(self.prediction_log,self.cast_ground_truth)+1.0*self.get_mse(self.prediction_log,self.cast_ground_truth)
            tf.summary.scalar("mse",self.mse)
        with tf.name_scope('SSIM'):
            self.SSIM=self.get_SSIM(self.prediction_log,self.cast_ground_truth)
            tf.summary.scalar("SSIM", self.SSIM)

        with tf.name_scope('SSIM2'):
            self.SSIM2=self.get_mse(self.e,self.cast_ground_truth1)
            self.SSIM1=self.get_SSIM(self.e2,self.cast_ground_truth)
            self.mse2=self.SSIM2#-tf.log((1+self.SSIM2)/2)#-tf.log((1+self.SSIM1)/2)
            tf.summary.scalar("mse2",self.mse2)
        with tf.name_scope('loss'):
#            k1=tf.cond(self.prediction_min<0, lambda:1.2, lambda:0.6)
            k1=0.0 #1.0 #1.0 
            self.loss =0.1*self.mse2+1*self.mse-1.0*tf.log((1+self.SSIM)/2)#-k1*self.prediction_min
            tf.summary.scalar("loss",self.loss)
        # Gradient Descent
        with tf.name_scope('step'):
            self.global_step = tf.Variable(0, trainable=False)#0.015,500,0.95 #wide 0.01 200 0.95
            self.learning_rate = tf.train.exponential_decay(0.015,self.global_step,600,0.95,staircase=False) #previous0.025,1000,0.95
            tf.summary.scalar("learning rate",self.learning_rate)
        #with tf.name_scope('Gradient_Descent'):
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.train_step = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss,global_step=self.global_step)

        self.summary = tf.summary.merge_all()
    def train(self):
        Input.filenames=[]
        train_dir=train_model_path
        checkpoint_path=os.path.join(train_dir,'model.ckpt')
        pre_parameters_saver = tf.train.Saver() #sav
        all_parameters_saver = tf.train.Saver() #save

        #tf.reset_default_graph()
        with tf.Session(config=config) as sess:  # 开始一个会话
            sum_writer = tf.summary.FileWriter(log_dir + str(time.time()), sess.graph)
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
        #    all_parameters_saver.restore(sess=sess, save_path=checkpoint_path)
            sum_los = 0.0
            dtime=0.0
            time_start=time.time()
            for k in range(train_iter_num):
                t1=time.time()
                train_images,train_GT,label2out,shape=Input.get_data_tiff(data_dir,'train',train_batch_size,normal_pmin,normal_pmax,True)
                t2=time.time()
                dtime=dtime+t2-t1
                #print(t2-t1)
                summary,mse,ssim,lo,mse2,trainer= sess.run([self.summary,self.mse,self.SSIM,self.loss,self.mse2,self.train_step],
                           feed_dict={self.shape:shape,self.input_image:train_images, self.ground_truth: train_GT,self.is_traing: True})
                #t3=time.time()
                #print(t3-t2)
                sum_writer.add_summary(summary, k)
                sum_los += lo          
                if k % 101 == 0:
                    time_end=time.time()
                    used_time=time_end-time_start
                    print('dtime:%.6f'%dtime)
                    print('num %d, mse: %.6f, SSIM: %.6f, loss: %.6f,mse2: %.6f. runtime:%.6f ' % (k, mse, ssim, lo,mse2,used_time))
                if (k+1)%648==0:#13
                    print('sum_lo: %.6f' %(sum_los))
                    sum_los = 0.0
                if (k+1)%500==0:
                    image= sess.run([self.prediction],
                                  feed_dict={self.shape:shape,self.input_image:train_images, self.ground_truth: train_GT,self.is_traing:True})
                    image1=np.array(image)
                    print(image1.shape)
                    reshape_image=image1[0,:,:,:,:,0]
                    #print(reshape_image.shape)
                    print(label2out)
#                    print('sum_lo: %.6f' %(sum_los))
#                    sum_los = 0.0
                    for v in range(train_batch_size):
                        single=reshape_image[v]
                        filenames_out=train_output+str(k)+'_'+label2out[0]+str(v)
                        tiff.imsave(filenames_out,single)
                if (k+1)%100 == 0:
                    all_parameters_saver.save(sess=sess, save_path=checkpoint_path,global_step=k)
                    print('saving num %d' % (k))
                sys.stdout.flush()
            print("Done training")
        sess.close()
        
        
    def valid(self):
        Input.filenames=[]
        train_dir=train_model_path #20201105_canbeused/scale4_3/' #save logs files for training process
        checkpoint_path=os.path.join(train_dir,'model.ckpt')
        variables=tf.contrib.framework.get_variables_to_restore()
        variables_to_restore1=[v for v in variables if (v.name.split('/')[0]!='step')]
        all_parameters_saver = tf.train.Saver(variables_to_restore1)
        with tf.Session() as sess:  # 开始一个会话
#           time_start=time.time()
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            all_parameters_saver.restore(sess=sess, save_path=checkpoint_path)
            sum_los = 0.0
            for m in range(test_iter_num):
                time_start=time.time()
                test_images,test_GT,label2out,shape=Input.get_data_tiff_VL(data_dir,'test',test_batch_size,normal_pmin,normal_pmax)
                w1,image, mse, ssim, los = sess.run([self.w[1],self.prediction, self.mse, self.SSIM, self.loss],
                            feed_dict={self.shape:shape,self.input_image:test_images, self.ground_truth: test_GT,self.is_traing:False})
                #print('num %d, mse: %.6f, SSIM: %.6f, loss: %.6f ' % (m, mse, ssim, lo))
                #print(w1)
                sum_los += los
                image1=np.array(image)
                print(image1.shape)
                reshape_image=image1[:,:,:,:,0]
                for v in range(test_batch_size):
                    single=reshape_image[v]
                    filenames_out=test_output+'rl_'+label2out[v]
                    tiff.imsave(filenames_out,single)
                #########  save output image #####################
                if m % 1 == 0:
                    time_end=time.time()
                    print('num %d, mse: %.6f, SSIM: %.6f, loss: %.6f, runtime:%.6f ' % (m, mse, ssim, los,time_end-time_start))
                sys.stdout.flush()
        print('Done testing')

    def test1(self):
        Input.filenames=[]
        train_dir=train_model_path #simu_apply/'##ER_high/'#'/media/sda-4T/liyue/ER/train/model_rl/simu_apply_1001/'#simu_apply/' #save logs files for training process
        checkpoint_path=os.path.join(train_dir,'model.ckpt')
        variables=tf.contrib.framework.get_variables_to_restore()
        variables_to_restore1=[v for v in variables if (v.name.split('/')[0]!='step')]
        all_parameters_saver = tf.train.Saver(variables_to_restore1)
        with tf.Session() as sess:  # 开始一个会话
#           time_start=time.time()
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            all_parameters_saver.restore(sess=sess, save_path=checkpoint_path)
            for m in range(test_iter_num):
                #tf.reset_default_graph()
                time_start=time.time()
                max_v,test_images,label2out,shape=InputT.get_data_tiff(data_dir,'test',test_batch_size,normal_pmin,normal_pmax)
                print("starting on: ", label2out)
                image= sess.run([self.first],feed_dict={self.shape:shape,self.input_image:test_images,self.is_traing:False})
                #print('num %d, mse: %.6f, SSIM: %.6f, loss: %.6f ' % (m, mse, ssim,los))
                image1=np.array(image)#.astype(np.float16)
                reshape_image=image1[:,:,:,:]#*max_v
                reshape_image=reshape_image.astype(np.float16)
                for v in range(test_batch_size):
                    single=reshape_image[v]
                    filenames_out=test_output+'RLN_'+label2out[v]
                    tiff.imsave(filenames_out,single)
                #########  save output image #####################
                if m % 1 == 0:
                    time_end=time.time()
                    print('num %d, runtime:%.6f ' % (m,time_end-time_start))
                sys.stdout.flush()
        print('Done testing')

    def test_stitch(self):
        filenames=[]
        input_path=os.path.join(data_dir,'test/input/')#_high_clear_tissue_stitch/')
        filenames=os.listdir(input_path)
        filenames.sort()
        folder_num = len(filenames)

        # set model
        train_dir=train_model_path 
        checkpoint_path=os.path.join(train_dir,'model.ckpt')
        variables=tf.contrib.framework.get_variables_to_restore()
        variables_to_restore1=[v for v in variables if (v.name.split('/')[0]!='step')]
        all_parameters_saver = tf.train.Saver(variables_to_restore1)
        with tf.Session() as sess:  # 开始一个会话
            time_start_init=time.time()
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            all_parameters_saver.restore(sess=sess, save_path=checkpoint_path)
            time_read_all=0
            time_crop_all=0
            time_process_all=0
            time_save_all=0

            for nn in range(0,folder_num):
                time_start=time.time()
                Input_file = input_path + filenames[nn]
                print(Input_file)
                label_out=filenames[nn]

                time_read_s=time.time()
                input_tif = tiff.imread(Input_file) #the input large data_size
                time_read_e=time.time()
                time_read_all=time_read_all+time_read_e-time_read_s

                output_image_shape = input_tif.shape
                input_image_shape = input_tif.shape

                time_crop_s=time.time()
                if input_image_shape[0]<40:
                    d=0
                else:
                    d=10
                overlap_shape = (d, 24, 24)
                if input_image_shape[1]%2!=0:
                    overlap_shape = (d, 23, 24)
                if input_image_shape[2]%2!=0:
                    overlap_shape = (d, 32, 24)
                # set the cropped size ~200M
                crop_d=min(output_image_shape[0],2000)
                crop_w = min(math.floor(math.sqrt(crop_data_size / 4 * 1024 * 1024 / crop_d)),1998)
                crop_h = min(math.floor(math.sqrt(crop_data_size / 4 * 1024 * 1024 / crop_d)),1998)
                if crop_w % 2 !=0:
                    different=2-crop_w%2
                    crop_w=crop_w-different
                if crop_h % 2 !=0:
                    different=2-crop_h%2
                    crop_h = crop_h - different
                model_input_image_shape = (crop_d, crop_h, crop_w)
                print(model_input_image_shape)
                step_shape = tuple(m - o for m, o in zip(model_input_image_shape, overlap_shape))

                block_weight = np.ones(
                    [m - 2 * o for m, o
                     in zip(model_input_image_shape, overlap_shape)],dtype=np.float32)
                block_weight = np.pad(
                    block_weight,
                    [(o + 1, o + 1) for o in overlap_shape],
                    'linear_ramp')[(slice(1, -1),) * image_dim]


                applied = np.zeros(
                    (*output_image_shape, num_output_channels), dtype=np.float32)
                sum_weight = np.zeros(output_image_shape, dtype=np.float32)
                num_steps = tuple(
                    i // s + (((i//s)*s+o)<i)
                    for i, s, o in zip(input_image_shape, step_shape,overlap_shape))

                blocks = list(itertools.product(
                    *[np.arange(n) * s for n, s in zip(num_steps, step_shape)]))

                print(blocks)

                time_crop_e=time.time()
                time_crop_all=time_crop_all+time_crop_e-time_crop_s

                time_process_s=time.time()
    
                for chunk_index in tqdm.trange(
                        0, len(blocks), test_batch_size, disable=False,
                        dynamic_ncols=True, ascii=tqdm.utils.IS_WIN):
                    rois = []
                    maxv=[]
                    minv=[]
                    for batch_index, tl in enumerate(blocks[chunk_index:chunk_index + test_batch_size]):
                        # tl 左上角坐标, br 右下角坐标
                        br = [min(t + m, i) for t, m, i in zip(tl, model_input_image_shape, input_image_shape)]
                        # r1是用于预测的图像区域
                        r1, r2 = zip(*[(slice(s, e), slice(0, e - s)) for s, e in zip(tl, br)])
                        #print(r2)

                        m = input_tif [r1]
                        block_weight1=block_weight
                        #print(m.shape)
                        if model_input_image_shape != m.shape:
                            #reshape the weight block
                            block_weight1=block_weight[:m.shape[0],:m.shape[1],:m.shape[2]]

                            # expand the data
                            #pad_width = [(0, b - s) for b, s
                            #             in zip(model_input_image_shape, m.shape)]
                            #print(pad_width)
                            #m = np.pad(m, pad_width, 'reflect')


                        shape_in=m.shape
                       # print(shape_in)
                        min_v,max_v, normal_input_tif = normalize(m, normal_pmin, normal_pmax)
                        batch= normal_input_tif
                        rois.append((r1, r2))
                        maxv.append(max_v)
                        minv.append(min_v)

                    #time_start = time.time()
                    image = sess.run(self.first, feed_dict={self.shape: shape_in, self.input_image: batch, self.is_traing: False})
                    image = image*maxv[0]
                    #image[image<0]=0
                    #print('num %d, runtime:%.6f ' % (nn,time_end-time_start))
                    for batch_index in range(len(rois)):
                        for channel in range(num_output_channels):
                            image[batch_index, ..., channel] *= block_weight1

                        r1, r2 = [roi for roi in rois[batch_index]]
                        #print(applied[r1].shape,image[batch_index][r2].shape)

                        applied[r1] += image[batch_index][r2]
                        sum_weight[r1] += block_weight[r2]

                time_process_e=time.time()
                time_process_all=time_process_all+time_process_e-time_process_s
                time_save_s=time.time()
                for channel in range(num_output_channels):
                    applied[..., channel] /= sum_weight

                if applied.shape[-1] == 1:
                    applied = applied[..., 0]



                image1=applied#np.array(applied)
                single = image1.astype(np.float16)
                #print(single.shape)
                filenames_out = test_output + 'RLN_' + label_out
                tiff.imsave(filenames_out, single)
                time_save_e=time.time()
                time_save_all=time_save_all+time_save_e-time_save_s
                time_end=time.time()
                print('num %d, runtime_all:%.6f ' % (nn,time_end-time_start))
                sys.stdout.flush()
            print('total_time: %.6f'%(time.time()-time_start_init))
            print('time_read_all: %.6f,time_crop_all: %.6f,time_process_all: %.6f,time_save_all: %.6f' %(time_read_all,time_crop_all,time_process_all,time_save_all))
            sys.stdout.flush()
        print('Done testing')

        
        
if __name__ == "__main__":
    net = Unet()
    if mode == 'TR':
        net.set_up_unet(train_batch_size)
        net.train()
    if mode == 'VL':
        net.set_up_unet(test_batch_size)
        net.valid()
    if mode == 'TS':
        net.set_up_unet(test_batch_size)
        net.test1()
    if mode == 'TSS':
        net.set_up_unet(test_batch_size)
        net.test_stitch()