8000 modify save and load to 1.7 api for rrpn by cjt222 · Pull Request #4310 · PaddlePaddle/models · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

modify save and load to 1.7 api for rrpn #4310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 19, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 41 additions & 38 deletions PaddleCV/rrpn/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@
logger = logging.getLogger(__name__)


def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = fluid.io.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = fluid.io.load_program_state(path)
return state


def load_params(exe, prog, path):
"""
Load model from the given path.
Expand Down Expand Up @@ -64,7 +77,7 @@ def save(exe, prog, path):
if os.path.isdir(path):
shutil.rmtree(path)
logger.info('Save model to {}.'.format(path))
fluid.io.save_persistables(exe, path, prog)
fluid.save(prog, path)


def load_and_fusebn(exe, prog, path):
Expand All @@ -81,15 +94,6 @@ def load_and_fusebn(exe, prog, path):
if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path))

def _if_exist(var):
b = os.path.exists(os.path.join(path, var.name))

if b:
logger.debug('load weight {}'.format(var.name))
return b

all_vars = list(filter(_if_exist, prog.list_vars()))

# Since the program uses affine-channel, there is no running mean and var
# in the program, here append running mean and var.
# NOTE, the params of batch norm should be like:
Expand All @@ -101,15 +105,25 @@ def _if_exist(var):
mean_variances = set()
bn_vars = []

bn_in_path = True
state = None
if os.path.exists(path + '.pdparams'):
state = _load_state(path)

inner_prog = fluid.Program()
inner_start_prog = fluid.Program()
inner_block = inner_prog.global_block()
with fluid.program_guard(inner_prog, inner_start_prog):
def check_mean_and_bias(prefix):
m = prefix + 'mean'
v = prefix + 'variance'
if state:
return v in state and m in state
else:
return (os.path.exists(os.path.join(path, m)) and
os.path.exists(os.path.join(path, v)))

has_mean_bias = True

with fluid.program_guard(prog, fluid.Program()):
for block in prog.blocks:
ops = list(block.ops)
if not bn_in_path:
if not has_mean_bias:
break
for op in ops:
if op.type == 'affine_channel':
Expand All @@ -119,50 +133,39 @@ def _if_exist(var):
prefix = scale_name[:-5]
mean_name = prefix + 'mean'
variance_name = prefix + 'variance'

if not os.path.exists(os.path.join(path, mean_name)):
bn_in_path = False
break
if not os.path.exists(os.path.join(path, variance_name)):
bn_in_path = False
if not check_mean_and_bias(prefix):
has_mean_bias = False
break

bias = block.var(bias_name)

mean_vb = inner_block.create_var(
mean_vb = block.create_var(
name=mean_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
variance_vb = inner_block.create_var(
dtype=bias.dtype)
variance_vb = block.create_var(
name=variance_name,
type=bias.type,
shape=bias.shape,
dtype=bias.dtype,
persistable=True)
dtype=bias.dtype)

mean_variances.add(mean_vb)
mean_variances.add(variance_vb)

bn_vars.append(
[scale_name, bias_name, mean_name, variance_name])

if not bn_in_path:
fluid.io.load_vars(exe, path, prog, vars=all_vars)
if state:
fluid.io.set_program_state(prog, state)
else:
load_params(exe, prog, path)
if not has_mean_bias:
logger.warning(
"There is no paramters of batch norm in model {}. "
"Skip to fuse batch norm. And load paramters done.".format(path))
return

# load running mean and running variance on cpu place into global scope.
place = fluid.CPUPlace()
exe_cpu = fluid.Executor(place)
fluid.io.load_vars(exe_cpu, path, vars=[v for v in mean_variances])

# load params on real place into global scope.
fluid.io.load_vars(exe, path, prog, vars=all_vars)

eps = 1e-5
for names in bn_vars:
scale_name, bias_name, mean_name, var_name = names
Expand Down
0