8000 Model caching by roussel-ryan · Pull Request #253 · xopt-org/Xopt · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Model caching #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 2, 2024
147 changes: 129 additions & 18 deletions docs/examples/single_objective_bayes_opt/time_dependent_bo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@
" return (x_ - k(t_)) ** 2\n",
"\n",
"\n",
"start_time = time.time()\n",
"\n",
"\n",
"# create callable function for Xopt\n",
"def f(inputs):\n",
" x_ = inputs[\"x\"]\n",
Expand Down Expand Up @@ -121,19 +118,7 @@
"\n",
"vocs = VOCS(variables=variables, objectives=objectives)\n",
"\n",
"evaluator = Evaluator(function=f)\n",
"generator = TDUpperConfidenceBoundGenerator(\n",
" vocs=vocs,\n",
" beta=0.01,\n",
" added_time=0.1,\n",
" forgetting_time=20.0,\n",
")\n",
"generator.n_monte_carlo_samples = N_MC_SAMPLES\n",
"generator.numerical_optimizer.n_restarts = NUM_RESTARTS\n",
"generator.max_travel_distances = [0.1]\n",
"\n",
"X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs)\n",
"X"
"evaluator = Evaluator(function=f)"
]
},
{
Expand All @@ -158,9 +143,22 @@
},
"outputs": [],
"source": [
"generator = TDUpperConfidenceBoundGenerator(\n",
" vocs=vocs,\n",
" beta=0.01,\n",
" added_time=0.1,\n",
" forgetting_time=20.0,\n",
")\n",
"generator.n_monte_carlo_samples = N_MC_SAMPLES\n",
"generator.numerical_optimizer.n_restarts = NUM_RESTARTS\n",
"generator.max_travel_distances = [0.1]\n",
"\n",
"start_time = time.time()\n",
"\n",
"X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs)\n",
"X.random_evaluate(1)\n",
"\n",
"for _ in trange(300):\n",
"for i in trange(300):\n",
" # note that in this example we can ignore warnings if computation time is greater\n",
" # than added time\n",
" with warnings.catch_warnings():\n",
Expand Down Expand Up @@ -284,12 +282,125 @@
" ax2.set_title(\"acquisition function at last time step\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Run Time Dependent BO with Model Caching\n",
"Instead of retraining the GP model hyperparameters at every step, we can instead hold\n",
"on to previously determined model parameters by setting\n",
"`use_catched_hyperparameters=True` in the model constructor. This reduces the time\n",
"needed to make decisions, leading to faster feedback when addressing time-critical\n",
"optimization tasks. However, this can come at the cost of model accuracy when the\n",
"target function changes behavior (change in lengthscale for example)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"generator = TDUpperConfidenceBoundGenerator(\n",
" vocs=vocs,\n",
" beta=0.01,\n",
" added_time=0.1,\n",
" forgetting_time=20.0,\n",
")\n",
"generator.n_monte_carlo_samples = N_MC_SAMPLES\n",
"generator.numerical_optimizer.n_restarts = NUM_RESTARTS\n",
"generator.max_travel_distances = [0.1]\n",
"\n",
"start_time = time.time()\n",
"\n",
"X = Xopt(evaluator=evaluator, generator=generator, vocs=vocs)\n",
"X.random_evaluate(1)\n",
"\n",
"for i in trange(300):\n",
" # note that in this example we can ignore warnings if computation time is greater\n",
" # than added time\n",
" if i == 50:\n",
" X.generator.gp_constructor.use_cached_hyperparameters = True\n",
"\n",
" with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
" X.step()\n",
" time.sleep(0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot total computation time\n",
"ax = X.generator.computation_time.sum(axis=1).plot()\n",
"ax.set_xlabel(\"Iteration\")\n",
"ax.set_ylabel(\"total BO computation time (s)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = X.data\n",
"\n",
"xbounds = generator.vocs.bounds\n",
"tbounds = [data[\"time\"].min(), data[\"time\"].max()]\n",
"\n",
"model = X.generator.model\n",
"n = 100\n",
"t = torch.linspace(*tbounds, n, dtype=torch.double)\n",
"x = torch.linspace(*xbounds.flatten(), n, dtype=torch.double)\n",
"tt, xx = torch.meshgrid(t, x)\n",
"pts = torch.hstack([ele.reshape(-1, 1) for ele in (tt, xx)]).double()\n",
"\n",
"tt, xx = tt.numpy(), xx.numpy()\n",
"\n",
"# NOTE: the model inputs are such that t is the last dimension\n",
"gp_pts = torch.flip(pts, dims=[-1])\n",
"\n",
"gt_vals = g(gp_pts.T[0], gp_pts.T[1] - start_time)\n",
"\n",
"with torch.no_grad():\n",
" post = model.posterior(gp_pts)\n",
"\n",
" mean = post.mean\n",
" std = torch.sqrt(post.variance)\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.set_title(\"model mean\")\n",
" ax.set_xlabel(\"unix time\")\n",
" ax.set_ylabel(\"x\")\n",
" c = ax.pcolor(tt, xx, mean.reshape(n, n))\n",
" ax.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\", label=\"samples\")\n",
"\n",
" ax.plot(t, k(t - start_time), \"C3--\", label=\"ideal path\", zorder=10)\n",
" ax.legend()\n",
" fig.colorbar(c)\n",
"\n",
" fig2, ax2 = plt.subplots()\n",
" ax2.set_title(\"model uncertainty\")\n",
" ax2.set_xlabel(\"unix time\")\n",
" ax2.set_ylabel(\"x\")\n",
" c = ax2.pcolor(tt, xx, std.reshape(n, n))\n",
" fig2.colorbar(c)\n",
"\n",
" fig3, ax3 = plt.subplots()\n",
" ax3.set_title(\"ground truth value\")\n",
" ax3.set_xlabel(\"unix time\")\n",
" ax3.set_ylabel(\"x\")\n",
" c = ax3.pcolor(tt, xx, gt_vals.reshape(n, n))\n",
" fig3.colorbar(c)\n",
"\n",
" ax2.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\")\n",
" ax3.plot(data[\"time\"].to_numpy(), data[\"x\"].to_numpy(), \"oC1\")"
]
}
],
"metadata": {
Expand Down
46 changes: 23 additions & 23 deletions xopt/generators/bayesian/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ class ModelConstructor(XoptBaseModel, ABC):
Convenience wrapper around `build_model` for use with VOCs (Variables, Objectives,
Constraints, Statics).

build_single_task_gp(train_X, train_Y, **kwargs)
build_single_task_gp(X, Y, train=True, **kwargs)
Utility method for creating and training simple SingleTaskGP models.

build_heteroskedastic_gp(train_X, train_Y, train_Yvar, **kwargs)
build_heteroskedastic_gp(X, Y, Yvar, train=True, **kwargs)
Utility method for creating and training heteroskedastic SingleTaskGP models.

"""
Expand Down Expand Up @@ -122,16 +122,18 @@ def build_model_from_vocs(
)

@staticmethod
def build_single_task_gp(train_X: Tensor, train_Y: Tensor, **kwargs) -> Model:
def build_single_task_gp(X: Tensor, Y: Tensor, train=True, **kwargs) -> Model:
"""
Utility method for creating and training simple SingleTaskGP models.

Parameters
----------
train_X : Tensor
X : Tensor
Training data for input variables.
train_Y : Tensor
Y : Tensor
Training data for outcome variables.
train : bool, True
Flag to specify if hyperparameter training should take place
**kwargs
Additional keyword arguments for model configuration.

Expand All @@ -141,29 +143,32 @@ def build_single_task_gp(train_X: Tensor, train_Y: Tensor, **kwargs) -> Model:
The trained SingleTaskGP model.

"""
if train_X.shape[0] == 0 or train_Y.shape[0] == 0:
if X.shape[0] == 0 or Y.shape[0] == 0:
raise ValueError("no data found to train model!")
model = SingleTaskGP(train_X, train_Y, **kwargs)
model = SingleTaskGP(X, Y, **kwargs)

mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
if train:
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
return model

@staticmethod
def build_heteroskedastic_gp(
train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor, **kwargs
X: Tensor, Y: Tensor, Yvar: Tensor, train: bool = True, **kwargs
) -> Model:
"""
Utility method for creating and training heteroskedastic SingleTaskGP models.

Parameters
----------
train_X : Tensor
X : Tensor
Training data for input variables.
train_Y : Tensor
Y : Tensor
Training data for outcome variables.
train_Yvar : Tensor
Yvar : Tensor
Training data for outcome variable variances.
train : bool, True
Flag to specify if hyperparameter training should take place
**kwargs
Additional keyword arguments for model configuration.

Expand All @@ -182,15 +187,10 @@ def build_heteroskedastic_gp(

warnings.filterwarnings("ignore")

if train_X.shape[0] == 0 or train_Y.shape[0] == 0:
if X.shape[0] == 0 or Y.shape[0] == 0:
raise ValueError("no data found to train model!")
model = XoptHeteroskedasticSingleTaskGP(
train_X,
train_Y,
train_Yvar,
**kwargs,
)

mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
model = XoptHeteroskedasticSingleTaskGP(X, Y, Yvar, **kwargs)
if train:
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
return model
31 changes: 30 additions & 1 deletion xopt/generators/bayesian/models/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class StandardModelConstructor(ModelConstructor):
description="specify custom noise prior for the GP likelihood, "
"overwrites value specified by use_low_noise_prior",
)
use_cached_hyperparameters: Optional[bool] = Field(
False,
description="flag to specify if cached hyperparameters should be used in "
"model creation instead of training",
)
_hyperparameter_store: Optional[Dict] = None

model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

Expand Down Expand Up @@ -177,6 +183,14 @@ def build_model(
tkwargs = {"dtype": dtype, "device": device}
models = []

# validate if model caching can be used if requested
if self.use_cached_hyperparameters:
if self._hyperparameter_store is None:
raise RuntimeWarning(
"cannot use cached hyperparameters, hyperparameter store empty, "
"training GP model hyperparameters instead"
)

covar_modules = deepcopy(self.covar_modules)
mean_modules = deepcopy(self.mean_modules)
for outcome_name in outcome_names:
Expand Down Expand Up @@ -208,6 +222,7 @@ def build_model(
train_X.to(**tkwargs),
train_Y.to(**tkwargs),
likelihood=self.likelihood,
train=not self.use_cached_hyperparameters,
**kwargs,
)
)
Expand All @@ -219,6 +234,7 @@ def build_model(
train_X.to(**tkwargs),
train_Y.to(**tkwargs),
train_Yvar.to(**tkwargs),
train=not self.use_cached_hyperparameters,
**kwargs,
)
)
Expand All @@ -234,7 +250,20 @@ def build_model(
f"could not be added to the model."
)

return ModelListGP(*models)
full_model = ModelListGP(*models)

# if specified, use cached model hyperparameters
if self.use_cached_hyperparameters:
store = {
name: ele.to(**tkwargs)
for name, ele in self._hyperparameter_store.items()
}
full_model.load_state_dict(store)

# cache model hyperparameters
self._hyperparameter_store = full_model.state_dict()

return full_model

def build_mean_module(
self, name, mean_modules, input_transform, outcome_transform
Expand Down
1 change: 1 addition & 0 deletions xopt/generators/bayesian/models/time_dependent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def build_model(
dtype: torch.dtype = torch.double,
device: Union[torch.device, str] = "cpu",
) -> ModelListGP:
# get model input names
new_input_names = deepcopy(input_names)
new_input_names += ["time"]

Expand Down
12 changes: 9 additions & 3 deletions xopt/tests/generators/bayesian/test_mobo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xopt.base import Xopt
from xopt.evaluator import Evaluator
from xopt.generators.bayesian.mobo import MOBOGenerator
from xopt.numerical_optimizer import GridOptimizer
from xopt.resources.test_functions.tnk import (
evaluate_TNK,
tnk_reference_point,
Expand All @@ -36,9 +37,13 @@ def test_script(self):
evaluator = Evaluator(function=evaluate_TNK)
reference_point = tnk_reference_point

gen = MOBOGenerator(vocs=tnk_vocs, reference_point=reference_point)
gen = MOBOGenerator(
vocs=tnk_vocs,
reference_point=reference_point,
numerical_optimizer=GridOptimizer(n_grid_points=2),
)
gen = deepcopy(gen)
gen.n_monte_carlo_samples = 20
gen.n_monte_carlo_samples = 1

for ele in [gen]:
dump = ele.model_dump()
Expand Down Expand Up @@ -66,6 +71,7 @@ def test_parallel(self):
reference_point=reference_point,
use_pf_as_initial_points=True,
)
gen.n_monte_carlo_samples = 1
gen.add_data(test_data)

gen.generate(2)
Expand Down Expand Up @@ -168,7 +174,7 @@ def test_initial_conditions(self):
reference_point=reference_point,
use_pf_as_initial_points=True,
)
gen.numerical_optimizer.max_time = 1.0
gen.n_monte_carlo_samples = 1
gen.add_data(test_data)
initial_points = gen._get_initial_conditions()

Expand Down
Loading
Loading
0