1

I've been working on a pair of legs that self-balance. If his 'waist' goes below a certain y-position value (falling over/tripping), the area is supposed to reset and also deduct points from his reward-score. I'm awfully new to machine learning, so go easy on me! Why is the agent not resetting when he falls over?

Legs trainer resport Agents in inspector




Code to Agent (Updated):

    using MLAgents;
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using UnityEngine;

    using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    //public GameObject goal;

    // private float buttR = 0f;

    public GameObject[] bodyParts = new GameObject[9];
    public Vector3[] posStart = new Vector3[9];
    public Vector3[] eulerStart = new Vector3[9];



    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        bodyParts = new GameObject[]{waist, buttR, buttL, thighR, thighL, legR, legL, footR, footL};

        for(int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
        }

    }

    public override void AgentReset() {
        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
        }
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);

        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            print("reset!");
            AddReward(-.1f);
            Done();
        }

        public override void CollectObservations() {
            AddVectorObs(waist.transform.localEulerAngles.y);
            AddVectorObs(buttR.transform.localEulerAngles.x);
            AddVectorObs(buttL.transform.localEulerAngles.x);
            AddVectorObs(thighR.transform.localEulerAngles.y);
            AddVectorObs(thighL.transform.localEulerAngles.y);
            AddVectorObs(legR.transform.localEulerAngles.y);
            AddVectorObs(legL.transform.localEulerAngles.y);
            AddVectorObs(footR.transform.localEulerAngles.y);
            AddVectorObs(footL.transform.localEulerAngles.y);
            AddVectorObs(waist.transform.position);
        }
    }




Code to Area:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;

public class BalancingArea : Area
{
    public List<BalanceAgent> BalanceAgent { get; private set; }
    public BalanceAcademy BalanceAcademy { get; private set; }
    public GameObject area;

    private void Awake() {
        BalanceAgent = transform.GetComponentsInChildren<BalanceAgent>().ToList();              //Grabs all agents in area
        BalanceAcademy = FindObjectOfType<BalanceAcademy>();                //Grabs balance acedem
    }

    private void Start() {

    }

    public void ResetAgentPosition(BalanceAgent agent) {
        agent.transform.position = new Vector3(area.transform.position.x, 0, area.transform.position.z);
        agent.transform.eulerAngles = new Vector3(0,0,0);
    }

    // Update is called once per frame
    void Update()
    {

    }
}




Code to BalanceAcademy:

using MLAgents;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAcademy : Academy
{

}



Command used to run trainer:

mlagents-learn config/trainer_config.yaml --run-id=balancetest09 --train
4

1 回答 1

1

从有关创建新环境的文档中:

初始化和重置代理

当代理到达其目标时,它会将自己标记为完成,并且其代理重置功能会将目标移动到随机位置。此外,如果代理滚下平台,重置功能会将其放回地板上。

要移动目标游戏对象,我们需要一个对其变换的引用(它存储游戏对象在 3D 世界中的位置、方向和比例)。要获取此引用,请将类型的公共字段添加Transform到 RollerAgent 类。Unity 中组件的公共字段显示在 Inspector 窗口中,允许您选择将哪个 GameObject 用作​​ Unity 编辑器中的目标。

要重置代理的速度(以及稍后施加力来移动代理),我们需要对刚体组件的引用。刚体是 Unity 用于物理模拟的主要元素。(有关 Unity 物理的完整文档,请参阅 Physics。)由于 Rigidbody 组件与我们的 Agent 脚本位于同一 GameObject 上,因此获取此引用的最佳方法是使用GameObject.GetComponent<T>(),我们可以在脚本的Start()方法中调用它。

到目前为止,我们的 RollerAgent 脚本如下所示:

using System.Collections.Generic;
using UnityEngine;
using MLAgents;

public class RollerAgent : Agent
{
    Rigidbody rBody;
    void Start () {
        rBody = GetComponent<Rigidbody>();
    }

    public Transform Target;
    public override void AgentReset()
    {
        if (this.transform.position.y < 0)
        {
            // If the Agent fell, zero its momentum
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.position = new Vector3( 0, 0.5f, 0);
        }

        // Move the target to a new spot
        Target.position = new Vector3(Random.value * 8 - 4,
                                      0.5f,
                                      Random.value * 8 - 4);
    }
}

因此,您应该覆盖AgentReset方法,以便重置代理关节的位置。为了让您开始,您可以在 中获取每个关节的旋转和位置InitializeAgent,然后在 中恢复它们AgentReset。此外,将刚体的速度和角速度归零。

我在文档或示例中没有看到任何关于 call Donein 的内容Update,因此可能会建议甚至要求它AgentAction按预期运行。还不如将所有内容移出Update和移入AgentAction.

此外,您可能希望transform.localEulerAngles在具有 3 个分量 (xyz) 的特征向量中使用,而不是transform.localRotation具有 4 个分量 (xyzw) 的 。否则,您不应省略 的 w 分量localRotation

总而言之,它可能看起来像这样:

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class BalanceAgent : Agent
{
    private BalancingArea area;
    public GameObject waist;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;
    public GameObject goal;

    private List<GameObject> gameObjectsToReset;
    private List<Rigidbody> rigidbodiesToReset;
    private List<Vector3> initEulers;
    private List<Vector3> initPositions;

    // private float buttR = 0f;


    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();

        gameObjectsToReset= new List<GameObject>(new GameObject[]{
                waist, buttR, buttL, thighR, thighL, legR, legL,
                footR, footL});
        rigidbodiesToReset = new List<Rigidbody>();
        initEulers = new List<Vector3>();
        initPositions = new List<Vector3>();

        foreach (GameObject g in gameObjectsToReset) {
            rigidbodiesToReset.Add(g.GetComponent<Rigidbody>());
            initEulers.Add(g.transform.eulerAngles);
            initPositions.Add(g.transform.position);
        }
    }

    public override void AgentReset() {
        for (int i = 0 ; i < gameObjectsToReset.Count ; i++) {
            Transform t = gameObjectsToReset[i].transform;
            t.position = initPositions[i];
            t.eulerAngles = initEulers[i];

            Rigidbody r = rigidbodiesToReset[i];
            r.velocity = Vector3.zero;
            r.angularVelocity = Vector3.zero;
        } 
    }

    public override void AgentAction(float[] vectorAction) {

        int buttRDir = 0;
        int buttRVec = (int)vectorAction[0];
        switch (buttRVec) {
            case 3:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = -1;
                break;
            case 2:
                buttRDir = 1;
                break;
        }
        buttR.transform.Rotate(0, buttRDir, 0);

        int buttLDir = 0;
        int buttLVec = (int)vectorAction[1];
        switch (buttLVec) {
            case 3:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = -1;
                break;
            case 2:
                buttLDir = 1;
                break;
        }
        buttL.transform.Rotate(0, buttLDir, 0);

        int thighRDir = 0;
        int thighRVec = (int)vectorAction[2];
        switch (thighRVec) {
            case 3:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = -1;
                break;
            case 2:
                thighRDir = 1;
                break;
        }
        thighR.transform.Rotate(0, thighRDir, 0);

        int thighLDir = 0;
        int thighLVec = (int)vectorAction[3];
        switch (thighLVec) {
            case 3:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = -1;
                break;
            case 2:
                thighLDir = 1;
                break;
        }
        thighL.transform.Rotate(0, thighLDir, 0);

        int legRDir = 0;
        int legRVec = (int)vectorAction[4];
        switch (legRVec) {
            case 3:
                legRDir = 0;
                break;
            case 1:
                legRDir = -1;
                break;
            case 2:
                legRDir = 1;
                break;
        }
        legR.transform.Rotate(0, legRDir, 0);

        int legLDir = 0;
        int legLVec = (int)vectorAction[5];
        switch (legLVec) {
            case 3:
                legLDir = 0;
                break;
            case 1:
                legLDir = -1;
                break;
            case 2:
                legLDir = 1;
                break;
        }
        legL.transform.Rotate(0, legLDir, 0);

        int footRDir = 0;
        int footRVec = (int)vectorAction[6];
        switch (footRVec) {
            case 3:
                footRDir = 0;
                break;
            case 1:
                footRDir = -1;
                break;
            case 2:
                footRDir = 1;
                break;
        }
        footR.transform.Rotate(0, footRDir, 0);

        int footLDir = 0;
        int footLVec = (int)vectorAction[7];
        switch (footLVec) {
            case 3:
                footLDir = 0;
                break;
            case 1:
                footLDir = -1;
                break;
            case 2:
                footLDir = 1;
                break;
        }
        footL.transform.Rotate(0, footLDir, 0);



        //buttR = vectorAction[0]; //Right or none
        //if (buttR == 2) buttR = -1f; //Left

        if (waist.transform.position.y > -1.3) {
            AddReward(.1f);
        }
        else {
            AddReward(-.02f);
        }

        if (waist.transform.position.y <= -3) {
            Done();
            AddReward(-.1f);
        }
    }

    public override void CollectObservations() {
        AddVectorObs(waist.transform.localEulerAngles.y);
        AddVectorObs(buttR.transform.localEulerAngles.x);
        AddVectorObs(buttL.transform.localEulerAngles.x);
        AddVectorObs(thighR.transform.localEulerAngles.y);
        AddVectorObs(thighL.transform.localEulerAngles.y);
        AddVectorObs(legR.transform.localEulerAngles.y);
        AddVectorObs(legL.transform.localEulerAngles.y);
        AddVectorObs(footR.transform.localEulerAngles.y);
        AddVectorObs(footL.transform.localEulerAngles.y);

        AddVectorObs(waist.GetComponent<Rigidbody>().freezeRotation);

        AddVectorObs(waist.transform.position);
    }
}

最后,确保将 BalanceAgent 设置Max Step为足够大以查看代理是否会失败,对于初学者来说可能是 500 或 1000。

<code>Max Step</code> 在检查器中是可编辑的

于 2019-12-12T18:30:29.793 回答