summaryrefslogtreecommitdiff
path: root/gnu/packages/machine-learning.scm
diff options
context:
space:
mode:
Diffstat (limited to 'gnu/packages/machine-learning.scm')
-rw-r--r--gnu/packages/machine-learning.scm45
1 files changed, 45 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 5b98705943..f50398b555 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -3868,6 +3868,51 @@ AI services.")
Actions for the Lightning suite of libraries.")
(license license:asl2.0)))
+(define-public python-captum
+ (package
+ (name "python-captum")
+ (version "0.6.0")
+ (source (origin
+ (method git-fetch)
+ (uri (git-reference
+ (url "https://github.com/pytorch/captum")
+ (commit (string-append "v" version))))
+ (file-name (git-file-name name version))
+ (sha256
+ (base32
+ "1h4n91ivhjxm6wj0vgqpfss2dmq4sjcp0appd08cd5naisabjyb5"))))
+ (build-system pyproject-build-system)
+ (arguments
+ (list
+ #:test-flags
+ '(list "-k"
+ ;; These two tests (out of more than 1000 tests) fail because of
+ ;; accuracy problems.
+ "not test_softmax_classification_batch_multi_target\
+ and not test_softmax_classification_batch_zero_baseline")))
+ (propagated-inputs (list python-matplotlib python-numpy python-pytorch))
+ (native-inputs (list jupyter
+ python-annoy
+ python-black
+ python-flake8
+ python-flask
+ python-flask-compress
+ python-ipython
+ python-ipywidgets
+ python-mypy
+ python-parameterized
+ python-pytest
+ python-pytest-cov
+ python-scikit-learn))
+ (home-page "https://captum.ai")
+ (synopsis "Model interpretability for PyTorch")
+ (description "Captum is a model interpretability and understanding library
+for PyTorch. Captum contains general purpose implementations of integrated
+gradients, saliency maps, smoothgrad, vargrad and others for PyTorch models.
+It has quick integration for models built with domain-specific libraries such
+as torchvision, torchtext, and others.")
+ (license license:bsd-3)))
+
(define-public python-readchar
(package
(name "python-readchar")