Code indexing in gitaly is broken and leads to code not being visible to the user. We work on the issue with highest priority.

Skip to content
Snippets Groups Projects
Commit 54a706fa authored by tekin_g's avatar tekin_g
Browse files

improve inference script

parent 97719718
No related branches found
No related tags found
No related merge requests found
import os
import sys
import time
import numpy as np
import tensorflow as tf
import TF2ImageHelpers as images
from RLN_single_Model import RLN_model_simple, RLN_model
from RLN_single_Model import RLN_model
def get_model(path):
"""
Init a RLN model and restore latest weights from checkpoint directory.
:param path: Checkpoint Directory
:return: instance of Model, step number from training.
"""
print("Loading model...")
model = RLN_model(name="test")
i = tf.Variable(0, trainable=False, dtype=tf.int64)
ckpt = tf.train.Checkpoint(step=i, model=model)
manager = tf.train.CheckpointManager(ckpt, path, max_to_keep=20)
ckpt.restore(manager.latest_checkpoint)
print("restore successful")
return model, i
def inference(model, input):
"""
Inference using RLN model. Normalizes the input to the statistics that the model excepts and runs the model on input
:param model: Should be an instance of RLN model
:param input:should be a 3 (height,width,channel) or 4 (batch,height,width,channel) TensorFlow Tensor.
:return: 4 dimensional (batch,height,width,channel) tensor containing the predictions.
"""
shape = tf.shape(input)
if len(shape) < 3 or len(shape) > 4:
raise Exception("Input image shape wrong. Shape: {}".format(shape))
if len(shape) == 3:
input = tf.expand_dims(input, axis=0)
std = tf.math.reduce_std(input, axis=(1, 2, 3), keepdims=True)
mean = tf.math.reduce_mean(input, axis=(1, 2, 3), keepdims=True)
input = (input - mean) / std
x_i = model(input, training=False)
output = x_i[0] * std + mean
return output
if __name__ == "__main__":
import logging
import argparse
import os
import TF2ImageHelpers as images
import time
logging.basicConfig(filemode="a", encoding='utf-8', level=logging.DEBUG)
base_dir = "/data"
input_dir = ""
run_name = ""
train_model_path = ""
test_output = ""
print(test_output)
base_dir = "/data"
ground_truthdir = base_dir + '/train/ground_truth/'
input_dir = base_dir + "/test/input/"
arg_parser = argparse.ArgumentParser(prog="RLN Inference",
description="Used to run Inference on input pictures. Should support .bmp, "
".png, .npz file formats.")
timestr = time.strftime("%Y%m%d-%H%M%S")
run_name = "/{}/".format(sys.argv[1])
arg_parser.add_argument("--base_dir", default=base_dir,
help="Base Directory for input using the suggested file structure.")
arg_parser.add_argument("--run_name", default=None, help="Run name as used in the default file structure.")
train_model_path = base_dir + '/train/model_rl' + run_name
train_output = base_dir + '/train/output_rl' + run_name
test_output = base_dir + '/test/output_rl' + run_name
log_dir = base_dir + "/logs" + run_name
print(test_output)
arg_parser.add_argument("--checkpoint_dir", default=None, help="Override checkpoint directory.")
arg_parser.add_argument("--input_dir", default=None, help="Override input directory.")
arg_parser.add_argument("--output_dir", default=None, help="Override output directory.")
if not os.path.exists(train_model_path) or not os.path.exists(train_output) or not os.path.exists(
test_output) or not os.path.exists(log_dir):
raise Exception("missing locations")
args = arg_parser.parse_args()
base_dir = args.base_dir
import logging
if args.run_name is None or args.input_dir is None or args.output_dir is None:
raise Exception("You must specify either --run_name, or --input_dir, --output_dir, and --checkpoint_dir.")
logging.basicConfig(filemode="a", encoding='utf-8', level=logging.DEBUG)
if args.checkpoint_dir is not None and args.output_dir is not None and args.run_name is not None and args.input_dir is not None:
raise Exception("You must specify either --run_name, or --input_dir, --output_dir, and --checkpoint_dir.")
if args.run_name is not None:
input_dir = base_dir + "/test/input/"
run_name = "/{}/".format(args.run_name)
model = RLN_model(name="test")
train_model_path = base_dir + '/train/model_rl' + run_name
test_output = base_dir + '/test/output_rl' + run_name
i = tf.Variable(0, trainable=False, dtype=tf.int64)
if args.output_dir is not None and args.checkpoint_dir is not None and args.input_dir is not None:
input_dir = args.input_dir
test_output = args.output_dir
train_model_path = args.checkpoint_dir
ckpt = tf.train.Checkpoint(step=i, model=model)
manager = tf.train.CheckpointManager(ckpt, train_model_path, max_to_keep=20)
ckpt.restore(manager.latest_checkpoint)
logging.info("restore sucessfull")
a = os.listdir(input_dir)
a = [i for i in a if i.split('.')[-1] == 'npz']
a = sorted(a, key=lambda x: int(x.split(".")[0]))
logging.debug(a)
for path_i in a:
logging.debug(path_i)
x = images.read_file(input_dir + path_i)
x = tf.expand_dims(x, 0)
print(f"Input: {input_dir} \n Output: {test_output} \n Checkpoint: {train_model_path}")
st = time.time()
if not os.path.exists(train_model_path) or not os.path.exists(
test_output) or not os.path.exists(input_dir):
raise Exception("missing locations")
std = tf.math.reduce_std(x, axis=(1, 2, 3), keepdims=True)
mean = tf.math.reduce_mean(x, axis=(1, 2, 3), keepdims=True)
a = os.listdir(input_dir)
a = sorted(a, key=lambda x: int(x.split(".")[0]))
print("Input images: {}".format(a))
x = (x - mean) / std
x_i = model(x, training=False)
x_i = x_i[0] * std + mean
x_i = tf.clip_by_value(x_i, 0, 1, name="final_activation_clip")
for path_i in a:
logging.debug(path_i)
x = images.read_file(input_dir + path_i)
x = tf.expand_dims(x, 0)
model, i = get_model(train_model_path)
st = time.time()
x_i = inference(model, x)
et = time.time()
dt = np.array(et - st)
np.savez_compressed(test_output + path_i + ".npz", x=x, f_x=x_i, dt=dt)
et2 = time.time()
i.assign_add(1)
logging.info("Run: {} calc time: {} save time: {} memory: {}".format(i.numpy(), dt, et2 - et,
tf.config.experimental.get_memory_info(
'GPU:0')))
et = time.time()
dt = np.array(et - st)
np.savez_compressed(test_output + path_i + ".npz", x=x, f_x=x_i, dt=dt)
et2 = time.time()
i.assign_add(1)
logging.info("Run: {} calc time: {} save time: {} memory: {}".format(i.numpy(), dt, et2 - et,
tf.config.experimental.get_memory_info(
'GPU:0')))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment