8000 cross validation and grid search · Issue #33 · anttttti/Wordbatch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

cross validation and grid search #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Si 8000 gn up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
marketneutral opened this issue Oct 24, 2019 · 3 comments
Open

cross validation and grid search #33

marketneutral opened this issue Oct 24, 2019 · 3 comments

Comments

@marketneutral
Copy link

I would like to use FM_FTRL in an sklearn cross-validation pipeline, e.g.,

from wordbatch.models import FM_FTRL

modelF = FM_FTRL(
      alpha=0.01,    # learning rate
      beta=0.1,
      L1=0.00001,
      L2=0.10,
      D=X_train.shape[1],
      alpha_fm=0.01,
      L2_fm=0.0,
      init_fm=0.01,
      D_fm=50,
      e_noise=0.0001,
      iters=5,
      inv_link='sigmoid',
      threads=4
  )

cv_scores = cross_val_score(modelF, X_train.tocsc(), y_train_fm.target.values, scoring='roc_auc', cv=time_split)

This throws

TypeError: Cannot clone object '<wordbatch.models.fm_ftrl.FM_FTRL object at 0x557056cfbfa0>' (type <class 'wordbatch.models.fm_ftrl.FM_FTRL'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods.

This error is also thrown when trying to pass a FM_FTRL model to GridSearchCV.

Can you provide some guidance on how to make this work?

I can see in this thread that you tuned hyperparameters with random search. Can you provide guidance on that?

Thank you!

@anttttti
Copy link
Owner

I'll have a look at fixing this. Basically you'd just need to add these to the model .pyx files:

def get_params(self):  return self.__getstate__()
def set_params(self, params):  self.__setstate__(params):

This could work. I fixed most of the incompatibilities with sklearn earlier, but it seems I missed this.

On model hyperparameter estimation, I've used a Gaussian random search algorithm described in Section 6.1.4 of my thesis: https://arxiv.org/pdf/1602.02332.pdf . You can use any of the Python hyperparameter packages, such as cmaes, hyperopt, skopt, Ray Tune, etc. There's a lot of them by now. I'll try to add hyperpameter optimization into the package when I get the chance, since this is useful with the supported models, and makes good use of the supported distributed computing backends.

@anttttti
Copy link
Owner

This is bit more complicated to fix. It seem sklearn would need a fix.

For ftrl.pyx, get_params can be defined:

def get_params(self, deep=False):
    param_names= ["alpha", "beta", "L1", "L2", "e_clip", "D", "init", "seed", "iters", "w", "z", "n", "inv_link",
                  "threads", "bias_term", "verbose"]
    params= {x:y for x, y in zip(param_names, self.__getstate__())}
    if params['inv_link']==1:  params['inv_link']= "sigmoid"
    else:  params['inv_link']= "identity"
    return params

Also, the estimator _init_ function needs adding the model parameters w, z, n as optional arguments (np.ndarray w= None, np.ndarray z= None, np.ndarray n= None)

After this you'll still get the following validation error from sklearn:

RuntimeError: Cannot clone object <wordbatch.models.ftrl.FTRL object at 0x56245d510ee0>, as the constructor either does not set or modifies parameter alpha

You can debug sklearn base.py clone() to print the variables:

print(name, param1, param2, type(param1), type(param2), param1 is param2, param1==param2)

Which prints out this:

alpha 0.1 0.1 <class 'float'> <class 'float'> False True

So the first "param1 is param2" comparison fails, whereas the "param1 == param2" comparison works. This is due to difference how the Python "is" and "==" comparisons work. Here any float or numpy array will fail the comparison, so the validation raises an error. I'm not sure why they use "is" instead of "==" in the clone() validation, since this validation should be comparing values of different objects, not their references.

I'll set up a ticket for sklearn developers for fixing the above. If that gets fixed then there's not that many changes to do to fix this issue.

@marketneutral
Copy link
Author

Thank you for looking into this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
0