-
Notifications
You must be signed in to change notification settings - Fork 60
cross validation and grid search #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Si 8000 gn up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I'll have a look at fixing this. Basically you'd just need to add these to the model .pyx files:
This could work. I fixed most of the incompatibilities with sklearn earlier, but it seems I missed this. On model hyperparameter estimation, I've used a Gaussian random search algorithm described in Section 6.1.4 of my thesis: https://arxiv.org/pdf/1602.02332.pdf . You can use any of the Python hyperparameter packages, such as cmaes, hyperopt, skopt, Ray Tune, etc. There's a lot of them by now. I'll try to add hyperpameter optimization into the package when I get the chance, since this is useful with the supported models, and makes good use of the supported distributed computing backends. |
This is bit more complicated to fix. It seem sklearn would need a fix. For ftrl.pyx, get_params can be defined:
Also, the estimator _init_ function needs adding the model parameters w, z, n as optional arguments (np.ndarray w= None, np.ndarray z= None, np.ndarray n= None) After this you'll still get the following validation error from sklearn:
You can debug sklearn base.py clone() to print the variables:
Which prints out this:
So the first "param1 is param2" comparison fails, whereas the "param1 == param2" comparison works. This is due to difference how the Python "is" and "==" comparisons work. Here any float or numpy array will fail the comparison, so the validation raises an error. I'm not sure why they use "is" instead of "==" in the clone() validation, since this validation should be comparing values of different objects, not their references. I'll set up a ticket for sklearn developers for fixing the above. If that gets fixed then there's not that many changes to do to fix this issue. |
Thank you for looking into this! |
I would like to use
FM_FTRL
in an sklearn cross-validation pipeline, e.g.,This throws
This error is also thrown when trying to pass a FM_FTRL model to
GridSearchCV
.Can you provide some guidance on how to make this work?
I can see in this thread that you tuned hyperparameters with random search. Can you provide guidance on that?
Thank you!
The text was updated successfully, but these errors were encountered: