This function enables or disables automatic differentiation using the JAX package in Python, which can considerably speed up and increase the accuracy of standard errors when a model includes many parameters.
Arguments
- autodiff
Logical flag. If
TRUE, enables automatic differentiation with JAX. IfFALSE(default), disables automatic differentiation and reverts to finite difference methods.- install
Logical flag. If
TRUE, installs themarginaleffectsPython package viareticulate::py_install(). Default isFALSE. This is only necessary if you are self-managing a Python installation.
Details
Automatic differentiation needs to be enabled once per session.
When autodiff = TRUE, this function:
Imports the
marginaleffects.autodiffPython module viareticulate::import()Sets the internal jacobian function to use JAX-based automatic differentiation
Provides faster and more accurate gradient computation for supported models
Falls back on the default finite difference method for unsupported models and calls.
Currently supports:
Model types:
lm,glm,ols,lrmFunctions:
predictions()andcomparisons(), along withavg_andplot_variants.type: "response" or "link"by:TRUE,FALSE, or character vector.comparison: "difference" and "ratio"
For unsupported models or options, the function automatically falls back to finite difference methods with a warning.
Python Configuration
By default, no manual configuration of Python should be necessary. On most
machines, unless you have explicitly configured reticulate, reticulate
defaults to an automatically managed ephemeral virtual environment with all
Python requirements declared via reticulate::py_require().
If you prefer to use a manually managed Python installation, you can direct
reticulate and specify which Python executable or environment to use.
reticulate selects a Python installation using its Order of Discovery.
As a convenience autodiff(install=TRUE) will install the marginaleffects Python
package in a self-managed virtual environment.
To specify an alternate Python version:
library(reticulate)
use_python("/usr/local/bin/python")To use a virtual environment:
use_virtualenv("myenv")These configuration commands should be called before calling autodiff().
Examples
if (FALSE) { # \dontrun{
# Install the Python package (only needed once)
autodiff(install = TRUE)
# Enable automatic differentiation
autodiff(TRUE)
# Fit a model and compute marginal effects
mod <- glm(am ~ hp + wt, data = mtcars, family = binomial)
avg_comparisons(mod) # Will use JAX for faster computation
# Disable automatic differentiation
autodiff(FALSE)
} # }