8000 [WIP] Make adding new Policy Models flexible by engmubarak48 · Pull Request #327 · alexhernandezgarcia/gflownet · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Make adding new Policy Models flexible #327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
98e4d7e
initial commit: split the base policy and the architectures
engmubarak48 Jun 18, 2024
1388756
ignore git logs
engmubarak48 Jun 21, 2024
7c8a517
refactor base policy class and move models into differrent file
engmubarak48 Jun 21, 2024
4ecd865
handle when config none gracefully
engmubarak48 Jun 21, 2024
d50aeea
black formatting
engmubarak48 Jun 21, 2024
1508e4b
further formatting with isort
engmubarak48 Jun 21, 2024
9cdf18e
formatting black + isort
engmuba 8000 rak48 Jun 28, 2024
8af7f82
added flatten flag and device movement handling
engmubarak48 Jun 28, 2024
9474150
bug fix: use .cpu() before .numpy()
engmubarak48 Jun 28, 2024
83a4904
added cnn policy, and flatten flag should be set to false when using …
engmubarak48 Jun 28, 2024
e254cf2
black formatting
engmubarak48 Jun 28, 2024
d30c3f2
smaller cnn config like kernel size etc
engmubarak48 Jul 9, 2024
2ac9af5
minor refactor on parse_config
engmubarak48 Jul 9, 2024
7e02200
move self.is_model to instantiate and add super().parse_config(config…
engmubarak48 Jul 9, 2024
e6a14ef
Add docstring and typing to __init__ of policy base.
alexhernandezgarcia Jul 10, 2024
57b0b13
Use kwargs instead of listing parameters explicitly
alexhernandezgarcia Jul 10, 2024
178a08e
Policy MLP: docstring and typing.
alexhernandezgarcia Jul 10, 2024
ace8a28
Get rid of parse_config and include its content in __init__
alexhernandezgarcia Jul 10, 2024
8e6f03d
Combine instantiate and make_* into a single method make_model()
alexhernandezgarcia Jul 10, 2024
774c411
Missing import
alexhernandezgarcia Jul 10, 2024
9520315
Fix config issue by implementing _get_config()
alexhernandezgarcia Jul 10, 2024
910a948
Docstring for base argument
alexhernandezgarcia Jul 10, 2024
c9ec03f
Merge pull request #335 from alexhernandezgarcia/ahg/293-flexible-pol…
josephdviviano Sep 18, 2024
9fa9381
remove env from the cnn policy
engmubarak48 Sep 23, 2024
2bf438a
init the cnn env's height and width in the policy
engmubarak48 Sep 23, 2024
9395e61
add mlp to device
engmubarak48 Sep 23, 2024
595db80
formatting
engmubarak48 Sep 23, 2024
6797bcc
debug: add to print environment information when running GitHub Actio…
engmubarak48 Sep 23, 2024
95258a0
Add the specific versions of pymatgen and spglib
engmubarak48 Sep 25, 2024
61f1e1d
Merge branch 'main' into 293-flexible-policy-definition
engmubarak48 Sep 25, 2024
b7f33d8
revert version downgrade of pymatgen
engmubarak48 Sep 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/push_code_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ jobs:
with:
python-version: "3.10"

- name: Display Python version
run: python -V

- name: Install code dependencies
run: bash ./prereq_ci.sh pip install --upgrade

Expand All @@ -36,6 +39,12 @@ jobs:
- name: Install Pytest and Isort
run: pip install pytest isort

- name: Display pip list
run: pip list

- name: Display system information
run: uname -a

- name: Validate import format in main source code
run: isort --profile black ./gflownet/ --check-only

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ playground/
!docs/requirements-docs.txt
.DS_Store
docs/_build/
logs
2 changes: 2 additions & 0 deletions config/env/tetris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ height: 20
pieces: ["I", "J", "L", "O", "S", "T", "Z"]
# Allowed roations
rotations: [0, 90, 180, 270]
# Don't flatten if using CNN
flatten: True
# Other config
allow_redundant_rotations: False
allow_eos_before_full: False
Expand Down
16 changes: 16 additions & 0 deletions config/policy/cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: gflownet.policy.cnn.CNNPolicy

shared: null

forward:
n_layers: 2
channels: [16, 32]
kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width)
strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride)
checkpoint: null
reload_ckpt: False

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
3 changes: 1 addition & 2 deletions config/policy/mlp.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
_target_: gflownet.policy.base.Policy
_target_: gflownet.policy.mlp.MLPPolicy

shared: null

forward:
type: mlp
n_hid: 128
n_layers: 2
checkpoint: null
Expand Down
8 changes: 6 additions & 2 deletions gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
height: int = 20,
pieces: List = ["I", "J", "L", "O", "S", "T", "Z"],
rotations: List = [0, 90, 180, 270],
flatten: bool = True,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move the flattening from the environment to the policy, then we don't need this.

allow_redundant_rotations: bool = False,
allow_eos_before_full: bool = False,
**kwargs,
Expand All @@ -87,6 +88,7 @@ def __init__(
self.height = height
self.pieces = pieces
self.rotations = rotations
self.flatten = flatten
self.allow_redundant_rotations = allow_redundant_rotations
self.allow_eos_before_full = allow_eos_before_full
self.max_pieces_per_type = 100
Expand Down Expand Up @@ -307,7 +309,9 @@ def states2policy(
A tensor containing all the states in the batch.
"""
states = tint(states, device=self.device, int_type=self.int)
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
if self.flatten:
return self.states2proxy(states).flatten(start_dim=1).to(self.float)
return self.states2proxy(states).to(self.float)
Comment on lines +312 to +314
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexhernandezgarcia This is a temporary solution to make the CNN policy work on Tetris env. But normally the flattening should happen inside the model but not in the environment (see my other comments)

if you are okay with that, then I can update.


def state2readable(self, state: Optional[TensorType["height", "width"]] = None):
"""
Expand Down Expand Up @@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2):
linewidth : int
The width of the separation between cells, in pixels.
"""
board = board.clone().numpy()
board = board.clone().cpu().numpy()
height = board.shape[0] * cellsize
width = board.shape[1] * cellsize
board_img = 128 * np.ones(
Expand Down
2 changes: 1 addition & 1 deletion gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def compute_log_prob_metrics(self, x_tt, metrics=None):

if "corr_prob_traj_rewards" in metrics:
lp_metrics["corr_prob_traj_rewards"] = np.corrcoef(
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt.cpu().numpy()
)[0, 1]

if "var_logrewards_logp" in metrics:
Expand Down
162 changes: 77 additions & 85 deletions gflownet/policy/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
from abc import ABC, abstractmethod
"""
Base Policy class for GFlowNet policy models.
"""

from typing import Tuple, Union

import torch
from omegaconf import OmegaConf
from torch import nn
from omegaconf.dictconfig import DictConfig

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import set_device, set_float_precision


class ModelBase(ABC):
def __init__(self, config, env, device, float_precision, base=None):
class Policy:
def __init__(
self,
config: Union[dict, DictConfig],
env: GFlowNetEnv,
device: Union[str, torch.device],
float_precision: [int, torch.dtype],
base=None,
):
"""
Base Policy class for a :class:`GFlowNetAgent`.

Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
env : GFlowNetEnv
The environment used to train the :class:`GFlowNetAgent`, used to extract
needed properties.
device : str or torch.device
The device to be passed to torch tensors.
float_precision : int or torch.dtype
The floating point precision to be passed to torch tensors.
base: Policy (optional)
A base policy to be used as backbone for the backward policy.
"""
config = self._get_config(config)
# Device and float precision
self.device = set_device(device)
self.float = set_float_precision(float_precision)
Expand All @@ -19,98 +49,60 @@ def __init__(self, config, env, device, float_precision, base=None):
self.output_dim = len(self.fixed_output)
# Optional base model
self.base = base
# Policy type, defaults to uniform
self.type = config.get("type", "uniform")
# Checkpoint, defaults to None
self.checkpoint = config.get("checkpoint", None)
# TODO: This could be done better? We could store this only when using CNN policy. e.g. self.type could be "cnn"
if hasattr(env, "height"):
self.height = env.height
if hasattr(env, "width"):
self.width = env.width
# Instantiate the model
self.model, self.is_model = self.make_model()

@staticmethod
def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
"""
Returns a configuration dictionary, even if the input is None.

self.parse_config(config)

def parse_config(self, config):
# If config is null, default to uniform
Parameters
----------
config : dict or DictConfig
The configuration dictionary to set up the policy model. It may be None, in
which an empty config is created and the defaults will be used.

Returns
-------
config : dict or DictConfig
The configuration dictionary to set up the policy model.
"""
if config is None:
config = OmegaConf.create()
config.type = "uniform"
self.checkpoint = config.get("checkpoint", None)
self.shared_weights = config.get("shared_weights", False)
self.n_hid = config.get("n_hid", None)
self.n_layers = config.get("n_layers", None)
self.tail = config.get("tail", [])
if "type" in config:
self.type = config.type
elif self.shared_weights:
self.type = self.base.type
else:
raise "Policy type must be defined if shared_weights is False"

@abstractmethod
def instantiate(self):
pass
return config

def __call__(self, states):
return self.model(states)

def make_mlp(self, activation):
def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]:
"""
Defines an MLP with no top layer activation
If share_weight == True,
baseModel (the model with which weights are to be shared) must be provided
Args
----
layers_dim : list
Dimensionality of each layer
activation : Activation
Activation function
Instantiates the model of the policy.

Returns
-------
model : torch.tensor or torch.nn.Module
A tensor representing the output of the policy or a torch model.
is_model : bool
True if the policy is a model (for example, a neural network) and False if
it is a fixed tensor (for example to make a uniform distribution).
"""
if self.shared_weights == True and self.base is not None:
mlp = nn.Sequential(
self.base.model[:-1],
nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
),
)
return mlp
elif self.shared_weights == False:
layers_dim = (
[self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)]
)
mlp = nn.Sequential(
*(
sum(
[
[nn.Linear(idim, odim)]
+ ([activation] if n < len(layers_dim) - 2 else [])
for n, (idim, odim) in enumerate(
zip(layers_dim, layers_dim[1:])
)
],
[],
)
+ self.tail
)
)
return mlp
else:
raise ValueError(
"Base Model must be provided when shared_weights is set to True"
)


class Policy(ModelBase):
def __init__(self, config, env, device, float_precision, base=None):
super().__init__(config, env, device, float_precision, base)

self.instantiate()

def instantiate(self):
if self.type == "fixed":
self.model = self.fixed_distribution
self.is_model = False
return self.fixed_distribution, False
elif self.type == "uniform":
self.model = self.uniform_distribution
self.is_model = False
elif self.type == "mlp":
self.model = self.make_mlp(nn.LeakyReLU()).to(self.device)
self.is_model = True
return self.uniform_distribution, False
else:
raise "Policy model type not defined"

def __call__(self, states):
return self.model(states)

def fixed_distribution(self, states):
"""
Returns the fixed distribution specified by the environment.
Expand Down
90 changes: 90 additions & 0 deletions gflownet/policy/cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from omegaconf import OmegaConf
from torch import nn

from gflownet.policy.base import Policy


class CNNPolicy(Policy):
def __init__(self, **kwargs):
config = self._get_config(kwargs["config"])
# Shared weights, defaults to False
self.shared_weights = config.get("shared_weights", False)
# Reload checkpoint, defaults to False
self.reload_ckpt = config.get("reload_ckpt", False)
# CNN features: number of layers, number of channels, kernel sizes, strides
self.n_layers = config.get("n_layers", 3)
self.channels = config.get("channels", [16] * self.n_layers)
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers)
self.strides = config.get("strides", [(1, 1)] * self.n_layers)
# Base init
super().__init__(**kwargs)

def make_model(self):
"""
Instantiates a CNN with no top layer activation.

Returns
-------
model : torch.nn.Module
A torch model containing the CNN.
is_model : bool
True because a CNN is a model.
"""
if self.shared_weights and self.base is not None:
layers = list(self.base.model.children())[:-1]
last_layer = nn.Linear(
self.base.model[-1].in_features, self.base.model[-1].out_features
)

model = nn.Sequential(*layers, last_layer).to(self.device)
return model, True

current_channels = 1
conv_module = nn.Sequential()

if len(self.kernel_sizes) != self.n_layers:
raise ValueError(
f"Inconsistent dimensions kernel_sizes != n_layers, "
"{len(self.kernel_sizes)} != {self.n_layers}"
)

for i in range(self.n_layers):
conv_module.add_module(
f"conv_{i}",
nn.Conv2d(
in_channels=current_channels,
out_channels=self.channels[i],
kernel_size=tuple(self.kernel_sizes[i]),
stride=tuple(self.strides[i]),
padding=0,
padding_mode="zeros", # Constant zero padding
),
)
conv_module.add_module(f"relu_{i}", nn.ReLU())
current_channels = self.channels[i]

dummy_input = torch.ones(
(1, 1, self.height, self.width)
) # (batch_size, channels, height, width)
try:
in_channels = conv_module(dummy_input).numel()
if in_channels >= 500_000: # TODO: this could better be handled
raise RuntimeWarning(
"Input channels for the dense layer are too big, this will "
"increase number of parameters"
)
except RuntimeError as e:
raise RuntimeError(
"Failed during convolution operation. Ensure that the kernel sizes "
"and strides are appropriate for the input dimensions."
) from e

model = nn.Sequential(
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim)
)
return model.to(self.device), True

def __call__(self, states):
states = states.unsqueeze(1) # (batch_size, channels, height, width)
return self.model(states)
Empty file added gflownet/policy/gnn.py
Empty file.
Loading
Loading
0