Unity ML agentを学習するうえでUnity ml-agentのサンプルを改変したのでまとめました
対象者
- Unity ml-agentsの内容を確認して,自作ゲームなどに取り入れていきたい方
- Unityの基本はわかっている方
環境
- Unity 2019.4.12f1
- ML-Agents 1.0.5
- Anaconda 4.8.5
- Python 3.7.9
環境構築は,省きます
やったこと
Unityさんのこちらのドキュメントを少しづついじっています
以下のものはこちらで公開しています
3つのターゲットから得点の高いターゲットだけを取得する
詳細
set-up
3つのターゲットから一番報酬の高い赤に向かうものです. それぞれのターゲットに異なる報酬を設定しています.
報酬設定
- 赤:+1.0
- 青:+0.7
- 緑:+0.5f
Observation space
- ターゲットのposition: 3*3 = 9
- 自分自身のposition 3
- 自分の速度 (x,y) 2 合計 14
Action space
- 上下,左右の2つ
デモ
コード
using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; public class RollerAgent : Agent { Rigidbody _rBody; public Transform targetRed; public Transform targetBlue; public Transform targetGreen; void Start() { _rBody = GetComponent<Rigidbody>(); } public override void OnEpisodeBegin() { if (this.transform.localPosition.y < 0) { // If the Agent fell, zero its momentum this._rBody.angularVelocity = Vector3.zero; this._rBody.velocity = Vector3.zero; this.transform.localPosition = new Vector3(0, 0.5f, 0); } // Move the target to a new spot targetRed.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); targetBlue.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); targetGreen.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); } public override void CollectObservations(VectorSensor sensor) { // Target and Agent positions sensor.AddObservation(targetRed.localPosition); sensor.AddObservation(targetGreen.localPosition); sensor.AddObservation(targetBlue.localPosition); sensor.AddObservation(this.transform.localPosition); // Agent velocity sensor.AddObservation(_rBody.velocity.x); sensor.AddObservation(_rBody.velocity.z); } public float speed = 10; public override void OnActionReceived(float[] vectorAction) { // Actions, size = 2 Vector3 controlSignal = Vector3.zero; controlSignal.x = vectorAction[0]; controlSignal.z = vectorAction[1]; _rBody.AddForce(controlSignal * speed); // Rewards float distanceToTargetRed = Vector3.Distance(this.transform.localPosition, targetRed.localPosition); float distanceToTargetBlue = Vector3.Distance(this.transform.localPosition, targetBlue.localPosition); float distanceToTargetGreen = Vector3.Distance(this.transform.localPosition, targetGreen.localPosition); // Reached target if (distanceToTargetRed < 1.42f) { SetReward(1.0f); EndEpisode(); } if (distanceToTargetBlue < 1.42f) { SetReward(0.7f); EndEpisode(); } if (distanceToTargetGreen < 1.42f) { SetReward(0.5f); EndEpisode(); } // Fell off platform if (this.transform.localPosition.y < 0) { EndEpisode(); } } public override void Heuristic(float[] actionsOut) { actionsOut[0] = Input.GetAxis("Horizontal"); actionsOut[1] = Input.GetAxis("Vertical"); } }
一つ下の床にあるターゲットに向かう
詳細
set-up
床が二つあり,より下位にある床のターゲットを取得する
報酬設定
- 青:+1.0
Observation space
- ターゲットのposition: 3
- 自分自身のposition 3
- 自分の速度 (x,y) 2 合計 8
Action space
- 上下,左右の2つ
デモ
コード
using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; public class DownBall : Agent { Rigidbody _rBody; public Transform target; void Start() { _rBody = GetComponent<Rigidbody>(); } public override void OnEpisodeBegin() { if (this.transform.localPosition.y < 0) { // If the Agent fell, zero its momentum this._rBody.angularVelocity = Vector3.zero; this._rBody.velocity = Vector3.zero; this.transform.localPosition = new Vector3(0, 0.5f, 0); } // Move the target to a new spot target.localPosition = new Vector3(Random.value * 8 - 12, -3.5f, Random.value * 8 -4); } public override void CollectObservations(VectorSensor sensor) { // Target and Agent positions sensor.AddObservation(target.localPosition); sensor.AddObservation(this.transform.localPosition); // Agent velocity sensor.AddObservation(_rBody.velocity.x); sensor.AddObservation(_rBody.velocity.z); } public float speed = 10; public override void OnActionReceived(float[] vectorAction) { // Actions, size = 2 Vector3 controlSignal = Vector3.zero; controlSignal.x = vectorAction[0]; controlSignal.z = vectorAction[1]; _rBody.AddForce(controlSignal * speed); // Rewards float distanceToTarget = Vector3.Distance(this.transform.localPosition, target.localPosition); // Reached target if (distanceToTarget < 1.42f) { SetReward(1.0f); EndEpisode(); } // Fell off platform if (this.transform.localPosition.y < -4f) { EndEpisode(); } } public override void Heuristic(float[] actionsOut) { actionsOut[0] = Input.GetAxis("Horizontal"); actionsOut[1] = Input.GetAxis("Vertical"); } }
2つのターゲットを取得する
詳細
set-up
2つのターゲットがそれぞれ高さの違う床に設置されている. 緑のターゲットはランダムで設置される (両方とるように設定)
報酬設定
- 紫:+1.0
- 緑:+0.7
Observation space
- ターゲットのposition: 3*2 = 6
- 自分自身のposition 3
- 自分の速度 (x,y) 2 合計 12
Action space
- 上下,左右の2つ
デモ
コード
using System; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Random = UnityEngine.Random; public class TwoTarget : Agent { Rigidbody _rBody; public Transform targetGreen; public Transform goal; void Start() { _rBody = GetComponent<Rigidbody>(); } public override void OnEpisodeBegin() { if (this.transform.localPosition.y < 0) { // If the Agent fell, zero its momentum this._rBody.angularVelocity = Vector3.zero; this._rBody.velocity = Vector3.zero; this.transform.localPosition = new Vector3(0, 0.5f, 0); } RecreateTarget(); ActiveTarget(); } public override void CollectObservations(VectorSensor sensor) { // Target and Agent positions sensor.AddObservation(targetGreen.localPosition); sensor.AddObservation(goal.localPosition); sensor.AddObservation(this.transform.localPosition); // Agent velocity sensor.AddObservation(_rBody.velocity.x); sensor.AddObservation(_rBody.velocity.z); } public float speed = 10; public override void OnActionReceived(float[] vectorAction) { // Actions, size = 2 Vector3 controlSignal = Vector3.zero; controlSignal.x = vectorAction[0]; controlSignal.z = vectorAction[1]; _rBody.AddForce(controlSignal * speed); // Fell off platform if (this.transform.localPosition.y < -4f) { EndEpisode(); } } private void OnCollisionEnter(Collision other) { if (other.gameObject.CompareTag("Target/Green")) { AddReward(0.7f); other.gameObject.SetActive(false); } if (other.gameObject.CompareTag("Goal")) { AddReward(1.0f); EndEpisode(); } } public override void Heuristic(float[] actionsOut) { actionsOut[0] = Input.GetAxis("Horizontal"); actionsOut[1] = Input.GetAxis("Vertical"); } void RecreateTarget() { // Move the target to a new spot targetGreen.localPosition = new Vector3(Random.value * 8 - 4-7, -1.5f, Random.value * 8 - 4-7); } void ActiveTarget() { targetGreen.gameObject.SetActive(true); } }