ML-agentsの複数の実行はスレッド数に依存??
— ようさん (@ayousanz) 2020年10月11日
誰か調べている人いないかな
ちょっと探したけど見つからなかった
ちょっと調べてみたのですが,結局わからなかったので自分で測ってみました 結果は,私の環境ですので他のところだと違う結果になる可能性は高いと思われます.
結論
今回はコア数に依存しました
環境
- Ryzen 7 2700X Eight-Core Processor ( 8core/16thread)
- memory 64GB
- Unity2019.4.12f1
- ML-Agent 1.0.5
- Anaconda 4.8.5
- Python 3.7.9
- tensorflow 2.3.1
測定するゲーム
以下の画像のもので測定を行いました
スタート位置・ゴール位置は固定です
Agent 設定
ハイパーパラメータ設定
behaviors: YousanSideGame: trainer_type: ppo hyperparameters: batch_size: 2048 buffer_size: 10240 learning_rate: 3.0e-2 beta: 0.005 epsilon: 0.2 lambd: 0.95 num_epoch: 3 learning_rate_schedule: linear network_settings: normalize: true hidden_units: 64 num_layers: 2 reward_signals: extrinsic: gamma: 0.99 strength: 1.0 keep_checkpoints: 5 checkpoint_interval: 500000 max_steps: 200000 time_horizon: 512 summary_freq: 10000 threaded: true
agentのscript
PlayerController.cs
using UniRx; using UniRx.Triggers; using UnityEngine; public class PlayerController : MonoBehaviour { [SerializeField] private float speed = 0.3f; private Rigidbody2D _rigidbody2D; [SerializeField] private float jumpSpeed = 300.0f; [SerializeField] private bool isJump = true; [SerializeField] private GameObject startPoint; // Start is called before the first frame update void Start() { _rigidbody2D = GetComponent<Rigidbody2D>(); // this.UpdateAsObservable().Subscribe(_ => // { // // float horizontalInput = Input.GetAxis("Horizontal"); // float verticalInput = Input.GetAxis("Vertical"); // // if (horizontalInput > 0) // { // RightMove(horizontalInput,verticalInput); // }else if (horizontalInput < 0) // { // LeftMove(horizontalInput,verticalInput); // } // }); // // this.UpdateAsObservable().Where(_ => Input.GetKey(KeyCode.Space) && isJump).Subscribe(_ => // { // Jump(); // }); this.UpdateAsObservable().Where(_ => Input.GetKeyDown(KeyCode.R)).Subscribe(_ => { Reset(); }); } public void RightMove(float hInput) { // Debug.Log("right move"); transform.localScale = new Vector3(1,1,1); Vector2 input = new Vector2(hInput,0f); _rigidbody2D.velocity = input.normalized * speed; } public void LeftMove(float hInput) { // Debug.Log("left move"); transform.localScale = new Vector3(-1,1,1); Vector2 input = new Vector2(hInput,0f); _rigidbody2D.velocity = input.normalized * speed; } public void Jump() { _rigidbody2D.velocity = Vector2.up*jumpSpeed; isJump = false; } private void OnCollisionEnter2D(Collision2D other) { if (other.gameObject.CompareTag("Ground")) { isJump = true; } } public void Reset() { transform.localPosition = startPoint.transform.localPosition; _rigidbody2D.velocity = Vector2.zero; _rigidbody2D.AddForce(Vector2.zero); } public bool GetIsJump() { return isJump; } }
PlayerAgent
using Unity.MLAgents; using Unity.MLAgents.Sensors; using UnityEngine; public class PlayerAgent : Agent { public PlayerController playerController; private Rigidbody2D _rigidbody2D; public Transform endPoint; public override void Initialize() { _rigidbody2D = GetComponent<Rigidbody2D>(); } public override void OnEpisodeBegin() { playerController.Reset(); } public override void CollectObservations(VectorSensor sensor) { // gold position : 2point sensor.AddObservation(endPoint.transform.localPosition.x); sensor.AddObservation(endPoint.transform.localPosition.y); // Agent position : 2point sensor.AddObservation(transform.localPosition.x); sensor.AddObservation(transform.localPosition.y); // Agent velocity :2point sensor.AddObservation(_rigidbody2D.velocity.x); sensor.AddObservation(_rigidbody2D.velocity.y); } public override void OnActionReceived(float[] vectorAction) { AddReward(-0.0001f); float h = vectorAction[0]; // if(vectorAction[1] == 1f && playerController.GetIsJump()) playerController.Jump(); if(0f < h) playerController.RightMove(h); else if(h < 0f) playerController.LeftMove(h); } public override void Heuristic(float[] actionsOut) { actionsOut[0] = Input.GetAxis("Horizontal"); // actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f; } private void OnTriggerEnter2D(Collider2D other) { if (other.gameObject.CompareTag("EndPoint")) { Debug.Log("is goal"); SetReward(1.0f); EndEpisode(); }else if (other.gameObject.CompareTag("Finish")) { Debug.Log("is game over"); SetReward(0.0f); EndEpisode(); } } }
結果
step数:190000での比較です(一回目のスクショがミスっていたため,200000が映ってなかったです..)
- 4 env : Time Elapsed 240.355s
- 8 env : Time Elapsed 217.626s
- 16 env : Time Elapsed 222.897s
- 32 env : Time Elapsed 215.245s
ログスクショ
以下,それぞれの実行環境数の結果スクショです
TensorBoardスクショ
Cumulative Reward
Episode Length
Policy