Skip to content

Commit

Permalink
[one-cmds] Revise one-import-onnx with extension (#13925)
Browse files Browse the repository at this point in the history
This will revise one-import-onnx with extension as
- run default conversion
- skip default if force_ext is given
- skip extension if disable_ext is given
- run extension if default fails

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
Co-authored-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
seanshpark and jinevening committed Sep 5, 2024
1 parent 5c34723 commit 487afbd
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 56 deletions.
139 changes: 84 additions & 55 deletions compiler/one-cmds/one-import-onnx
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,7 @@ def _convert(args):
logfile_path = os.path.realpath(args.output_path) + '.log'

# get import onnx extension path
ext_path = None
if not _disable_ext(args):
ext_path = _check_ext()
ext_path = _check_ext()

with open(logfile_path, 'wb') as f, tempfile.TemporaryDirectory() as tmpdir:
# save intermediate
Expand All @@ -322,61 +320,92 @@ def _convert(args):
os.path.splitext(basename)[0] + '~.onnx')
onnx.save(onnx_model, fixed_path)

run_default_import = True
ext_alt_onnx_path = None
if ext_path:
# save onnx_model to temporary alt file
basename = os.path.basename(getattr(args, 'input_path'))
alt_path = os.path.join(tmpdir, os.path.splitext(basename)[0] + '-alt.onnx')
onnx.save(onnx_model, alt_path)

# call extension with options
ext_cmd = [ext_path]
if oneutils.is_valid_attr(args, 'unroll_rnn'):
ext_cmd.append('--unroll_rnn')
if oneutils.is_valid_attr(args, 'unroll_lstm'):
ext_cmd.append('--unroll_lstm')
if oneutils.is_valid_attr(args, 'experimental_disable_batchmatmul_unfold'):
ext_cmd.append('--experimental_disable_batchmatmul_unfold')
if oneutils.is_valid_attr(args, 'save_intermediate'):
ext_cmd.append('--save_intermediate')
if oneutils.is_valid_attr(args, 'keep_io_order'):
ext_cmd.append('--keep_io_order')
ext_cmd.append(alt_path)
ext_cmd.append(getattr(args, 'output_path'))
oneutils.run(ext_cmd, logfile=f)
return

tf_savedmodel = onnx_tf.backend.prepare(onnx_model)

savedmodel_name = os.path.splitext(os.path.basename(
args.output_path))[0] + '.savedmodel'
savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
tf_savedmodel.export_graph(savedmodel_output_path)

# make a command to convert from tf to tflite
tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
tf2tfliteV2_output_name = os.path.splitext(os.path.basename(
args.output_path))[0] + '.tflite'
tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)

tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(args, tf2tfliteV2_path,
savedmodel_output_path,
tf2tfliteV2_output_path)

f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())

# convert tf to tflite
oneutils.run(tf2tfliteV2_cmd, logfile=f)

# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(tflite2circle_path,
tf2tfliteV2_output_path,
getattr(args, 'output_path'))

f.write((' '.join(tflite2circle_cmd) + '\n').encode())

# convert tflite to circle
oneutils.run(tflite2circle_cmd, err_prefix="tflite2circle", logfile=f)
ext_alt_onnx_path = os.path.join(tmpdir,
os.path.splitext(basename)[0] + '-alt.onnx')
onnx.save(onnx_model, ext_alt_onnx_path)

if _force_ext(args):
run_default_import = False

res_conv = -1
if run_default_import:
if _force_ext(args):
print(
"onnx-import-onnx: 'force_ext' is True, "
"but onnx-import-onnx-ext is not installed. "
"onnx-tf is used.",
flush=True)
# TODO split these to small functions
try:
tf_savedmodel = onnx_tf.backend.prepare(onnx_model)

savedmodel_name = os.path.splitext(os.path.basename(
args.output_path))[0] + '.savedmodel'
savedmodel_output_path = os.path.join(tmpdir, savedmodel_name)
tf_savedmodel.export_graph(savedmodel_output_path)

# make a command to convert from tf to tflite
tf2tfliteV2_path = os.path.join(dir_path, 'tf2tfliteV2.py')
tf2tfliteV2_output_name = os.path.splitext(
os.path.basename(args.output_path))[0] + '.tflite'
tf2tfliteV2_output_path = os.path.join(tmpdir, tf2tfliteV2_output_name)

tf2tfliteV2_cmd = _make_cmd.make_tf2tfliteV2_cmd(
args, tf2tfliteV2_path, savedmodel_output_path,
tf2tfliteV2_output_path)

f.write((' '.join(tf2tfliteV2_cmd) + '\n').encode())

# convert tf to tflite
res_conv = oneutils.run_ret(tf2tfliteV2_cmd, logfile=f)

if res_conv == 0:
# make a command to convert from tflite to circle
tflite2circle_path = os.path.join(dir_path, 'tflite2circle')
tflite2circle_cmd = _make_cmd.make_tflite2circle_cmd(
tflite2circle_path, tf2tfliteV2_output_path,
getattr(args, 'output_path'))

f.write((' '.join(tflite2circle_cmd) + '\n').encode())

# convert tflite to circle
res_conv = oneutils.run_ret(tflite2circle_cmd,
err_prefix="tflite2circle",
logfile=f)
except:
res_conv = -1

# if default conversion fails, try with one-import-onnx-ext if available
if ext_path and not _disable_ext(args):
if res_conv != 0:
if run_default_import:
print(
"onnx-import-onnx: Conversion with onnx-tf failed. "
"Fallback to use one-import-onnx-ext",
flush=True)
# call extension with options
ext_cmd = [ext_path]
if oneutils.is_valid_attr(args, 'unroll_rnn'):
ext_cmd.append('--unroll_rnn')
if oneutils.is_valid_attr(args, 'unroll_lstm'):
ext_cmd.append('--unroll_lstm')
if oneutils.is_valid_attr(args,
'experimental_disable_batchmatmul_unfold'):
ext_cmd.append('--experimental_disable_batchmatmul_unfold')
if oneutils.is_valid_attr(args, 'save_intermediate'):
ext_cmd.append('--save_intermediate')
if oneutils.is_valid_attr(args, 'keep_io_order'):
ext_cmd.append('--keep_io_order')
ext_cmd.append(ext_alt_onnx_path)
ext_cmd.append(getattr(args, 'output_path'))
res_conv = oneutils.run_ret(ext_cmd, logfile=f)

sys.exit(res_conv)


def main():
Expand Down
3 changes: 2 additions & 1 deletion compiler/one-cmds/tests/one-import-onnx_ext_001.test
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

# test for one-import-onnx to invoke extension
# default should execute and not one-import-onnx-ext

filename_ext="$(basename -- $0)"
filename="${filename_ext%.*}"
Expand Down Expand Up @@ -43,7 +44,7 @@ one-import-onnx \
--input_path ${inputfile} \
--output_path ${outputfile} > ${logfile} 2>&1

if ! grep -q "one-import-onnx-ext dummy output!!!" "${logfile}"; then
if grep -q "one-import-onnx-ext dummy output!!!" "${logfile}"; then
trap_err_onexit
fi

Expand Down

0 comments on commit 487afbd

Please sign in to comment.