summaryrefslogtreecommitdiff
path: root/gnu/packages/machine-learning.scm
diff options
context:
space:
mode:
authorVinicius Monego <monego@posteo.net>2023-07-29 14:52:53 -0300
committerVinicius Monego <monego@posteo.net>2023-09-16 23:10:52 -0300
commit70b8682eaa94b5a6517a8fe733eb599ca2a778e3 (patch)
tree69f8d1c2b8658089775b4b7f93c5dd72ff25f1c8 /gnu/packages/machine-learning.scm
parentee17fdfe42daab9ef7849a2d204fc69179daf1bf (diff)
gnu: Add python-jaxtyping.
* gnu/packages/machine-learning.scm (python-jaxtyping): New variable.
Diffstat (limited to 'gnu/packages/machine-learning.scm')
-rw-r--r--gnu/packages/machine-learning.scm25
1 files changed, 25 insertions, 0 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm
index fd0be8d500..92a60d1616 100644
--- a/gnu/packages/machine-learning.scm
+++ b/gnu/packages/machine-learning.scm
@@ -2060,6 +2060,31 @@ physics-informed learning. It includes implementations for the PINN
MFNN (multifidelity neural network) algorithms.")
(license license:lgpl2.1+)))
+(define-public python-jaxtyping
+ (package
+ (name "python-jaxtyping")
+ (version "0.2.21")
+ (source (origin
+ (method url-fetch)
+ (uri (pypi-uri "jaxtyping" version))
+ (sha256
+ (base32
+ "19qmsnbn4wv2jl99lpn622qs49mrfxmx8s9pr5y8izzgdjq1fvii"))))
+ (build-system pyproject-build-system)
+ ;; Tests require JAX, but JAX can't be packaged because it uses the Bazel
+ ;; build system.
+ (arguments (list #:tests? #f))
+ (native-inputs (list python-hatchling))
+ (propagated-inputs (list python-numpy python-typeguard
+ python-typing-extensions))
+ (home-page "https://github.com/google/jaxtyping")
+ (synopsis
+ "Type annotations and runtime checking for JAX arrays and others")
+ (description "@code{jaxtyping} provides type annotations and runtime
+checking for shape and dtype of JAX arrays, PyTorch, NumPy, TensorFlow, and
+PyTrees.")
+ (license license:expat)))
+
;; There have been no proper releases yet.
(define-public kaldi
(let ((commit "be22248e3a166d9ec52c78dac945f471e7c3a8aa")