#include <BALL/QSAR/classificationValidation.h>
Public Member Functions | |
Constructors and Destructors | |
ClassificationValidation (ClassificationModel *m) | |
Private Attributes | |
Attributes | |
BALL::Matrix< double > | confusion_matrix_ |
Vector< double > | class_results_ |
double | quality_ |
double | quality_input_test_ |
double | quality_cv_ |
ClassificationModel * | clas_model |
void(ClassificationValidation::* | qualCalculation )() |
Accessors | |
| |
void | crossValidation (int k, bool restore=1) |
double | getCVRes () |
double | getFitRes () |
void | setCVRes (double d) |
void | testInputData (bool transform=0) |
const BALL::Matrix< double > * | getConfusionMatrix () |
const BALL::Vector< double > * | getClassResults () |
void | bootstrap (int k, bool restore=1) |
const BALL::Matrix< double > & | yRandomizationTest (int runs, int k) |
double | getAccuracyCV () |
double | getAccuracyInputTest () |
void | selectStat (int s) |
void | saveToFile (string filename) const |
void | saveToFile (string filename, const double &quality_input_test, const double &predictive_quality) const |
void | readFromFile (string filename) |
void | testAllSubstances (bool transform) |
void | calculateAverageSensitivity () |
void | calculateWeightedSensitivity () |
void | calculateOverallAccuracy () |
void | calculateAverageMCC () |
void | calculateOverallMCC () |
void | calculateTDR () |
class for validation of QSAR regression models
Definition at line 48 of file classificationValidation.h.
BALL::QSAR::ClassificationValidation::ClassificationValidation | ( | ClassificationModel * | m | ) |
constructor
m | pointer to the regression model, which the object of this class should test |
void BALL::QSAR::ClassificationValidation::bootstrap | ( | int | k, | |
bool | restore = 1 | |||
) | [virtual] |
starts bootstrapping with k samples
k | no of bootstrap samples |
Implements BALL::QSAR::Validation.
void BALL::QSAR::ClassificationValidation::calculateAverageMCC | ( | ) | [private] |
calculate one MCC for each class and use the average
void BALL::QSAR::ClassificationValidation::calculateAverageSensitivity | ( | ) | [private] |
calculate average accuracy with the current values of TP, FP, FN, TN in matrix ClassificationValidation.predictions.
void BALL::QSAR::ClassificationValidation::calculateOverallAccuracy | ( | ) | [private] |
calculate accuracy for all classes at once
void BALL::QSAR::ClassificationValidation::calculateOverallMCC | ( | ) | [private] |
calculate MCC for all classes at once
void BALL::QSAR::ClassificationValidation::calculateTDR | ( | ) | [private] |
calculate the True Discovery Rate (only applicable to binary classification validation results).
void BALL::QSAR::ClassificationValidation::calculateWeightedSensitivity | ( | ) | [private] |
calculate weighted average accuracy of all classes. Weighted by the number of training compounds within each class
void BALL::QSAR::ClassificationValidation::crossValidation | ( | int | k, | |
bool | restore = 1 | |||
) | [virtual] |
Starts cross-validation with k steps.
Data is taken from QSARData.descriptor_matrix and is in each step divided into training- and test-data.
(Data having already been copied into Model.descriptor_matrix will be deleted)
Implements BALL::QSAR::Validation.
double BALL::QSAR::ClassificationValidation::getAccuracyCV | ( | ) |
get average accuracy value as determined after cross validation
double BALL::QSAR::ClassificationValidation::getAccuracyInputTest | ( | ) |
get average accuracy value as determined after testing of input data();
const BALL::Vector<double>* BALL::QSAR::ClassificationValidation::getClassResults | ( | ) |
returns a RowVector holding the one value contituting the validation result for each class if "average accuracy" or "average MCC" is chosen (see selectStat()).
const BALL::Matrix<double>* BALL::QSAR::ClassificationValidation::getConfusionMatrix | ( | ) |
return pointer to the matrix containing the number of TP, FP, TN, FN in one column for each class
double BALL::QSAR::ClassificationValidation::getCVRes | ( | ) | [virtual] |
fetches the result of cross-validation
Implements BALL::QSAR::Validation.
double BALL::QSAR::ClassificationValidation::getFitRes | ( | ) | [virtual] |
fetches the quality of fit to the input data, as calculated by testInputData()
Implements BALL::QSAR::Validation.
void BALL::QSAR::ClassificationValidation::readFromFile | ( | string | filename | ) | [virtual] |
restore validation-results from a file
Implements BALL::QSAR::Validation.
void BALL::QSAR::ClassificationValidation::saveToFile | ( | string | filename, | |
const double & | quality_input_test, | |||
const double & | predictive_quality | |||
) | const |
void BALL::QSAR::ClassificationValidation::saveToFile | ( | string | filename | ) | const [virtual] |
save the result of the applied validation methods to a file
Implements BALL::QSAR::Validation.
void BALL::QSAR::ClassificationValidation::selectStat | ( | int | s | ) | [virtual] |
select the desired statistic to be used for validating the models
s | if (s==1) R^2 and Q^2 are used if(s==2) F_regr and F_cv are used. |
Implements BALL::QSAR::Validation.
void BALL::QSAR::ClassificationValidation::setCVRes | ( | double | d | ) | [virtual] |
set the result of cross-validation to the given value
Implements BALL::QSAR::Validation.
void BALL::QSAR::ClassificationValidation::testAllSubstances | ( | bool | transform | ) | [private] |
Tests the current model with all substances in the (unchanged) test data set
void BALL::QSAR::ClassificationValidation::testInputData | ( | bool | transform = 0 |
) | [virtual] |
Fetches input data from QSARData and tests the current (unchanged) model with all these new substances (without cross-validation!).
transform | if transform==1, the test data is transformed in the same way that the training data was transformed before predicting activities. If training and test substances are taken from the same input file, set transform to 0 |
Implements BALL::QSAR::Validation.
const BALL::Matrix<double>& BALL::QSAR::ClassificationValidation::yRandomizationTest | ( | int | runs, | |
int | k | |||
) | [virtual] |
Y randomization test
Randomizes all columns of model.Y, trains the model, runs crossValidation and testInputData and saves the resulting accuracy_input_test and accuracy_cv value to a vector, where BALL::Matrix<double>(i,0)=accuracy_input_test, BALL::Matrix<double>(i,1)=accuracy_cv
runs | this is repeated as often as specified by 'runs' |
Implements BALL::QSAR::Validation.
pointer to the regression model, which the object of this class should test
Definition at line 148 of file classificationValidation.h.
Vector<double> BALL::QSAR::ClassificationValidation::class_results_ [private] |
RowVector holding the one value contituting the validation result for each class if "average sensitivity" or "average MCC" is chosen (see selectStat()).
Definition at line 139 of file classificationValidation.h.
BALL::Matrix<double> BALL::QSAR::ClassificationValidation::confusion_matrix_ [private] |
matrix containing the number of TP, FP, FN, TN in one column for each class
Definition at line 136 of file classificationValidation.h.
void(ClassificationValidation::* BALL::QSAR::ClassificationValidation::qualCalculation)() [private] |
Definition at line 150 of file classificationValidation.h.
Definition at line 141 of file classificationValidation.h.
Definition at line 145 of file classificationValidation.h.
Definition at line 143 of file classificationValidation.h.