-
Notifications
You must be signed in to change notification settings - Fork 19
[WIP] Make adding new Policy Models flexible #327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
98e4d7e
1388756
7c8a517
4ecd865
d50aeea
1508e4b
9cdf18e
8af7f82
9474150
83a4904
e254cf2
d30c3f2
2ac9af5
7e02200
e6a14ef
57b0b13
178a08e
ace8a28
8e6f03d
774c411
9520315
910a948
c9ec03f
9fa9381
2bf438a
9395e61
595db80
6797bcc
95258a0
61f1e1d
b7f33d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,4 @@ playground/ | |
!docs/requirements-docs.txt | ||
.DS_Store | ||
docs/_build/ | ||
logs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
_target_: gflownet.policy.cnn.CNNPolicy | ||
|
||
shared: null | ||
|
||
forward: | ||
n_layers: 2 | ||
channels: [16, 32] | ||
kernel_sizes: [[3, 3], [2, 2]] # Each tuple represents (height, width) | ||
strides: [[1, 1], [1, 1]] # Each tuple represents (vertical_stride, horizontal_stride) | ||
checkpoint: null | ||
reload_ckpt: False | ||
|
||
backward: | ||
shared_weights: True | ||
checkpoint: null | ||
reload_ckpt: False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,6 +75,7 @@ def __init__( | |
height: int = 20, | ||
pieces: List = ["I", "J", "L", "O", "S", "T", "Z"], | ||
rotations: List = [0, 90, 180, 270], | ||
flatten: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we move the flattening from the environment to the policy, then we don't need this. |
||
allow_redundant_rotations: bool = False, | ||
allow_eos_before_full: bool = False, | ||
**kwargs, | ||
|
@@ -87,6 +88,7 @@ def __init__( | |
self.height = height | ||
self.pieces = pieces | ||
self.rotations = rotations | ||
self.flatten = flatten | ||
self.allow_redundant_rotations = allow_redundant_rotations | ||
self.allow_eos_before_full = allow_eos_before_full | ||
self.max_pieces_per_type = 100 | ||
|
@@ -307,7 +309,9 @@ def states2policy( | |
A tensor containing all the states in the batch. | ||
""" | ||
states = tint(states, device=self.device, int_type=self.int) | ||
return self.states2proxy(states).flatten(start_dim=1).to(self.float) | ||
if self.flatten: | ||
return self.states2proxy(states).flatten(start_dim=1).to(self.float) | ||
return self.states2proxy(states).to(self.float) | ||
Comment on lines
+312
to
+314
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexhernandezgarcia This is a temporary solution to make the CNN policy work on Tetris env. But normally the flattening should happen inside the model but not in the environment (see my other comments) if you are okay with that, then I can update. |
||
|
||
def state2readable(self, state: Optional[TensorType["height", "width"]] = None): | ||
""" | ||
|
@@ -581,7 +585,7 @@ def _plot_board(board, ax: Axes, cellsize: int = 20, linewidth: int = 2): | |
linewidth : int | ||
The width of the separation between cells, in pixels. | ||
""" | ||
board = board.clone().numpy() | ||
board = board.clone().cpu().numpy() | ||
height = board.shape[0] * cellsize | ||
width = board.shape[1] * cellsize | ||
board_img = 128 * np.ones( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
from omegaconf import OmegaConf | ||
from torch import nn F438 | ||
|
||
from gflownet.policy.base import Policy | ||
|
||
|
||
class CNNPolicy(Policy): | ||
def __init__(self, **kwargs): | ||
config = self._get_config(kwargs["config"]) | ||
# Shared weights, defaults to False | ||
self.shared_weights = config.get("shared_weights", False) | ||
# Reload checkpoint, defaults to False | ||
self.reload_ckpt = config.get("reload_ckpt", False) | ||
# CNN features: number of layers, number of channels, kernel sizes, strides | ||
self.n_layers = config.get("n_layers", 3) | ||
self.channels = config.get("channels", [16] * self.n_layers) | ||
self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) | ||
self.strides = config.get("strides", [(1, 1)] * self.n_layers) | ||
# Base init | ||
super().__init__(**kwargs) | ||
|
||
def make_model(self): | ||
""" | ||
Instantiates a CNN with no top layer activation. | ||
|
||
Returns | ||
------- | ||
model : torch.nn.Module | ||
A torch model containing the CNN. | ||
is_model : bool | ||
True because a CNN is a model. | ||
""" | ||
if self.shared_weights and self.base is not None: | ||
layers = list(self.base.model.children())[:-1] | ||
last_layer = nn.Linear( | ||
self.base.model[-1].in_features, self.base.model[-1].out_features | ||
) | ||
|
||
model = nn.Sequential(*layers, last_layer).to(self.device) | ||
return model, True | ||
|
||
current_channels = 1 | ||
conv_module = nn.Sequential() | ||
|
||
if len(self.kernel_sizes) != self.n_layers: | ||
raise ValueError( | ||
f"Inconsistent dimensions kernel_sizes != n_layers, " | ||
"{len(self.kernel_sizes)} != {self.n_layers}" | ||
) | ||
|
||
for i in range(self.n_layers): | ||
conv_module.add_module( | ||
f"conv_{i}", | ||
nn.Conv2d( | ||
in_channels=current_channels, | ||
out_channels=self.channels[i], | ||
kernel_size=tuple(self.kernel_sizes[i]), | ||
stride=tuple(self.strides[i]), | ||
padding=0, | ||
padding_mode="zeros", # Constant zero padding | ||
), | ||
) | ||
conv_module.add_module(f"relu_{i}", nn.ReLU()) | ||
current_channels = self.channels[i] | ||
|
||
dummy_input = torch.ones( | ||
(1, 1, self.height, self.width) | ||
) # (batch_size, channels, height, width) | ||
try: | ||
in_channels = conv_module(dummy_input).numel() | ||
if in_channels >= 500_000: # TODO: this could better be handled | ||
raise RuntimeWarning( | ||
"Input channels for the dense layer are too big, this will " | ||
"increase number of parameters" | ||
) | ||
except RuntimeError as e: | ||
raise RuntimeError( | ||
"Failed during convolution operation. Ensure that the kernel sizes " | ||
"and strides are appropriate for the input dimensions." | ||
) from e | ||
|
||
model = nn.Sequential( | ||
conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) | ||
) | ||
return model.to(self.device), True | ||
|
||
def __call__(self, states): | ||
states = states.unsqueeze(1) # (batch_size, channels, height, width) | ||
return self.model(states) |
Uh oh!
There was an error while loading. Please reload this page.