--I created a decorator that reads parameters from a file and adds them to arguments, so I want to share how to use them.
--The motivation is to make it easier to read parameters for models such as DL.
--Because the code that links the arguments from the main function is troublesome ...
--omegaconf
It's convenient, so I want people who don't know to know it.
――If you have any other useful things, please let me know!
--Preparation
--Install omegaconf
%%bash
pip install omegaconf
--Preparing the decorator
import functools
from omegaconf import OmegaConf
def add_args(params_file: str, as_default: bool = False) -> callable:
@functools.wraps(add_args)
def _decorator(f: callable) -> callable:
@functools.wraps(f)
def _wrapper(*args, **kwargs) -> None:
cfg_params = OmegaConf.load(params_file)
if as_default:
cfg_params.update(kwargs)
kwargs = cfg_params
else:
kwargs.update(cfg_params)
return f(*args, **kwargs)
return _wrapper
return _decorator
--Prepare a parameter file to read (yaml or json)
--omegaconf
supports yaml, json
%%bash
cat <<__YML__ > params.yml
n_encoder_layer: 3
n_decoder_layer: 5
n_heads: 4
n_embedding: 16
__YML__
:
echo "===== [ params.yml ] ====="
cat params.yml
echo "====="
--Call
@add_args("params.yml")
def use_params(a, b, n_encoder_layer, n_decoder_layer, n_heads, n_embedding):
assert a == 0.25
assert b == "world"
assert n_encoder_layer == 3
assert n_decoder_layer == 5
assert n_heads == 4
assert n_embedding == 16
use_params(a=0.25, b="world")
print("OK")
Here, only a
and b
are specified in the use_params ()
function.
You can also programmatically overwrite the params.yml
setting as the default by specifying as_default = True
as the decorator argument, as shown below. (By the way, in the case of as_default = False
(default of the decorator), the direct of the configuration file is prioritized over the actual argument specified by the program.)
@add_args("params.yml", as_default=True)
def use_params(n_encoder_layer, n_decoder_layer, n_heads, n_embedding):
assert n_encoder_layer == 128 # notice !!
assert n_decoder_layer == 5
assert n_heads == 4
assert n_embedding == 16
use_params(n_encoder_layer=128)
print("OK")
--Other
--You can decorate it with the class __init__
, so please give it a try.
--In omegaconf
, you can refer to environment variables and direct variables in the configuration file as variables.
--For more information on omegaconf
, see here
――It's subtle to write the same code every time, so I want to be able to pip install
Recommended Posts