Unity ML-Agents 複数環境実行でのCPUのコア数とスレッド数の違いにおける処理速度

ちょっと調べてみたのですが,結局わからなかったので自分で測ってみました 結果は,私の環境ですので他のところだと違う結果になる可能性は高いと思われます.

結論

今回はコア数に依存しました

環境

  • Ryzen 7 2700X Eight-Core Processor ( 8core/16thread)
  • memory 64GB
  • Unity2019.4.12f1
  • ML-Agent 1.0.5
  • Anaconda 4.8.5
  • Python 3.7.9
  • tensorflow 2.3.1

測定するゲーム

以下の画像のもので測定を行いました f:id:ayousanz:20201011202326p:plain

スタート位置・ゴール位置は固定です

Agent 設定

f:id:ayousanz:20201011202442p:plain

ハイパーパラメータ設定

behaviors:
  YousanSideGame:
    trainer_type: ppo
    hyperparameters:
      batch_size: 2048
      buffer_size: 10240
      learning_rate: 3.0e-2
      beta: 0.005
      epsilon: 0.2
      lambd: 0.95
      num_epoch: 3
      learning_rate_schedule: linear
    network_settings:
      normalize: true
      hidden_units: 64
      num_layers: 2
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    keep_checkpoints: 5
    checkpoint_interval: 500000
    max_steps: 200000
    time_horizon: 512
    summary_freq: 10000
    threaded: true

agentのscript

PlayerController.cs

using UniRx;
using UniRx.Triggers;
using UnityEngine;

public class PlayerController : MonoBehaviour
{
    [SerializeField] private float speed = 0.3f;
    
    private Rigidbody2D _rigidbody2D;
    [SerializeField] private float jumpSpeed = 300.0f;
    [SerializeField] private bool isJump = true;
    [SerializeField] private GameObject startPoint;

    // Start is called before the first frame update
    void Start()
    {
        _rigidbody2D = GetComponent<Rigidbody2D>();
        // this.UpdateAsObservable().Subscribe(_ =>
        // {
        //     
        //     float horizontalInput = Input.GetAxis("Horizontal");
        //     float verticalInput = Input.GetAxis("Vertical");
        //     
        //     if (horizontalInput > 0)
        //     {
        //         RightMove(horizontalInput,verticalInput);
        //     }else if (horizontalInput < 0)
        //     {
        //         LeftMove(horizontalInput,verticalInput);
        //     }
        // });
        //
        // this.UpdateAsObservable().Where(_ => Input.GetKey(KeyCode.Space) && isJump).Subscribe(_ =>
        // {
        //     Jump();
        // });
        this.UpdateAsObservable().Where(_ => Input.GetKeyDown(KeyCode.R)).Subscribe(_ =>
        {
            Reset();
        });
    }

    public void RightMove(float hInput)
    {
        // Debug.Log("right move");
        transform.localScale = new Vector3(1,1,1);
        Vector2 input = new Vector2(hInput,0f);
        _rigidbody2D.velocity = input.normalized * speed;
    }

    public void LeftMove(float hInput)
    {
        // Debug.Log("left move");
        transform.localScale = new Vector3(-1,1,1);
        Vector2 input = new Vector2(hInput,0f);
        _rigidbody2D.velocity = input.normalized * speed;
    }

    public void Jump()
    {
        _rigidbody2D.velocity = Vector2.up*jumpSpeed;
        isJump = false;
    }
    

    private void OnCollisionEnter2D(Collision2D other)
    {
        if (other.gameObject.CompareTag("Ground"))
        {
            isJump = true;
        }
    }
    public void Reset()
    {
        transform.localPosition = startPoint.transform.localPosition;
        _rigidbody2D.velocity = Vector2.zero;
        _rigidbody2D.AddForce(Vector2.zero);
    }

    public bool GetIsJump()
    {
        return isJump;
    }
    
    
}

PlayerAgent

using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using UnityEngine;

public class PlayerAgent : Agent
{
    public PlayerController playerController;
    private Rigidbody2D _rigidbody2D;
    public Transform endPoint;

    public override void Initialize()
    {
        _rigidbody2D = GetComponent<Rigidbody2D>();
    }

    public override void OnEpisodeBegin()
    {
        playerController.Reset();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        // gold position : 2point
        sensor.AddObservation(endPoint.transform.localPosition.x);
        sensor.AddObservation(endPoint.transform.localPosition.y);
        
        // Agent position : 2point
        sensor.AddObservation(transform.localPosition.x);
        sensor.AddObservation(transform.localPosition.y);
        
        // Agent velocity :2point
        sensor.AddObservation(_rigidbody2D.velocity.x); 
        sensor.AddObservation(_rigidbody2D.velocity.y);
        
    }

    public override void OnActionReceived(float[] vectorAction)
    {
        AddReward(-0.0001f);
        float h = vectorAction[0];
        // if(vectorAction[1] == 1f && playerController.GetIsJump()) playerController.Jump();
        if(0f < h) playerController.RightMove(h);
        else if(h < 0f) playerController.LeftMove(h);
        
    }

    public override void Heuristic(float[] actionsOut)
    {
        actionsOut[0] = Input.GetAxis("Horizontal");
        // actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
    }
    
    private void OnTriggerEnter2D(Collider2D other)
    {
        if (other.gameObject.CompareTag("EndPoint"))
        {
            Debug.Log("is goal");
            SetReward(1.0f);
            EndEpisode();
        }else if (other.gameObject.CompareTag("Finish"))
        {
            Debug.Log("is game over");
           SetReward(0.0f);
           EndEpisode();
        }
    }

    
    
}

結果

step数:190000での比較です(一回目のスクショがミスっていたため,200000が映ってなかったです..)

  • 4 env : Time Elapsed 240.355s
  • 8 env : Time Elapsed 217.626s
  • 16 env : Time Elapsed 222.897s
  • 32 env : Time Elapsed 215.245s

ログスクショ

以下,それぞれの実行環境数の結果スクショです

f:id:ayousanz:20201011203201j:plain
実行環境数:4
f:id:ayousanz:20201011203213j:plain
実行環境数:8
f:id:ayousanz:20201011203303j:plain
実行環境数:16
f:id:ayousanz:20201011203318j:plain
実行環境数:32

TensorBoardスクショ

Cumulative Reward

f:id:ayousanz:20201011203814p:plain

Episode Length

f:id:ayousanz:20201011203841p:plain

Policy

f:id:ayousanz:20201011203905p:plain