https://www.immersivelimit.com/tutorials/reinforcement-learning-penguins-part-1-unity-ml-agents
Reinforcement Learning Penguins (Part 1/4) | Unity ML-Agents — Immersive Limit
Unity Project Setup and Asset Import
www.immersivelimit.com
https://docs.unity3d.com/ScriptReference/TooltipAttribute.html
Unity - Scripting API: TooltipAttribute
Tooltip hovering over the class it was added to. In the following script a Tooltip is added. This provides information to the user about the range of values for the health variable. The suggested range is provided in the TooltipAttribute string. Note: Unit
docs.unity3d.com
Method GetCumulativeReward | ML Agents | 3.0.0-exp.1
Method GetCumulativeReward GetCumulativeReward() Retrieves the episode reward for the Agent. Declaration public float GetCumulativeReward() Returns Type Description float The episode reward.
docs.unity3d.com
using System.Collections;
using System.Collections.Generic;
using UnityEditor;
using UnityEngine;
[CustomEditor(typeof(PenguinArea))]
public class PenguinAreaEditor : Editor
{
public override void OnInspectorGUI()
{
PenguinArea penguinArea = (PenguinArea)target;
base.OnInspectorGUI();
if (GUILayout.Button("Choose Random Position"))
{
Debug.Log("ChooseRandomPosition");
Vector3 center = penguinArea.transform.position;
float minAngle = -15;
float maxAngle = 15;
float minRadius = 1;
float maxRadius = 2;
Vector3 randPos =
PenguinArea.ChooseRandomPosition(center, minAngle, maxAngle, minRadius, maxRadius);
Debug.Log(randPos);
Debug.DrawLine(center, randPos, Color.red, 10);
}
}
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;
public class PenguinArea : MonoBehaviour
{
[Tooltip("The agent inside the area")]
public PenguinAgent penguinAgent;
[Tooltip("The baby penguin inside the area")]
public GameObject penguinBaby;
[Tooltip("The TextMeshPro text that shows the cumulative reward of the agent")]
public TextMeshPro cumulativeRewardText;
[Tooltip("Prefab of a live fish")]
public Fish fishPrefab;
private List<GameObject> fishList;
//씬에 남아 있는 물고기 수
public int FishRemaining
{
get { return fishList.Count; }
}
private void Start()
{
this.fishList = new List<GameObject>();
ResetArea();
}
private void Update()
{
// Update the cumulative reward text
cumulativeRewardText.text = penguinAgent.GetCumulativeReward().ToString("0.00");
}
public void ResetArea()
{
RemoveAllFish();
PlacePenguin();
PlaceBaby();
SpawnFish(4, .5f);
}
//엄마 펭귄의 이동속도, 회전속도 초기화 및 랜덤위치, 랜덤회전
private void PlacePenguin()
{
Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
}
private void PlaceBaby()
{
Rigidbody rigidbody = penguinBaby.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
penguinBaby.transform.position = ChooseRandomPosition(transform.position, -45f, 45f, 4f, 9f) + Vector3.up * .5f;
penguinBaby.transform.rotation = Quaternion.Euler(0f, 180f, 0f);
}
//물고기를 먹었다면 리스트에서 제거하고 씬에서 제거한다
public void RemoveSpecificFish(GameObject fishObject)
{
fishList.Remove(fishObject);
Destroy(fishObject);
}
public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
{
float radius = minRadius;
float angle = minAngle;
if (maxRadius > minRadius)
{
// Pick a random radius
radius = UnityEngine.Random.Range(minRadius, maxRadius);
}
if (maxAngle > minAngle)
{
// Pick a random angle
angle = UnityEngine.Random.Range(minAngle, maxAngle);
}
// Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
}
//count만큼 fish프리팹 인스턴스 생성, 리스트에 넣기, 이동속도 설정하기
private void SpawnFish(int count, float fishSpeed)
{
for (int i = 0; i < count; i++)
{
// Spawn and place the fish
GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
// Set the fish's parent to this area's transform
fishObject.transform.SetParent(transform);
// Keep track of the fish
fishList.Add(fishObject);
// Set the fish speed
fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
}
}
private void RemoveAllFish()
{
if (fishList != null)
{
for (int i = 0; i < fishList.Count; i++)
{
if (fishList[i] != null)
{
Destroy(fishList[i]);
}
}
}
fishList = new List<GameObject>();
}
private void OnDrawGizmos()
{
Vector3 center = this.transform.position;
Vector3 dir = Vector3.forward;
float angleRange = 30f;
float radius = 2f;
GizmosExtensions.DrawWireArc(center, dir, angleRange, radius);
}
}
using UnityEngine;
public class GizmosExtensions
{
private GizmosExtensions() { }
public static void DrawWireArc(Vector3 position, Vector3 dir, float anglesRange, float radius, Color color, float maxSteps = 20)
{
Gizmos.color = color;
var srcAngles = GetAnglesFromDir(position, dir);
var initialPos = position;
var posA = initialPos;
var stepAngles = anglesRange / maxSteps;
var angle = srcAngles - anglesRange / 2;
for (var i = 0; i <= maxSteps; i++)
{
var rad = Mathf.Deg2Rad * angle;
var posB = initialPos;
posB += new Vector3(radius * Mathf.Cos(rad), radius * Mathf.Sin(rad), 0);
Gizmos.DrawLine(posA, posB);
angle += stepAngles;
posA = posB;
}
Gizmos.DrawLine(posA, initialPos);
}
/// <summary>
/// Draws a wire arc.
/// </summary>
/// <param name="position"></param>
/// <param name="dir">The direction from which the anglesRange is taken into account</param>
/// <param name="anglesRange">The angle range, in degrees.</param>
/// <param name="radius"></param>
/// <param name="maxSteps">How many steps to use to draw the arc.</param>
public static void DrawWireArc(Vector3 position, Vector3 dir, float anglesRange, float radius, float maxSteps = 20)
{
var srcAngles = GetAnglesFromDir(position, dir);
var initialPos = position;
var posA = initialPos;
var stepAngles = anglesRange / maxSteps;
var angle = srcAngles - anglesRange / 2;
for (var i = 0; i <= maxSteps; i++)
{
var rad = Mathf.Deg2Rad * angle;
var posB = initialPos;
posB += new Vector3(radius * Mathf.Cos(rad), 0, radius * Mathf.Sin(rad));
Gizmos.DrawLine(posA, posB);
angle += stepAngles;
posA = posB;
}
Gizmos.DrawLine(posA, initialPos);
}
static float GetAnglesFromDir(Vector3 position, Vector3 dir)
{
var forwardLimitPos = position + dir;
var srcAngles = Mathf.Rad2Deg * Mathf.Atan2(forwardLimitPos.z - position.z, forwardLimitPos.x - position.x);
return srcAngles;
}
}
using System.Collections;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class PenguinAgent : Agent
{
[Tooltip("How fast the agent moves forward")]
public float moveSpeed = 5f;
[Tooltip("How fast the agent turns")]
public float turnSpeed = 180f;
[Tooltip("Prefab of the heart that appears when the baby is fed")]
public GameObject heartPrefab;
[Tooltip("Prefab of the regurgitated fish that appears when the baby is fed")]
public GameObject regurgitatedFishPrefab;
private PenguinArea penguinArea;
new private Rigidbody rigidbody;
private GameObject baby;
private bool isFull; // If true, penguin has a full stomach
public override void Initialize()
{
base.Initialize();
penguinArea = GetComponentInParent<PenguinArea>();
baby = penguinArea.penguinBaby;
rigidbody = GetComponent<Rigidbody>();
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Convert the first action to forward movement
float forwardAmount = actionBuffers.DiscreteActions[0];
// Convert the second action to turning left or right
float turnAmount = 0f;
if (actionBuffers.DiscreteActions[1] == 1f)
{
turnAmount = -1f;
}
else if (actionBuffers.DiscreteActions[1] == 2f)
{
turnAmount = 1f;
}
// Apply movement
rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
// Apply a tiny negative reward every step to encourage action
if (MaxStep > 0) AddReward(-1f / MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
int forwardAction = 0;
int turnAction = 0;
if (Input.GetKey(KeyCode.W))
{
// move forward
forwardAction = 1;
}
if (Input.GetKey(KeyCode.A))
{
// turn left
turnAction = 1;
}
else if (Input.GetKey(KeyCode.D))
{
// turn right
turnAction = 2;
}
// Put the actions into the array
actionsOut.DiscreteActions.Array[0] = forwardAction;
actionsOut.DiscreteActions.Array[1] = turnAction;
}
public override void OnEpisodeBegin()
{
isFull = false;
penguinArea.ResetArea();
}
public override void CollectObservations(VectorSensor sensor)
{
// Whether the penguin has eaten a fish (1 float = 1 value)
sensor.AddObservation(isFull);
// Distance to the baby (1 float = 1 value)
sensor.AddObservation(Vector3.Distance(baby.transform.position, transform.position));
// Direction to baby (1 Vector3 = 3 values)
sensor.AddObservation((baby.transform.position - transform.position).normalized);
// Direction penguin is facing (1 Vector3 = 3 values)
sensor.AddObservation(transform.forward);
// 1 + 1 + 3 + 3 = 8 total values
}
private void OnCollisionEnter(Collision collision)
{
if (collision.transform.CompareTag("fish"))
{
// Try to eat the fish
EatFish(collision.gameObject);
}
else if (collision.transform.CompareTag("baby"))
{
// Try to feed the baby
RegurgitateFish();
}
}
private void EatFish(GameObject fishObject)
{
if (isFull) return; // Can't eat another fish while full
isFull = true;
penguinArea.RemoveSpecificFish(fishObject);
AddReward(1f);
}
private void RegurgitateFish()
{
if (!isFull) return; // Nothing to regurgitate
isFull = false;
// Spawn regurgitated fish
GameObject regurgitatedFish = Instantiate<GameObject>(regurgitatedFishPrefab);
regurgitatedFish.transform.parent = transform.parent;
regurgitatedFish.transform.position = baby.transform.position;
Destroy(regurgitatedFish, 4f);
// Spawn heart
GameObject heart = Instantiate<GameObject>(heartPrefab);
heart.transform.parent = transform.parent;
heart.transform.position = baby.transform.position + Vector3.up;
Destroy(heart, 4f);
AddReward(1f);
if (penguinArea.FishRemaining <= 0)
{
EndEpisode();
}
}
}
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class Fish : MonoBehaviour
{
[Tooltip("The swim speed")]
public float fishSpeed;
private float randomizedSpeed = 0f;
private float nextActionTime = -1f;
private Vector3 targetPosition;
private void FixedUpdate()
{
if (fishSpeed > 0f)
{
Swim();
}
}
private void Swim()
{
// If it's time for the next action, pick a new speed and destination
// Else, swim toward the destination
if (Time.fixedTime >= nextActionTime)
{
// Randomize the speed
randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);
// Pick a random target
targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);
// Rotate toward the target
transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);
// Calculate the time to get there
float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
nextActionTime = Time.fixedTime + timeToGetThere;
}
else
{
// Make sure that the fish does not swim past the target
Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
{
transform.position += moveVector;
}
else
{
transform.position = targetPosition;
nextActionTime = Time.fixedTime;
}
}
}
}