bart.bart#
Module: bart.bart#
Inheritance diagram for ISLP.bart.bart:
Classes#
BART#
- class ISLP.bart.bart.BART(num_trees=200, num_particles=10, max_stages=5000, split_prob=<function BART.<lambda>>, min_depth=0, std_scale=2, split_prior=None, ndraw=10, burnin=100, sigma_prior=(5, 0.9), num_quantile=50, random_state=None, n_jobs=-1)#
Bases:
BaseEnsemble,RegressorMixinParticle Gibbs BART sampling step
- Parameters:
- num_particlesint
Number of particles for the conditional SMC sampler. Defaults to 10
- max_stagesint
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
Notes
This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces several changes. The changes will be properly documented soon.
References
[Lakshminarayanan2015]Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015), Particle Gibbs for Bayesian Additive Regression Trees. ArviX, link
- Attributes:
base_estimator_Estimator used to grow the ensemble.
Methods
get_params([deep])Get parameters for this estimator.
init_particles(base_particle, sigmasq, resid)Initialize particles
score(X, y[, sample_weight])Return the coefficient of determination of the prediction.
set_params(**params)Set the parameters of this estimator.
fit
predict
staged_predict
- __init__(num_trees=200, num_particles=10, max_stages=5000, split_prob=<function BART.<lambda>>, min_depth=0, std_scale=2, split_prior=None, ndraw=10, burnin=100, sigma_prior=(5, 0.9), num_quantile=50, random_state=None, n_jobs=-1)#
- property base_estimator_#
Estimator used to grow the ensemble.
- fit(X, Y, sample_weight=None)#
- get_params(deep=True)#
Get parameters for this estimator.
- Parameters:
- deepbool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
- paramsdict
Parameter names mapped to their values.
- init_particles(base_particle: ParticleTree, sigmasq: float, resid: ndarray) ndarray#
Initialize particles
- predict(X)#
- score(X, y, sample_weight=None)#
Return the coefficient of determination of the prediction.
The coefficient of determination \(R^2\) is defined as \((1 - \frac{u}{v})\), where \(u\) is the residual sum of squares
((y_true - y_pred)** 2).sum()and \(v\) is the total sum of squares((y_true - y_true.mean()) ** 2).sum(). The best possible score is 1.0 and it can be negative (because the model can be arbitrarily worse). A constant model that always predicts the expected value of y, disregarding the input features, would get a \(R^2\) score of 0.0.- Parameters:
- Xarray-like of shape (n_samples, n_features)
Test samples. For some estimators this may be a precomputed kernel matrix or a list of generic objects instead with shape
(n_samples, n_samples_fitted), wheren_samples_fittedis the number of samples used in the fitting for the estimator.- yarray-like of shape (n_samples,) or (n_samples, n_outputs)
True values for X.
- sample_weightarray-like of shape (n_samples,), default=None
Sample weights.
- Returns:
- scorefloat
\(R^2\) of
self.predict(X)w.r.t. y.
Notes
The \(R^2\) score used when calling
scoreon a regressor usesmultioutput='uniform_average'from version 0.23 to keep consistent with default value ofr2_score(). This influences thescoremethod of all the multioutput regressors (except forMultiOutputRegressor).
- set_params(**params)#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline). The latter have parameters of the form<component>__<parameter>so that it’s possible to update each component of a nested object.- Parameters:
- **paramsdict
Estimator parameters.
- Returns:
- selfestimator instance
Estimator instance.
- staged_predict(X, start_idx=0)#
SampleSplittingVariable#
- class ISLP.bart.bart.SampleSplittingVariable(alpha_prior, random_state)#
Bases:
objectMethods
rvs
- __init__(alpha_prior, random_state)#
Sample splitting variables proportional to alpha_prior.
This is equivalent as sampling weights from a Dirichlet distribution with alpha_prior parameter and then using those weights to sample from the available spliting variables. This enforce sparsity.
- rvs()#