强化学习库gym定义我们自己的环境

gym==0.21.0

stable-baselines3==1.4.1a1

import gym
from gym import spaces
import pygame
import numpy as np
from stable_baselines3 import PPO


class GridWorldEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, size=5):
        self.size = size  # 方格环境的大小
        self.window_size = 512  #  PyGame窗口的大小

        # 观察(observation)是记录agent和目标位置的字典
        # 每个位置编码为{0,…,`size`}^2的元素, 即多重离散([size,size])
        # 即 这是包含坐标的一维数组, 最小为0, 最大为size-1, 数组内元素个数为2,
        self.observation_space = spaces.Dict(
            {
                "agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),
                "target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
            }
        )

        self.action_space = spaces.Discrete(4)
        self._action_to_direction = {
            0: np.array([0, 1]),  # 向右移动
            1: np.array([-1, 0]), # 向上移动
            2: np.array([0, -1]), # 向左移动
            3: np.array([1, 0]),  # 向下移动
        }

        # 离散空间生成器
        self.random_space = spaces.Box(0, size - 1, shape=(2,), dtype=int)

        self._agent_location = self.random_space.sample()  # 机器人位置
        self._target_location = self.random_space.sample()  # 目标的位置
        self.count = 0  # 当前走的步数

        self.window = None
        self.clock = None

    # 返回观测的结果
    def _get_obs(self):
        return {"agent": self._agent_location, "target": self._target_location}

    # 得到一些其他信息, 提供agent和目标之间的曼哈顿距离
    def _get_info(self):
        return {"distance": np.linalg.norm(self._agent_location - self._target_location, ord=1)}

    def reset(self, seed=None, return_info=False, options=None):
        self.count = 0

        # 随机选择agent的位置
        self._agent_location = self.random_space.sample()

        # 我们将随机抽样目标的位置,直到它与agent的位置不一致为止
        self._target_location = self.random_space.sample()
        while np.array_equal(self._target_location, self._agent_location):
            self._target_location = self.random_space.sample()

        observation = self._get_obs()
        info = self._get_info()
        return (observation, info) if return_info else observation

    def step(self, action):
        self.count += 1
        # 将动作(item{0,1,2,3})映射到行走的方向
        direction = self._action_to_direction[action]

        # 使用`np.clip` 确保我们行动没有离开定义的方格区域
        self._agent_location = np.clip(
            self._agent_location + direction, 0, self.size - 1
        )

        # agent达到target,一个episode完成
        done = np.array_equal(self._agent_location, self._target_location)
        reward = 1 if done else 0  # 二元稀疏奖励
        observation = self._get_obs()
        info = self._get_info()

        if self.count > 200:
            done = True
            reward = -1

        return observation, reward, done, info

    def render(self, mode="human"):
        if self.window is None and mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        pix_square_size = (
                self.window_size / self.size
        )  # 单个网格正方形的大小(以像素为单位)

        # 首先画出target
        pygame.draw.rect(
            canvas,
            (255, 0, 0),
            pygame.Rect(
                pix_square_size * self._target_location,
                (pix_square_size, pix_square_size),
                ),
        )
        # 画出agent
        pygame.draw.circle(
            canvas,
            (0, 0, 255),
            (self._agent_location + 0.5) * pix_square_size,
            pix_square_size / 3,
            )

        # 最后,添加一些网格线
        for x in range(self.size + 1):
            pygame.draw.line(
                canvas,
                0,
                (0, pix_square_size * x),
                (self.window_size, pix_square_size * x),
                width=3,
            )
            pygame.draw.line(
                canvas,
                0,
                (pix_square_size * x, 0),
                (pix_square_size * x, self.window_size),
                width=3,
            )

        if mode == "human":
            # 我们的绘图从“canvas”复制到可见窗口
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # 我们需要确保以预定义的帧速率进行渲染
            # 下一行将自动添加延迟以保持帧速率稳定。
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))


def learn():
    env = GridWorldEnv(10)
    tensorboard_log_path = "./tensorboard_log"
    model = PPO("MultiInputPolicy",
                env,
                verbose=1,
                tensorboard_log=tensorboard_log_path)
    model.learn(total_timesteps=5_0000)
    model.save("GridWorldEnv")


def check():
    env = GridWorldEnv(10)
    model = PPO.load("GridWorldEnv", env=env)
    done = False

    obs = env.reset()
    cnt = 0
    while not done:
        cnt += 1
        action, state = model.predict(obs)
        obs, reward, done, info = env.step(action)
        print(cnt, action, obs, state)
        env.render()


if __name__ == '__main__':
    check()