PitchHut logo
Pool
Efficient parallel data collection for reinforcement learning environments.
Pitch

The Pool class streamlines data collection from multiple environments in reinforcement learning. By utilizing Python's multiprocessing, it allows for efficient management of shared memory and smooth execution of concurrent interactions. Features include configurable pool size, windowed data management, and optional randomized prioritization for balanced data distribution.

Description

The Pool class provides an efficient and scalable solution for parallelized data collection in reinforcement learning environments. This implementation utilizes Python's multiprocessing module to facilitate concurrent interactions with multiple environments, making it ideal for gathering experience data needed for training reinforcement learning agents.

Key Features

  • Parallel Data Collection: Collects experience tuples consisting of (state, action, next_state, reward, done) from multiple environments simultaneously, significantly enhancing data throughput.
  • Shared Memory Management: Employs multiprocessing.Manager to manage shared data structures, allowing seamless access and modification across multiple processes.
  • Configurable Pool Size: Users can specify the desired size of the data pool to accommodate varying requirements for data collection.
  • Windowed Data Management: Supports maintenance of the most recent samples with a defined window_size, useful for keeping a sliding window of experience and preventing data overload.
  • Randomized Prioritization (Optional): Prioritizes data collection across sub-pools based on the inverse length of each sub-pool, ensuring a more balanced distribution of data among processes.
  • Clearing Frequency (Optional): Has the capability to periodically truncate older data based on a specified frequency, preventing indefinite growth of the data pool.

Class Definition

import numpy as np
import multiprocessing as mp
import math

class Pool:
    def __init__(self, env, processes, pool_size, window_size=None, clearing_freq=None, window_size_=None, random=True):
        # Initialization for Pool object

Constructor Parameters

  • env (List[gym.Env]): List of environment instances for each process.
  • processes (int): Number of parallel processes for data collection.
  • pool_size (int): Maximum total size of all processes' data pools combined.
  • window_size (int, optional): Maintains a maximum size for each sub-pool, truncating older samples when exceeded.
  • clearing_freq (int, optional): Periodically clears older data, keeping only the window_size_ most recent samples.
  • window_size_ (int, optional): Number of oldest samples to remove during clear operations.
  • random (bool, optional): If True, data is randomly distributed among sub-pools to ensure similar sizes, optimizing data utilization across processes.

Methods Overview

pool(self, s, a, next_s, r, done, index=None)

Manages the addition of experience tuples to the appropriate shared lists, applying truncation rules based on defined parameters.

store_in_parallel(self, p, lock_list)

Handles continuous interaction with the assigned environment per process, collecting and storing experience tuples securely.

store(self)

Initiates parallel data collection, creating and managing multiple child processes for data gathering.

get_pool(self)

Retrieves the aggregated data from all sub-pools, returning combined NumPy arrays of states, actions, next states, rewards, and done flags.

Usage Example

import numpy as np
import multiprocessing as mp
import math
import gym  # Assuming gym environments

# Dummy environment for demonstration
class DummyEnv:
    def __init__(self):
        self.state = 0
        self.steps = 0

    def reset(self):
        self.state = np.random.rand(4)
        self.steps = 0
        return self.state, 0  # Return state and dummy action

    def step(self, action):
        self.state += np.random.rand(4) * 0.1
        reward = np.sum(self.state)
        self.steps += 1
        done = self.steps >= 10  # Episode ends after 10 steps
        return np.random.randint(0, 2), self.state, reward, done

if __name__ == "__main__":
    num_processes = 2
    envs = [DummyEnv() for _ in range(num_processes)]
    total_pool_size = 50

    pool_manager = Pool(envs, num_processes, total_pool_size, window_size=5, random=True)

    print("Starting data collection...")
    pool_manager.store()

    state, action, next_state, reward, done = pool_manager.get_pool()
    print(f"Collected States Shape: {state.shape}")
    # Further data analysis can follow

This class is a practical tool for anyone working in the field of reinforcement learning, facilitating effective data collection to improve agent training.

0 comments

No comments yet.

Sign in to be the first to comment.