8000 Implementation of stride grid search by mpecha · Pull Request #9 · permon/permonsvm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Implementation of stride grid search #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file 8000
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/output/exbinfile_1.out
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,17 @@
number of CG steps 0
number of expansion steps 46
number of proportioning steps 1
=====================
type: binary
model parameters:
||w||=2.3081 bias=0.0000 margin=0.8665 NSV=45
L1 hinge loss:
sum(xi_i)=60.5118
objective functions:
primalObj=63.1753 dualObj=61.7252 gap=1.4501
=====================
=====================
type: binary
model performance score with training parameters C=1.000, mod=2, loss=L1:
Confusion matrix:
TP = 27 FP = 9
FN = 7 TN = 47
accuracy=82.22% precision=75.00% sensitivity=79.41%
F1=0.77 MCC=0.63 AUC_ROC=0.82 G1=0.63
=====================
F1=0.77 MCC=0.63 AUC_ROC=0.82 G1=0.63
6 changes: 3 additions & 3 deletions include/permon/private/svmimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ struct _p_SVM {
PetscReal C,C_old;
PetscReal Cp,Cp_old;
PetscReal Cn,Cn_old;
PetscReal LogCMin,LogCMax,LogCBase;
PetscReal LogCpMin,LogCpMax,LogCpBase;
PetscReal LogCnMin,LogCnMax,LogCnBase;
PetscReal logC_base,logC_start,logC_end,logC_step;
PetscReal logCp_base,logCp_start,logCp_end,logCp_step;
PetscReal logCn_base,logCn_start,logCn_end,logCn_step;
PetscInt nfolds;

SVMLossType loss_type;
Expand Down
32 changes: 14 additions & 18 deletions include/permonsvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,24 +201,20 @@ FLLOP_EXTERN PetscErrorCode SVMGetHyperOptNScoreTypes(SVM,PetscInt *);
FLLOP_EXTERN PetscErrorCode SVMGetHyperOptScoreTypes(SVM svm,const ModelScore *types[]);

FLLOP_EXTERN PetscErrorCode SVMGridSearch(SVM);
FLLOP_EXTERN PetscErrorCode SVMSetLogCMin(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCMin(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCMax(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCMax(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCBase(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCBase(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCpMin(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCpMin(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCpMax(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCpMax(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCpBase(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCpBase(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCnMin(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCnMin(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCnMax(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCnMax(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMSetLogCnBase(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGetLogCnBase(SVM,PetscReal *);
/* Penalty type 1 */
FLLOP_EXTERN PetscErrorCode SVMGridSearchSetBaseLogC(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGridSearchGetBaseLogC(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMGridSearchSetStrideLogC(SVM,PetscReal,PetscReal,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGridSearchGetStrideLogC(SVM,PetscReal *,PetscReal *,PetscReal *);
/* Penalty type 2 */
FLLOP_EXTERN PetscErrorCode SVMGridSearchSetPositiveBaseLogC(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGridSearchGetPositiveBaseLogC(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMGridSearchSetPositiveStrideLogC(SVM,PetscReal,PetscReal,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGridSearchGetPositiveStrideLogC(SVM,PetscReal *,PetscReal *,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMGridSearchSetNegativeBaseLogC(SVM,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGridSearchGetNegativeBaseLogC(SVM,PetscReal *);
FLLOP_EXTERN PetscErrorCode SVMGridSearchSetNegativeStrideLogC(SVM,PetscReal,PetscReal,PetscReal);
FLLOP_EXTERN PetscErrorCode SVMGridSearchGetNegativeStrideLogC(SVM,PetscReal *,PetscReal *,PetscReal *);

FLLOP_EXTERN PetscErrorCode SVMSetCrossValidationType(SVM,CrossValidationType);
FLLOP_EXTERN PetscErrorCode SVMGetCrossValidationType(SVM,CrossValidationType *);
Expand Down
165 changes: 56 additions & 109 deletions src/svm/impls/binary/binary.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ PetscErrorCode SVMView_Binary(SVM svm,PetscViewer v)
TRY( PetscObjectTypeCompare((PetscObject)v,PETSCVIEWERASCII,&isascii) );

if (isascii) {
TRY( PetscViewerASCIIPrintf(v,"=====================\n") );
TRY( PetscObjectPrintClassNamePrefixType((PetscObject) svm,v) );

TRY( PetscViewerASCIIPushTab(v) );
Expand Down Expand Up @@ -150,8 +149,6 @@ PetscErrorCode SVMView_Binary(SVM svm,PetscViewer v)
TRY( PetscViewerASCIIPopTab(v) );

TRY( PetscViewerASCIIPopTab(v) );

TRY( PetscViewerASCIIPrintf(v,"=====================\n") );
} else {
FLLOP_SETERRQ1(comm,PETSC_ERR_SUP,"Viewer type %s not supported for SVMViewScore", ((PetscObject)v)->type_name);
}
Expand Down Expand Up @@ -182,7 +179,6 @@ PetscErrorCode SVMViewScore_Binary(SVM svm,PetscViewer v)
TRY( SVMGetMod(svm,&mod) );
TRY( SVMGetLossType(svm,&loss_type) );

TRY( PetscViewerASCIIPrintf(v,"=====================\n") );
TRY( PetscObjectPrintClassNamePrefixType((PetscObject) svm,v) );

TRY( PetscViewerASCIIPushTab(v) );
Expand Down Expand Up @@ -214,8 +210,6 @@ PetscErrorCode SVMViewScore_Binary(SVM svm,PetscViewer v)
TRY( PetscViewerASCIIPopTab(v) );

TRY( PetscViewerASCIIPopTab(v) );

TRY( PetscViewerASCIIPrintf(v,"=====================\n") );
} else {
FLLOP_SETERRQ1(comm,PETSC_ERR_SUP,"Viewer type %s not supported for SVMViewScore", ((PetscObject)v)->type_name);
}
Expand Down Expand Up @@ -1411,141 +1405,94 @@ PetscErrorCode SVMGetModelScore_Binary(SVM svm,ModelScore score_type,PetscReal *

#undef __FUNCT__
#define __FUNCT__ "SVMInitGridSearch_Binary_Private"
PetscErrorCode SVMInitGridSearch_Binary_Private(SVM svm,PetscInt *n,PetscReal *c_arr[])
PetscErrorCode SVMInitGridSearch_Binary_Private(SVM svm,PetscInt *n_out,PetscReal *grid_out[])
{
PetscInt penalty_type;

PetscReal logC_min,logC_max,logC_base;
PetscReal logCp_min,logCp_max,logCp_base;
PetscReal logCn_min,logCn_max,logCn_base;
PetscReal base_1,start_1,end_1,step_1;
PetscReal base_2,start_2,end_2,step_2;
PetscReal tmp;

PetscReal Cp,Cn;
PetscReal C_min,Cp_min,Cn_min;
PetscReal *grid;
PetscInt n,n_1,n_2;

PetscReal *c_arr_inner;
PetscInt n_inner,np,nn;
PetscInt i,j,p;

PetscFunctionBegin;
TRY( SVMGetPenaltyType(svm,&penalty_type) );

if (penalty_type == 1) {
TRY( SVMGetLogCMin(svm,&logC_min) );
TRY( SVMGetLogCMax(svm,&logC_max) );
TRY( SVMGetLogCBase(svm,&logC_base) );
TRY( SVMGridSearchGetBaseLogC(svm,&base_1) );
TRY( SVMGridSearchGetStrideLogC(svm,&start_1,&end_1,&step_1) );

C_min = PetscPowReal(logC_base,logC_min);
n = (PetscAbs(end_1 - start_1) + 1) / PetscAbs(step_1);
TRY( PetscMalloc1(n,&grid) );

n_inner = (PetscInt) (logC_max - logC_min) + 1;
TRY( PetscMalloc1(n_inner,&c_arr_inner) );

c_arr_inner[0] = C_min;
for (i = 1; i < n_inner; ++i) {
c_arr_inner[i] = c_arr_inner[i-1] * logC_base;
}
/* Penalty type 2: different penalty for each one class */
for (i = 0; i < n; ++i) grid[i] = PetscPowReal(base_1,start_1 + i * step_1);
} else {
TRY( SVMGetLogCpMin(svm,&logCp_min) );
TRY( SVMGetLogCpMax(svm,&logCp_max) );
TRY( SVMGetLogCpBase(svm,&logCp_base) );

TRY( SVMGetLogCnMin(svm,&logCn_min) );
TRY( SVMGetLogCnMax(svm,&logCn_max) );
TRY( SVMGetLogCnBase(svm,&logCn_base) );

Cp_min = PetscPowReal(logCp_base,logCp_min);
Cn_min = PetscPowReal(logCn_base,logCn_min);

np = (PetscInt) (logCp_max - logCp_min) + 1;
nn = (PetscInt) (logCn_max - logCn_min) + 1;
n_inner = 2 * np * nn;

TRY( PetscMalloc1(n_inner,&c_arr_inner) );

/* Generate Cp and Cn values */
Cp = Cp_min;
p = 0;
for (i = 0; i < np; ++i) {
c_arr_inner[p++] = Cp;
c_arr_inner[p++] = Cn_min;
Cn = Cn_min;
for (j = 1; j < nn; ++j) {
c_arr_inner[p++] = Cp;
Cn *= logCn_base;
c_arr_inner[p++] = Cn;
TRY( SVMGridSearchGetPositiveBaseLogC(svm,&base_1) );
TRY( SVMGridSearchGetPositiveStrideLogC(svm,&start_1,&end_1,&step_1) );
TRY( SVMGridSearchGetNegativeBaseLogC(svm,&base_2) );
TRY( SVMGridSearchGetNegativeStrideLogC(svm,&start_2,&end_2,&step_2) );

n_1 = (PetscAbs(end_1 - start_1) + 1) / PetscAbs(step_1);
n_2 = (PetscAbs(end_2 - start_2) + 1) / PetscAbs(step_2);
n = 2 * n_1 * n_2;
TRY( PetscMalloc1(n,&grid) );

p = -1;
for (i = 0; i < n_1; ++i) {
grid[++p] = PetscPowReal(base_1,start_1 + i * step_1);
tmp = grid[p];
grid[++p] = PetscPowReal(base_2,start_2);
for (j = 1; j < n_2; ++j) {
grid[++p] = tmp;
grid[++p] = PetscPowReal(base_2,start_2 + j * step_2);
}
Cp *= logCp_base;
}
}

*n = n_inner;
*c_arr = c_arr_inner;
*n_out = n;
*grid_out = grid;
PetscFunctionReturn(0);
}

#undef __FUNCT__
#define __FUNCT__ "SVMGridSearch_Binary"
PetscErrorCode SVMGridSearch_Binary(SVM svm)
{
MPI_Comm comm;

PetscReal *c_arr,*score;
PetscReal score_best;
PetscInt m,n,i,p;

PetscBool info_set;
const char *prefix;

const ModelScore *hyperopt_score_types;
PetscInt nscores;
PetscReal *grid;
PetscReal *scores,score_best;
PetscInt i,n,m,p,s;

PetscFunctionBegin;
TRY( SVMInitGridSearch_Binary_Private(svm,&n,&c_arr) );
TRY( SVMGetPenaltyType(svm,&m) );

TRY( PetscMalloc1((n / m),&score) );
TRY( PetscMemzero(score,(n / m) * sizeof(PetscReal)) );

TRY( SVMCrossValidation(svm,c_arr,n,score) );

/* Select penalty */
n /= m;
score_best = score[0];
p = 0;
for (i = 1; i < n; ++i) {
if (score[i] > score_best) {
/* Initialize grid */
TRY( SVMInitGridSearch_Binary_Private(svm,&n,&grid) );
/* Perform cross-validation */
s = n / m;
TRY( PetscCalloc1(s,&scores) );
TRY( SVMCrossValidation(svm,grid,n,scores) );

/* Find best score */
score_best = -1.;
for (i = 0; i < s; ++i) {
if (scores[i] > score_best) {
p = i;
score_best = score[i];
score_best = scores[i];
}
}
TRY( SVMSetPenalty(svm,m,&c_arr[p * m]) );

TRY( SVMGetOptionsPrefix(svm,&prefix) );
TRY( PetscOptionsHasName(NULL,prefix,"-svm_info",&info_set) );

if (info_set) {
TRY( PetscObjectGetComm((PetscObject) svm,&comm) );
TRY( SVMGetHyperOptScoreTypes(svm,&hyperopt_score_types) );
TRY( SVMGetHyperOptNScoreTypes(svm,&nscores) );
TRY( PetscPrintf(comm,"SVM (grid-search): selected ") );
if (m == 1) {
TRY( PetscPrintf(comm,"C_best=%f, ",c_arr[p]) );
} else {
TRY( PetscPrintf(comm,"C+_best=%f, ",c_arr[p * m]) );
TRY( PetscPrintf(comm,"C-_best=%f, ",c_arr[p * m + 1]) );
}
TRY( PetscPrintf(comm,"acc_score=%f (",score_best) );
for (i = 0; i < nscores; ++i) {
TRY( PetscPrintf(comm,"%s",ModelScores[hyperopt_score_types[i]]) );
if (i < nscores - 1) {
TRY( PetscPrintf(comm,",") );
} else {
TRY( PetscPrintf(comm,")\n") );
}
}

if (m == 1) {
PetscInfo2(svm,"selected best C=%.4f (score=%f)\n",grid[p],score_best);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use %g and cast to double
See PETSc dev guide section 2.2.3 19.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue also appears in other parts of code. I think it is better to do it in another PR related to fixing code-style.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not introduce a potential bug (not casting will be undefined behaviour when PetscReal is not double/float?) just because it is somewhere else in the code.

I added note in #3 to fix this in other parts of the code.

} else {
PetscInfo3(svm,"selected best C+=%.4f, C-=%.4f (score=%f)\n",grid[p * m],grid[p * m + 1],score_best);
}
TRY( PetscFree(c_arr) );
TRY( PetscFree(score) );

TRY( SVMSetPenalty(svm,m,&grid[p * m]) );

TRY( PetscFree(grid) );
TRY( PetscFree(scores) );
PetscFunctionReturn(0);
}

Expand Down
Loading
0