본문 바로가기
유니티/ML-Agents

[ Unity ] ML-Agent 를 활용한 게임 개발 I - 기록용

by Hexs 2023. 11. 16.
반응형

GrowingAITest.yaml 내용.

해당 값들을 변경해주며 최적의 학습 데이터값을 찾아야합니다.

behaviors:
    GrowingAITest:
        trainer_type: ppo
        hyperparameters:
            batch_size: 512
            buffer_size: 131072
            learning_rate: 0.0003
            beta: 0.005
            epsilon: 0.2
            lambd: 0.95
            num_epoch: 3
            learning_rate_schedule: linear
        network_settings:
            normalize: false
            hidden_units: 128
            num_layers: 2
            vis_encode_type: simple
        reward_signals:
            extrinsic:
                gamma: 0.99
                strength: 1.0
        keep_checkpoints: 5
        max_steps: 3000000
        time_horizon: 64
        summary_freq: 50000

 


 

코드

틀 정도만 제작한 코드입니다.

시간 제한은 30초. 30초 이내로 모든 Objects를 충돌할시 추가적인 Reward가 지급됩니다.

 

또한 각각의 Objects 들은 Reward 값이 모두 다르며.

무지개 색 순서대로 Reward 값이 증가합니다.

Red = 0.5

Purple = 1.7 ... 등

 

using System;
using System.Linq;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;
//using System;
using Random = UnityEngine.Random;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;
public class CreateModel : Agent
{
    [Header(" [ Int ]")]
    [SerializeField] int cmSpawnRange;
    [Space(20f)]
    [Header(" [ Float ]")]
    public float cmForce;
    public float cmRotateForce;
    public float cmCurLimitTime;
    public float cmLimitTime;
    [Space(20f)]
    [Header(" [ Bool ]")]
    [SerializeField] List<bool> cmTargerCheck = new List<bool>();
    [Space(20f)]
    [Header(" [ GameObject ]")]
    [SerializeField] GameObject cmLimitTimeText;

    [Space(20f)]
    [Header(" [ Other ]")]
    public List<Transform> cmTarget = new List<Transform>();

    [SerializeField] LayerMask cmObjectLayer;

    [SerializeField] Rigidbody cmRb;
    public override void Initialize()
    {
        cmRb = GetComponent<Rigidbody>();
    }
    // Start is called before the first frame update
    void Start()
    {

    }

    // Update is called once per frame
    void Update()
    {

    }

    public void MoveAgent(ActionSegment<int> act)
    {
        var dirToGo = Vector3.zero;
        var rotateDir = Vector3.zero;

        var action = act[0];
        switch (action)
        {
            case 1:
                dirToGo = transform.forward * 1f;
                break;
            case 2:
                dirToGo = transform.forward * -1f;
                break;
            case 3:
                rotateDir = transform.up * 1f;
                break;
            case 4:
                rotateDir = transform.up * -1f;
                break;
        }
        transform.Rotate(rotateDir, Time.deltaTime * cmRotateForce);
        cmRb.AddForce(dirToGo * cmForce, ForceMode.VelocityChange);
    }

    public override void OnEpisodeBegin() // 에피소드 시작시.
    {
        for (int i = 0; i < cmTargerCheck.Count; i++)
        {
            try
            {
                cmTargerCheck[i] = false;
            }
            catch (System.Exception)
            {
                Debug.Log("Out Of Index -- cmTargetCheck");
                throw;
            }
        }

        cmRb.angularVelocity = Vector3.zero;
        cmRb.velocity = Vector3.zero;
        this.transform.localPosition = new Vector3(0, 0.5f, 0);
        this.transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
        cmCurLimitTime = cmLimitTime;

        for (int i = 0; i < cmTarget.Count; i++)
        {
            int num1 = Random.Range(0, 2);
            if (num1 == 0)
                num1 = -1;
            int num2 = Random.Range(0, 2);
            if (num2 == 0)
                num2 = -1;
            try
            {
                cmTarget[i].gameObject.SetActive(true);
                cmTarget[i].localPosition = new Vector3(Random.value * cmSpawnRange * num1, 0.5f, Random.value * cmSpawnRange * num2);
            }
            catch (System.Exception)
            {
                Debug.Log("Out Of Index -- cmTarget");
                throw;
            }

        }

        base.OnEpisodeBegin();
    }
    /// <summary>
    /// 강화학습 프로그램에게 관측정보를 전달
    /// </summary>
    /// <param name="sensor"></param>
    public override void CollectObservations(VectorSensor sensor)
    {

    }

    public override void OnActionReceived(ActionBuffers actionBuffers) // 에피소드 코드.
    {
        // Actions, size = 2
        MoveAgent(actionBuffers.DiscreteActions);

        cmCurLimitTime -= Time.deltaTime;
        cmLimitTimeText.GetComponent<TextMeshPro>().text = cmCurLimitTime.ToString("F2");

        if (cmCurLimitTime <= 0) // 시간초과 인경우
        {
            AddReward(-2f);
            EndEpisode(); // 에피소드 종료
        }

        AddReward(-1f / MaxStep);


    }

    private void OnCollisionEnter(Collision col)
    {
        if (col.gameObject.layer.Equals(LayerMask.NameToLayer("Objects")))
        {
            AddReward(col.gameObject.GetComponent<Objects>().oReward);
            col.gameObject.SetActive(false);
            cmTargerCheck[col.gameObject.GetComponent<Objects>().oIndex] = true;
        }
        else if (col.gameObject.CompareTag("Walls"))
        {
            AddReward(-2f);
            EndEpisode();
        }
        else if (col.gameObject.CompareTag("EndObjects"))
        {
            bool endcheck = true;
            foreach (bool check in cmTargerCheck)
            {
                if (check == false)
                {
                    endcheck = false;
                    AddReward(-1f);
                    EndEpisode();
                }
            }
            if (endcheck)
            {
                AddReward(cmCurLimitTime / cmLimitTime);
                EndEpisode();
            }
        }
    }
}

 

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class Objects : MonoBehaviour
{
    public int oIndex;
    public float oReward;
    // Start is called before the first frame update
    void Start()
    {

    }

    // Update is called once per frame
    void Update()
    {

    }
}

 

Agent에 사용한 Ray Perception 센서는 총 3개 입니다.

1. 벽을 제외한 모든 오브젝트를 체크.

2. 넓은 범위를 사용하여 벽과 EndObject 체크

3. 좁은 범위를 사용하여 벽을 제외한 모든 오브젝트 체크.

 

 

 

 

동일한 환경 64개 를 동시에 학습시키는데

35분정도 소요되었습니다.

 

방치형 게임을 제작할 생각이고

많은 테스트 공간과 모델이 필요할것으로 생각됩니다.

 

 

반응형