summaryrefslogtreecommitdiff
path: root/gnu/packages/machine-learning.scm
diff options
context:
space:
mode:
authorRicardo Wurmus <rekado@elephly.net>2023-07-10 13:13:11 +0200
committerRicardo Wurmus <rekado@elephly.net>2023-07-10 13:13:45 +0200
commitd0296970fb8ed97ac17bd4c580351af961a8c0fb (patch)
tree0a2d2b61c876da7fa72d61b1cf51b8aa37b22ed5 /gnu/packages/machine-learning.scm
parente3d9d896b540f82e4511f2bd6ae6373390ee2d4d (diff)
gnu: Add python-captum.
* gnu/packages/machine-learning.scm (python-captum): New variable.
Diffstat (limited to 'gnu/packages/machine-learning.scm')
-rw-r--r--gnu/packages/machine-learning.scm45
1 files changed, 45 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index 5b98705943..f50398b555 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -3868,6 +3868,51 @@ AI services.")
Actions for the Lightning suite of libraries.")
(license license:asl2.0)))
+(define-public python-captum
+ (package
+ (name "python-captum")
+ (version "0.6.0")
+ (source (origin
+ (method git-fetch)
+ (uri (git-reference
+ (url "https://github.com/pytorch/captum")
+ (commit (string-append "v" version))))
+ (file-name (git-file-name name version))
+ (sha256
+ (base32
+ "1h4n91ivhjxm6wj0vgqpfss2dmq4sjcp0appd08cd5naisabjyb5"))))
+ (build-system pyproject-build-system)
+ (arguments
+ (list
+ #:test-flags
+ '(list "-k"
+ ;; These two tests (out of more than 1000 tests) fail because of
+ ;; accuracy problems.
+ "not test_softmax_classification_batch_multi_target\
+ and not test_softmax_classification_batch_zero_baseline")))
+ (propagated-inputs (list python-matplotlib python-numpy python-pytorch))
+ (native-inputs (list jupyter
+ python-annoy
+ python-black
+ python-flake8
+ python-flask
+ python-flask-compress
+ python-ipython
+ python-ipywidgets
+ python-mypy
+ python-parameterized
+ python-pytest
+ python-pytest-cov
+ python-scikit-learn))
+ (home-page "https://captum.ai")
+ (synopsis "Model interpretability for PyTorch")
+ (description "Captum is a model interpretability and understanding library
+for PyTorch. Captum contains general purpose implementations of integrated
+gradients, saliency maps, smoothgrad, vargrad and others for PyTorch models.
+It has quick integration for models built with domain-specific libraries such
+as torchvision, torchtext, and others.")
+ (license license:bsd-3)))
+
(define-public python-readchar
(package
(name "python-readchar")