1

I would like to use python datatypes - both built-in and imported from libraries such as numpy, tensorflow, etc - as arguments in my hydra configuration. Something like:

# config.yaml

arg1: np.float32
arg2: tf.float16

I'm currently doing this instead:

# config.yaml

arg1: 'float32'
arg2: 'float16
# my python code
# ...
DTYPES_LOOKUP = {
  'float32': np.float32,
  'float16': tf.float16
}
arg1 = DTYPES_LOOKUP[config.arg1]
arg2 = DTYPES_LOOKUP[config.arg2]

Is there a more hydronic/elegant solution? Thanks!

4

1 回答 1

0

Does the hydra.utils.get_class function solve this problem for you?

# config.yaml

arg1: numpy.float32  # note: use "numpy" here, not "np"
arg2: tensorflow.float16
# python code
...
from hydra.utils import get_class
arg1 = get_class(config.arg1)
arg2 = get_class(config.arg2)

Update 1: using a custom resolver

Based on miccio's comment below, here is a demonstration using an OmegaConf custom resolver to wrap the get_class function.

from omegaconf import OmegaConf
from hydra.utils import get_class

OmegaConf.register_new_resolver(name="get_cls", resolver=lambda cls: get_class(cls))

config = OmegaConf.create("""
# config.yaml

arg1: "${get_cls: numpy.float32}"
arg2: "${get_cls: tensorflow.float16}"
""")

arg1 = config.arg1
arg1 = config.arg2

Update 2:

It turns out that get_class("numpy.float32") succeeds but get_class("tensorflow.float16") raises a ValueError. The reason is that get_class checks that the returned value is indeed a class (using isinstance(cls, type)).

The function hydra.utils.get_method is slightly more permissive, checking only that the returned value is a callable, but this still does not work with tf.float16.

>>> isinstance(tf.float16, type)
False
>>> callable(tf.float16)
False

A custom resolver wrapping the tensorflow.as_dtype function might be in order.

>>> tf.as_dtype("float16")
tf.float16
于 2022-01-19T00:08:47.783 回答