sklearn.linear model.Lasso System

From GM-RKB
Jump to navigation Jump to search

A sklearn.linear_model.Lasso System is a linear least-squares L1-regularized regression system within sklearn.linear_model (that implements a LASSO algorithm to solve a LASSO task).

  • Context:
    • Usage:
1) Import Lasso Regression model from scikit-learn : from sklearn.linear_model import Lasso
2) Create design matrix X and response vector Y
3) Create Lasso Regression object: lasso=Lasso(alpha=alpha[,fit_intercept=True, normalize=False,...])
4) Choose method(s):
  • Fit model with coordinate descent: lasso.fit(X, Y[, check_input]))
  • Predict Y using the linear model with estimated coefficients: Y_pred = lasso.predict(X)
  • Return coefficient of determination (R^2) of the prediction: lasso.score(X,Y[, sample_weight=w])
  • Compute elastic net path with coordinate descent: lasso.path(X, y[, l1_ratio, eps, n_alphas,...])
  • Get estimator parameters: lasso.get_params([deep])
  • Set estimator parameters: lasso.set_params(**params)
Input: Output:
from sklearn.cross_validation import KFold
from sklearn.linear_model import Lasso
from sklearn.datasets import load_boston
boston = load_boston()
x = np.array([np.concatenate((v,[1])) for v in boston.data])
y = boston.target
lasso = Lasso(fit_intercept=True, alpha=0.5)
lasso.fit(x,y)
p = lasso.predict(x)
plot(p, y,'ro')
boston lasso10fold.png
err = p-y
total_error = np.dot(err,err)
rmse_train = np.sqrt(total_error/len(p))
kf = KFold(len(x), n_folds=10)
xval_err = 0
for train,test in kf:
lasso.fit(x[train],y[train])
p = lasso.predict(x[test])
e = p-y[test]
xval_err += np.dot(e,e)
rmse_10cv = np.sqrt(xval_err/len(x))
Method: Lasso Regression
RMSE on training set: 4.9141
RMSE on 10-fold CV: 5.7368


References

2017a

Linear Model trained with L1 prior as regularizer (aka the Lasso)
The optimization objective for Lasso is:
(1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
Technically the Lasso model is optimizing the same objective function as the Elastic Net with l1_ratio=1.0 (no L2 penalty).
Read more in the User Guide.

2017b

2017c

2017D