Unity ML Agentの基本の改変サンプル

Unity ML agentを学習するうえでUnity ml-agentのサンプルを改変したのでまとめました

対象者

  • Unity ml-agentsの内容を確認して,自作ゲームなどに取り入れていきたい方
  • Unityの基本はわかっている方

環境

  • Unity 2019.4.12f1
  • ML-Agents 1.0.5
  • Anaconda 4.8.5
  • Python 3.7.9

環境構築は,省きます

やったこと

Unityさんのこちらのドキュメントを少しづついじっています f:id:ayousanz:20201009210642p:plain

以下のものはこちらで公開しています

github.com

3つのターゲットから得点の高いターゲットだけを取得する

詳細

set-up

3つのターゲットから一番報酬の高い赤に向かうものです. それぞれのターゲットに異なる報酬を設定しています.

報酬設定

  • 赤:+1.0
  • 青:+0.7
  • 緑:+0.5f

Observation space

  • ターゲットのposition: 3*3 = 9
  • 自分自身のposition 3
  • 自分の速度 (x,y) 2 合計 14

Action space

  • 上下,左右の2つ

デモ

f:id:ayousanz:20201009163931g:plain

コード

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;

public class RollerAgent : Agent
{
    Rigidbody _rBody;
    public Transform targetRed;
    public Transform targetBlue;
    public Transform targetGreen;
    void Start()
    {
        _rBody = GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        if (this.transform.localPosition.y < 0)
        {
            // If the Agent fell, zero its momentum
            this._rBody.angularVelocity = Vector3.zero;
            this._rBody.velocity = Vector3.zero;
            this.transform.localPosition = new Vector3(0, 0.5f, 0);
        }

        // Move the target to a new spot
        targetRed.localPosition = new Vector3(Random.value * 8 - 4,
                                           0.5f,
                                           Random.value * 8 - 4);
        targetBlue.localPosition = new Vector3(Random.value * 8 - 4,
            0.5f,
            Random.value * 8 - 4);
        targetGreen.localPosition = new Vector3(Random.value * 8 - 4,
            0.5f,
            Random.value * 8 - 4);
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Target and Agent positions
        sensor.AddObservation(targetRed.localPosition);
        sensor.AddObservation(targetGreen.localPosition);
        sensor.AddObservation(targetBlue.localPosition);
        sensor.AddObservation(this.transform.localPosition);

        // Agent velocity
        sensor.AddObservation(_rBody.velocity.x);
        sensor.AddObservation(_rBody.velocity.z);
    }

    public float speed = 10;
    public override void OnActionReceived(float[] vectorAction)
    {
        // Actions, size = 2
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        _rBody.AddForce(controlSignal * speed);

        // Rewards
        float distanceToTargetRed = Vector3.Distance(this.transform.localPosition, targetRed.localPosition);
        float distanceToTargetBlue = Vector3.Distance(this.transform.localPosition, targetBlue.localPosition);
        float distanceToTargetGreen = Vector3.Distance(this.transform.localPosition, targetGreen.localPosition);

        // Reached target
        if (distanceToTargetRed < 1.42f)
        {
            SetReward(1.0f);
            EndEpisode();
        }
        if (distanceToTargetBlue < 1.42f)
        {
            SetReward(0.7f);
            EndEpisode();
        }
        if (distanceToTargetGreen < 1.42f)
        {
            SetReward(0.5f);
            EndEpisode();
        }

        // Fell off platform
        if (this.transform.localPosition.y < 0)
        {
            EndEpisode();
        }
    }

    public override void Heuristic(float[] actionsOut)
    {
        actionsOut[0] = Input.GetAxis("Horizontal");
        actionsOut[1] = Input.GetAxis("Vertical");
    }
}

一つ下の床にあるターゲットに向かう

詳細

set-up

床が二つあり,より下位にある床のターゲットを取得する

報酬設定

  • 青:+1.0

Observation space

  • ターゲットのposition: 3
  • 自分自身のposition 3
  • 自分の速度 (x,y) 2 合計 8

Action space

  • 上下,左右の2つ

デモ

f:id:ayousanz:20201009201159g:plain

コード

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;

public class DownBall : Agent
{
    Rigidbody _rBody;
    public Transform target;
    void Start()
    {
        _rBody = GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        if (this.transform.localPosition.y < 0)
        {
            // If the Agent fell, zero its momentum
            this._rBody.angularVelocity = Vector3.zero;
            this._rBody.velocity = Vector3.zero;
            this.transform.localPosition = new Vector3(0, 0.5f, 0);
        }

        // Move the target to a new spot
        target.localPosition = new Vector3(Random.value * 8 - 12,
                                           -3.5f,
                                           Random.value * 8 -4);
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Target and Agent positions
        sensor.AddObservation(target.localPosition);
        sensor.AddObservation(this.transform.localPosition);

        // Agent velocity
        sensor.AddObservation(_rBody.velocity.x);
        sensor.AddObservation(_rBody.velocity.z);
    }

    public float speed = 10;
    public override void OnActionReceived(float[] vectorAction)
    {
        // Actions, size = 2
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        _rBody.AddForce(controlSignal * speed);

        // Rewards
        float distanceToTarget = Vector3.Distance(this.transform.localPosition, target.localPosition);

        // Reached target
        if (distanceToTarget < 1.42f)
        {
            SetReward(1.0f);
            EndEpisode();
        }

        // Fell off platform
        if (this.transform.localPosition.y < -4f)
        {
            EndEpisode();
        }
    }

    public override void Heuristic(float[] actionsOut)
    {
        actionsOut[0] = Input.GetAxis("Horizontal");
        actionsOut[1] = Input.GetAxis("Vertical");
    }
}

2つのターゲットを取得する

詳細

set-up

2つのターゲットがそれぞれ高さの違う床に設置されている. 緑のターゲットはランダムで設置される (両方とるように設定)

報酬設定

  • 紫:+1.0
  • 緑:+0.7

Observation space

  • ターゲットのposition: 3*2 = 6
  • 自分自身のposition 3
  • 自分の速度 (x,y) 2 合計 12

Action space

  • 上下,左右の2つ

デモ

f:id:ayousanz:20201009205411g:plain

コード

using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;

public class TwoTarget : Agent
{
    Rigidbody _rBody;
    public Transform targetGreen;
    public Transform goal;
    void Start()
    {
        _rBody = GetComponent<Rigidbody>();
    }

    public override void OnEpisodeBegin()
    {
        if (this.transform.localPosition.y < 0)
        {
            // If the Agent fell, zero its momentum
            this._rBody.angularVelocity = Vector3.zero;
            this._rBody.velocity = Vector3.zero;
            this.transform.localPosition = new Vector3(0, 0.5f, 0);
        }
        RecreateTarget();
        ActiveTarget();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // Target and Agent positions
        sensor.AddObservation(targetGreen.localPosition);
        sensor.AddObservation(goal.localPosition);
        sensor.AddObservation(this.transform.localPosition);

        // Agent velocity
        sensor.AddObservation(_rBody.velocity.x);
        sensor.AddObservation(_rBody.velocity.z);
    }

    public float speed = 10;
    public override void OnActionReceived(float[] vectorAction)
    {
        // Actions, size = 2
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        _rBody.AddForce(controlSignal * speed);

        // Fell off platform
        if (this.transform.localPosition.y < -4f)
        {
            EndEpisode();
        }
    }

    private void OnCollisionEnter(Collision other)
    {
        if (other.gameObject.CompareTag("Target/Green"))
        {
            AddReward(0.7f);
            other.gameObject.SetActive(false);
        }

        if (other.gameObject.CompareTag("Goal"))
        {
            AddReward(1.0f);
            EndEpisode();
        }
    }

    public override void Heuristic(float[] actionsOut)
    {
        actionsOut[0] = Input.GetAxis("Horizontal");
        actionsOut[1] = Input.GetAxis("Vertical");
    }

    void RecreateTarget()
    {
        // Move the target to a new spot
        targetGreen.localPosition = new Vector3(Random.value * 8 - 4-7,
            -1.5f,
            Random.value * 8 - 4-7);
    }

    void ActiveTarget()
    {
        targetGreen.gameObject.SetActive(true);
    }
}