avatar


8.自定义环境(以股票交易为例)

《强化学习浅谈及其Python实现》的第二章到第七章,我们都将强化学习模型和gym中的某个预制的环境进行交互。
如果我们想将强化学习应用在其他场景下,并且gym中没有与之对应的环境。
那么,我们需要自定义环境。
《强化学习浅谈及其Python实现》上再新增一篇文章《8.自定义环境(以股票交易为例)》,讨论如何自定义环境。

完整的代码已经PUSH到了我的Github上

https://github.com/KakaWanYifan/custom-env-demo

gym的最后一个版本是0.26.2,发布于2022年10月5日。之后由gymnasium代替。
本文会基于gymnasium进行讨论。
gymnasium的官网:https://gymnasium.farama.org/
官方也提供了一个自定义环境的教程:https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/
但是官方教程的内容,需要我们把编写的环境打成一个Python的包,还需要gym的环境注册等,不是很方便。

代码结构

代码结构

gymnasium.core.Env

新建stock_env.py,定义一个类StockEnv,继承gymnasium.core.Env

1
2
3
import gymnasium as gym

class StockEnv(gym.core.Env):

我们可以点进Env,会看到如下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Env(Generic[ObsType, ActType]):
r"""The main Gymnasium class for implementing Reinforcement Learning Agents environments.

The class encapsulates an environment with arbitrary behind-the-scenes dynamics through the :meth:`step` and :meth:`reset` functions.
An environment can be partially or fully observed by single agents. For multi-agent environments, see PettingZoo.

The main API methods that users of this class need to know are:

- :meth:`step` - Updates an environment with actions returning the next agent observation, the reward for taking that actions,
if the environment has terminated or truncated due to the latest action and information from the environment about the step, i.e. metrics, debug info.
- :meth:`reset` - Resets the environment to an initial state, required before calling step.
Returns the first agent observation for an episode and information, i.e. metrics, debug info.
- :meth:`render` - Renders the environments to help visualise what the agent see, examples modes are "human", "rgb_array", "ansi" for text.
- :meth:`close` - Closes the environment, important when external software is used, i.e. pygame for rendering, databases

Environments have additional attributes for users to understand the implementation

- :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space.
- :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space.
- :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode.
The default reward range is set to :math:`(-\infty,+\infty)`.
- :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make`
- :attr:`metadata` - The metadata of the environment, i.e. render modes, render fps
- :attr:`np_random` - The random number generator for the environment. This is automatically assigned during
``super().reset(seed=seed)`` and when assessing ``self.np_random``.

.. seealso:: For modifying or extending environments use the :py:class:`gymnasium.Wrapper` class
"""

该部分其实已经将如何自定义环境描述得很清楚了。

我们以一个股票交易的场景为例,举例子。

构造方法

我们在环境的构造方法中,读取股票数据,并初始化一些变量。
示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class StockEnv(gym.core.Env):

def __init__(self):
"""
构造方法
"""
# 读取股票数据
self.stock_data = pd.read_csv('custom_env/stock_data.csv').iloc[::-1]
# 初始时刻的index
self.stock_index = 0
# 仓位
# 初始时刻的仓位
# position_index:股票仓位的index,用以记录买入价
# position_stock:股票仓位
self.stock_position = {'position_index': 0, 'position_stock': 0}

成员方法

step

智能体每执行一个动作,环境需要返回:

  • 新的状态(observation (ObsType))
  • 奖励(reward (SupportsFloat))。

gymnasium.core.Env中,还需要返回如下值:

  • terminated (bool):结束(成功或正常),智能体完成目标情况下,或者游戏正常退出,结束。
  • truncated (bool):结束(失败),智能体在失败的情况下,结束。
  • info (dict):用于输出一些供调试的信息。

还有一个返回值,done (bool),从0.26的版本开始,该返回值已经被弃用,由terminated (bool)truncated (bool)代替。

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
self.stock_index = self.stock_index + 1

# 新的观测值
observation = self.stock_data.iloc[self.stock_index]

# 计算奖励
# 假如都以收盘价进行计算
# 原市值
pre_value = self.stock_data.iloc[self.stock_position['position_index']]['close'] * self.stock_position[
'position_stock']
# 更新股票仓位
position_index_val = self.stock_index
self.stock_position['position_index'] = position_index_val
position_stock_val = self.stock_position['position_stock'] + action
self.stock_position['position_stock'] = position_stock_val
# 新市值
new_value = self.stock_data.iloc[self.stock_position['position_index']]['close'] * self.stock_position['position_stock']
# 奖励
reward = new_value - pre_value

# terminated
terminated = False
if (self.stock_data.shape[0] - 1) == self.stock_index:
terminated = True

# truncated
# 可以设置爆仓的话,truncated为True,在该例子永远是False
truncated = False

# info
info = {}

return observation, reward, terminated, truncated, info

reset

当环境重置后,stock_index置为0,仓位也恢复到初始状态。

示例代码:

1
2
3
4
5
6
7
8
9
10
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
# 初始时刻的index
self.stock_index = 0
# 仓位
self.stock_position = {'position_index': 0, 'position_stock': 0}

render

render,用于描述当前环境的状态。

在有些环境下,回通过图像描述状态,如下:
render

在这个例子中,我们简单的打印市值。

1
2
3
4
def render(self) -> RenderFrame | list[RenderFrame] | None:
value = self.stock_data.iloc[self.stock_position['position_index']]['close'] * self.stock_position[
'position_stock']
print('value ', value)

close

关闭环境的时候,调用close方法。
我们可以在这个方法中释放资源(如,断开数据库连接),或者进行其他技术操作和业务操作。

示例代码:

1
2
def close(self):
pass

成员变量

action_space

在真实场景中,股票交易环境的action_space,需要根据仓位和交易规则等进行确定。
例如:现有的现金最多可以购买多少股票?现有的仓位最多能卖多少股票?是否可以融资融券等?

在本文,作为一个DEMO,简单的将其定义为[5,5][-5,5]的一个闭区间。

示例代码:

1
2
# 动作空间
self.action_space = spaces.Discrete(n=11, start=-5)

observation_space

观察空间,其实就是这里的股票数据。
在本例中,没有定义该变量,直接把stock_data作为观察空间。

reward_range

奖励范围。
在本例中,不定义该成员变量

spec

环境规范,用于gymnasium.make时候,初始化环境的信息。
在本例中,不定义该成员变量

metadata

环境的元数据信息。
例如:

1
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

在本例中,不定义该成员变量

np_random

随机数种子。

1
2
# 随机数种子
self.np_random = 0

和环境交互

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
from stock_env import StockEnv

env = StockEnv()
env.reset()
for i in range(1000):
env.render()
observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
print(observation)
print(reward)
if terminated or truncated:
break
env.close()
文章作者: Kaka Wan Yifan
文章链接: https://kakawanyifan.com/10508
版权声明: 本博客所有文章版权为文章作者所有,未经书面许可,任何机构和个人不得以任何形式转载、摘编或复制。

评论区