这是使用文件返回一些参数的函数
def load_network_params(agent_name: str, env_name: str,
network_root_folder: str = 'jax-models') -> flax.core.FrozenDict:
filePathString=network_root_folder+r"/"+agent_name+r"/"+env_name+r"/2"
fileNameStr=r"ckpt.199"
fileString=os.path.join(filePathString,fileNameStr)
with open(fileString, 'rb') as file:
# file processing irrelevant to error
})
return network_params
这是随后使用该功能的方式:
from dataset import *
generate_dataset(r"dqn", r"Breakout", 10_000, network_root_folder=r"C:/Users/jk5g19/Documents/Year3IP/scripts/jax")
编辑:基本上 generate_dataset 调用 load_network_parameter,我试图减少共享的代码量,这样人们就不会感到困惑。这是最小可行产品。
def generate_dataset(agent_name: str, env_name: str, dataset_size: int, num_envs: int = 20, epsilon: float = 0.1,
network_root_folder: str = 'jax-models') -> Tuple[onp.ndarray, onp.ndarray, onp.ndarray,
onp.ndarray, int]:
num_actions = gym.make(f'{env_name}NoFrameskip-v0').action_space.n
images_obs_dataset = onp.zeros((dataset_size, 84, 84, 4))
ram_obs_dataset = onp.zeros((dataset_size, 128))
q_values_dataset = onp.zeros((dataset_size, num_actions))
action_dataset = onp.zeros(dataset_size)
network_def, network_args = get_network_def(agent_name, num_actions)
network_params = load_network_params(agent_name, env_name, network_root_folder=network_root_folder)
return images_obs_dataset, ram_obs_dataset, q_values_dataset, action_dataset, episodes_run
当我输出文件字符串并将其复制并粘贴到文件资源管理器中时,可以访问所需的文件。我使用原始字符串作为路径,并尝试使用双反斜杠和正斜杠。
我还在路径中添加了一个 test.txt 文件,文件路径不起作用,因此它排除了导致 open() 问题的文件类型。