Wolfram Research

Function Repository Resource:

CrossValidateModel

Source Notebook

Check the quality of a data fitting model by splitting the data into test and validation sets multiple times

Contributed by: Sjoerd Smit

ResourceFunction["CrossValidateModel"][data,fitfun]

splits data into training and validation sets multiple times in different ways, then applies fitfun to the training data and tests the result using the validation set.

ResourceFunction["CrossValidateModel"][data,dist]

performs cross-validation by using EstimatedDistribution to fit distribution dist to data.

ResourceFunction["CrossValidateModel"][data,<|lbl1 fitfun1,lbl2 fitfun2,|>]

cross-validates all listed models, using the same set of training-validation splits for each model.

ResourceFunction["CrossValidateModel"][data,{dist1,dist2,}]

performs cross-validation on multiple distributions by fitting them with EstimatedDistribution.

Details and Options

The recognized formats for data are listed in the left-hand column below. The right-hand column shows examples of symbols that take the same data format. Note that Dataset is currently not supported.
ResourceFunction["CrossValidateModel"] divides the data repeatedly into randomly chosen training and validation sets. Each time, fitfun is applied to the training set. The returned model is then validated against the validation set. The returned result is of the form {<|"FittedModel" model1, "ValidationResult" result1|>, }.
For any distribution that satisfies DistributionParameterQ, the syntax for fitting distributions is the same as for EstimatedDistribution, which is the function that is being used internally to fit the data. The returned values of EstimatedDistribution are found under the "FittedModel" keys of the output. The "ValidationResult" lists the average negative LogLikelihood on the validation set, which is the default loss function for distribution fits.
If the user does not specify a validation function explicitly using the "ValidationFunction" option, ResourceFunction["CrossValidateModel"] tries to automatically apply a sensible validation loss metric for results returned by fitfun (see table below the next item). If no usable validation method is found, the validation set itself will be returned so that the user can perform their own validation afterward.
The following table lists the types of models that are recognized; the functions that produce such models; and the default loss function applied to that type of model:
An explicit validation function can be provided with the "ValidationFunction" option. This function takes the fit result as a first argument and a validation set as a second argument. If multiple models are specified as an Association in the second argument of ResourceFunction["CrossValidateModel"], different validation functions for each model can be specified by passing an Association to the "ValidationFunction" option.
The Method option can be used to configure how the training and validation sets are generated. The following types of sampling are supported:
"KFold" (default) splits the dataset into k subsets (default: k=5) and trains the model k times, using each partition as validation set once
"LeaveOneOut" fit the data as many times as there are elements in the data set, using each element for validation once
“RandomSubSampling” split the dataset randomly into training and validation sets (default: 80% / 20%) repeatedly (default: five times) or define a custom sampling function
"BootStrap" use bootstrap samples (generated with RandomChoice) to fit the model repeatedly without validation
The default Method setting uses k-fold validation with five folds. This means that the dataset is randomly split into five partitions, where each is used as the validation set once. This means that the data gets trained five times on 4/5 of the dataset and then tested on the remaining 1/5. The "KFold" method has two sub-options:
"Folds" 5 number of partitions in which to split the dataset
"Runs" 1 number of times to perform k-fold validation (each time with a new random partitioning of the data)
The "LeaveOneOut" method, also known as the jackknife method, is essentially k-fold validation where the number of folds is equal to the number of data points. Since it can be quite computationally expensive, it is usually a good idea to use parallelization with this method. It does have the "Runs" sub-option like the "KFold" method, but for deterministic fitting procedures like EstimatedDistribution and LinearModelFit, there is no value in performing more than one run since each run will yield the exact same results (up to a random permutation).

The method "RandomSubSampling" splits the dataset into training/validation sets randomly and has the following sub-options:

"Runs" 1 number of times to resample the data into training/validation sets
ValidationSet Scaled[1/5] number of samples to use for validation. When specified as Scaled[f], a fraction f of the dataset will be used for validation
"SamplingFunction" Automatic function that specifies how to sub-sample the data
For the option "SamplingFunction", the function fun[nData,nVal] should return an Association with the keys “TrainingSet” and “ValidationSet”. Each key should contain a list of integers that indicate the indices in the data set.
The default sampling function corresponding to Automatic is Function[{nData,nVal},AssociationThread[{"TrainingSet","ValidationSet"},TakeDrop[RandomSample[Range[nData]],nData-nVal]]].
Bootstrap sampling is useful to get a sense of the range of possible models that can be fitted to the data. In a bootstrap sample, the original dataset is sampled with replacement (using RandomChoice), so the bootstrap samples can be larger than the original dataset. No validation sets will be generated when using bootstrap sampling. The following sub-options are supported:
"Runs" 5 number of bootstrap samples generated
"BootStrapSize" Scaled[1] number of elements to generate in each bootstrap sample. When specified as Scaled[f], a fraction f of the dataset will be used
The "ValidationFunction" option can be used to specify a custom function that gets applied to the fit result and the validation data.
The "ParallelQ" option can be used to parallelize the computation using ParallelTable. Sub-options for ParallelTable can be specified as "ParallelQ" {True, opts }.

Examples

Basic Examples

Generate some data and visualize it:

In[1]:=
data = RandomVariate[PoissonDistribution[2], 100];
Histogram[data, {1}, PlotRange -> All]
Out[2]=

Cross-validate the data by fitting a PoissonDistribution to five different training/validation splits of the data:

In[3]:=
val = ResourceFunction["CrossValidateModel"][data, PoissonDistribution[\[Lambda]]];
Dataset[val]
Out[4]=

The fitted values of λ can be extracted using Part:

In[5]:=
val[[All, "FittedModel", 1]]
Out[5]=

Show the spread of the obtained values of λ and negative LogLikelihoods that were found:

In[6]:=
BoxWhiskerChart[{{val[[All, "FittedModel", 1]], val[[All, "ValidationResult"]]}}, "Outliers", ChartLabels -> {"\[Lambda]", "Loss (-LogLikelihood)"}]
Out[6]=

Compare with a fit by another distribution. A lower negative log-likelihood indicates a better fit:

In[7]:=
dists = {GeometricDistribution[p], PoissonDistribution[\[Lambda]]};
val2 = ResourceFunction["CrossValidateModel"][data, dists];
BoxWhiskerChart[{Merge[val2[[All, "ValidationResult"]], Identity]}, "Outliers", ChartLabels -> Automatic, PlotLabel -> "Loss (-LogLikelihood)"]
Out[9]=

Scope

CrossValidateModel works with LearnDistribution:

In[10]:=
ResourceFunction["CrossValidateModel"][
  RandomVariate[LaplaceDistribution[], 100], Function[LearnDistribution[#]]][[All, "ValidationResult"]]
Out[10]=

CrossValidateModel works with LinearModelFit, GeneralizedLinearModelFit and NonlinearModelFit. For fit models that return a FittedModel object, CrossValidateModel will calculate the RootMeanSquare of the residuals of the fit:

In[11]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
crossVal = ResourceFunction["CrossValidateModel"][
   data,
   <|
    "Linear" -> Function[LinearModelFit[#, {x, y}, {x, y}]],
    "Quadratic" -> Function[LinearModelFit[#, {x, y, x y, x^2, y^2}, {x, y}]],
    "Trig" -> Function[LinearModelFit[#, {Sin[x], Cos[y]}, {x, y}]]
    |>
   ];
BoxWhiskerChart[
 Merge[crossVal[[All, "ValidationResult"]], Identity], "Outliers", ChartLabels -> Automatic, PlotLabel -> "Loss (residual RootMeanSquare)"]
Out[13]=

CrossValidateModel also works with Fit if you specify that the the returned model is a Function:

In[14]:=
crossVal = ResourceFunction["CrossValidateModel"][
   data,
   Function[Fit[#, {1, x, y}, {x, y}, "Function"]]
   ];
Dataset[crossVal]
Out[15]=

To use the parameter rules generated by FindFit with the default validation option, it is necessary to specify the option "ValidationFunction"{Automatic,{fitExpr,{var1,var2,}},} to parse the fit expression and the independent variables into the default validation function:

In[16]:=
fitExpression = a + Sin[d + b x + c y];
crossVal = ResourceFunction["CrossValidateModel"][
   data,
   Function[FindFit[#, fitExpression, {a, b, c, d}, {x, y}]],
   "ValidationFunction" -> {
     Automatic,
     {(* specify the fit expression and the ind*) fitExpression,
      {x, y}
      }
     }
   ];
Dataset[crossVal]
Out[18]=

When the fitting function uses Predict or Classify, the validation will be done with PredictorMeasurements or ClassifierMeasurements:

In[19]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
ResourceFunction["CrossValidateModel"][
  data[[All, {1, 2}]] -> data[[All, 3]], Function[Predict[#, TimeGoal -> 5]]] // Short
Out[20]=
In[21]:=
ResourceFunction["CrossValidateModel"][
  ExampleData[{"MachineLearning", "Titanic"}, "TrainingData"], Function[Classify[#, TimeGoal -> 5]]] // Short
Out[21]=

CrossValidateModel can also be used with NetTrain. First define a network and pre-train it:

In[22]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
net = NetTrain[NetChain[{10, Ramp, 10, Ramp, LinearLayer[]}], data[[All, {1, 2}]] -> data[[All, 3]], TimeGoal -> 5]
Out[23]=

Perform the cross validation. It is recommended to make sure that NetTrain returns a NetTrainResultsObject by using All as the third argument. The returned "ValidationResult" is the average loss of the trained network evaluated on the validation set:

In[24]:=
val = ResourceFunction["CrossValidateModel"][
   data[[All, {1, 2}]] -> data[[All, 3]], Function[NetTrain[net, #, All, TimeGoal -> 5]]];
val[[All, "ValidationResult"]]
Out[25]=

Options

Method

Suboptions "Folds" and "Runs" are support for the "KFold" method. Note how the total number of model fits is equal to "Folds" × "Runs":

In[26]:=
crossVal = ResourceFunction["CrossValidateModel"][
   RandomVariate[PoissonDistribution[2], 100], PoissonDistribution[\[Lambda]],
   Method -> {"KFold", "Folds" -> 10, "Runs" -> 4}
   ];
Length[crossVal]
Out[27]=

Here is an example input using the "LeaveOneOut" method:

In[28]:=
crossVal = ResourceFunction["CrossValidateModel"][
   RandomVariate[PoissonDistribution[2], 100], PoissonDistribution[\[Lambda]],
   Method -> "LeaveOneOut",
   "ParallelQ" -> True
   ];
Length[crossVal]
Out[29]=

Use the "RandomSubSampling" method. The "SamplingFunction" is set in such a way that it is equivalent to the default option "SamplingFunction" Automatic:

In[30]:=
crossVal = ResourceFunction["CrossValidateModel"][
   RandomVariate[PoissonDistribution[2], 100], PoissonDistribution[\[Lambda]],
   Method -> {"RandomSubSampling",
     "Runs" -> 100,
     ValidationSet -> Scaled[0.1], (* This illustrates how the default "SamplingFunction" can be \
replicated as an explicit Function *) "SamplingFunction" -> Function[{nData, nVal},(* The sampling function accepts the number of data points and \
number of validation points as inputs*) AssociationThread[(*The output should be an Association with 2 \
lists of indices 1 \[LessEqual] i \[LessEqual] nData *)
        {"TrainingSet", "ValidationSet"},
        TakeDrop[RandomSample[Range[nData]], nData - nVal]
        ]
       ]
     },
   "ParallelQ" -> True
   ] // Short
Out[30]=

In this example, we use bootstrap sampling to estimate the error bars on the value of λ:

In[31]:=
bootStrap = ResourceFunction["CrossValidateModel"][
   RandomVariate[PoissonDistribution[2], 100], PoissonDistribution[\[Lambda]],
   Method -> {"BootStrap", "Runs" -> 1000, "BootStrapSize" -> Scaled[1]},
   "ParallelQ" -> True
   ];
BoxWhiskerChart[bootStrap[[All, "FittedModel", 1]], "Outliers", ChartLabels -> {"\[Lambda]"}]
Out[32]=

ValidationFunction

Specify a function that calculates the CDF fitted distribution at the points given by the validation data. We know that if the fitted distribution is the true distribution, then the obtained points should be distributed as UniformDistribution[0, 1], as can be illustrated with TransformedDistribution:

In[33]:=
PDF[TransformedDistribution[
  CDF[NormalDistribution[\[Mu], \[Sigma]], x], x \[Distributed] NormalDistribution[\[Mu], \[Sigma]]]]
Out[33]=

By calculating the CDF of the validation data, we can combine the results and use a QuantilePlot to judge if the validation data follows the same distribution as the fit. The quantile plot suggests that the ExponentialDistribution gives the best fit (as expected):

In[34]:=
cdfValues = ResourceFunction["CrossValidateModel"][
   RandomVariate[ExponentialDistribution[1], 1000],
   {HalfNormalDistribution[\[Sigma]], ExponentialDistribution[\[Lambda]], LogNormalDistribution[\[Mu], \[Sigma]]},
   "ValidationFunction" -> Function[{fittedDistribution, validationData},
     CDF[fittedDistribution, validationData] (*This should be uniformly distributed if the \
fit is good *)
     ]
   ];
QuantilePlot[Merge[cdfValues[[All, "ValidationResult"]], Flatten], UniformDistribution[], PlotLegends -> Keys[cdfValues[[1, 1]]]]
Out[35]=

For fitting functions that return a FittedModel like NonlinearModelFit, you can specify the validation function as "ValidationFunction" {Automatic,Function[]}. In that case, the function will be applied to the fitted values and true values in the validation set:

In[36]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
ResourceFunction["CrossValidateModel"][
  data,
  Function[
   NonlinearModelFit[#, amp*Sin[a x + b y + c] + d, {amp, a, b, c, d}, {x, y}]],
  "ValidationFunction" -> {Automatic,
    Function[{fittedVals, trueVals}, <|
      "MeanAbsResiduals" -> Mean[Abs[fittedVals - trueVals]],
      "ComparisonPlot" -> ListPlot[SortBy[First]@Transpose[{fittedVals, trueVals}]]
      |>
     ]
    }
  ][[All, "ValidationResult"]]
Out[37]=

For Predict and Classify, the form "ValidationFunction" {Automatic,prop} can be used to extract specific properties from PredictorMeasurements or ClassifierMeasurements. As an example, we can generate multiple prediction comparison plots and show them together in a single graph:

In[38]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
ResourceFunction["CrossValidateModel"][
   data[[All, {1, 2}]] -> data[[All, 3]],
   Function[Predict[#, TimeGoal -> 5]],
   "ValidationFunction" -> {Automatic, "ComparisonPlot"}
   ][[All, "ValidationResult"]] // Show
Out[39]=

For neural networks trained with NetTrain, the form "ValidationFunction" {Automatic,args} can be used to supply arguments to NetMeasurements[trainedNet,validationData,args]:

In[40]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
net = NetTrain[NetChain[{10, Ramp, 10, Ramp, LinearLayer[]}], data[[All, {1, 2}]] -> data[[All, 3]], TimeGoal -> 5];
ResourceFunction["CrossValidateModel"][
  data[[All, {1, 2}]] -> data[[All, 3]], Function[NetTrain[net, #, All, TimeGoal -> 5]],
  "ValidationFunction" -> {Automatic, "MeanDeviation"}
  ][[All, "ValidationResult"]]
Out[41]=

With "ValidationFunction" None, no validation will be performed:

In[42]:=
ResourceFunction["CrossValidateModel"][
 RandomVariate[PoissonDistribution[2], 100], PoissonDistribution[\[Lambda]], "ValidationFunction" -> None]
Out[42]=

ParallelQ

Use parallelization to speed up the cross-validation when doing many validation runs. Sub-options for ParallelTable can specified as "ParallelQ" {True,opts}:

In[43]:=
crossVal = ResourceFunction["CrossValidateModel"][
   RandomVariate[PoissonDistribution[2], 100], PoissonDistribution[\[Lambda]],
   Method -> {"KFold", "Folds" -> 10, "Runs" -> 500},
   "ParallelQ" -> {True, Method -> "CoarsestGrained"}
   ];
BoxWhiskerChart[crossVal[[All, "ValidationResult"]], "Outliers"]
Out[44]=

Applications

Try a number of different continuous distributions to fit some ExampleData. The BoxWhiskerChart shows that the LogNormalDistribution, GammaDistribution and FrechetDistribution have the lowest losses and are therefore the best candidates:

In[45]:=
data = ExampleData[{"Statistics", "RiverLengths"}];
dists = {
   NormalDistribution[\[Mu], \[Sigma]], CauchyDistribution[a, b], HalfNormalDistribution[\[Sigma]],
   RayleighDistribution[\[Sigma]], LogNormalDistribution[\[Mu], \[Sigma]],
   GammaDistribution[\[Alpha], \[Beta]], FrechetDistribution[\[Alpha], \[Beta], \[Mu]], ExponentialDistribution[\[Lambda]], ParetoDistribution[k, \[Alpha]], PowerDistribution[k, \[Alpha]]
   };
val = ResourceFunction["CrossValidateModel"][data, dists,
   Method -> {"KFold", "Runs" -> 10},
   "ParallelQ" -> True
   ];
BoxWhiskerChart[{Merge[val[[All, "ValidationResult"]], Identity]}, "Outliers", ChartLabels -> Thread@Rotate[dists, 90 Degree], PlotLabel -> "Loss (-LogLikelihood)"]
Out[46]=

Neat Examples

Compare the residual RootMeanSquare of a NonlinearModelFit and Predict:

In[47]:=
data = Flatten[
   Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5, 5, 0.2}, {y, -5, 5, 0.2}], 1];
crossVal = ResourceFunction["CrossValidateModel"][
   Flatten[
    Table[{x, y, Sin[x + y] + RandomVariate[NormalDistribution[0, 0.2]]}, {x, -5,
       5, 0.2}, {y, -5, 5, 0.2}], 1],
   <| "NonlinearModelFit" -> Function[
      NonlinearModelFit[#, Sin[a x + b y + c], {a, b, c}, {x, y}]],
    "Predict" -> Function[Predict[#[[All, ;; 2]] -> #[[All, 3]]]]
    |>,
   "ValidationFunction" -> <|
     (* the default validation function will be used for \
NonlinearModelFit, so it does not need to be specified *) "Predict" -> Function[
       RootMeanSquare@
        PredictorMeasurements[#1, #2[[All, ;; 2]] -> #2[[All, 3]], "Residuals"]]
     |>,
   "ParallelQ" -> True
   ];
BoxWhiskerChart[
 Merge[crossVal[[All, "ValidationResult"]], Identity], "Outliers", ChartLabels -> Automatic, PlotLabel -> "Loss (residual RootMeanSquare)"]
Out[48]=

Resource History

Related Resources

License Information