/*
webid-oidc, implementation of the Solid specification
Copyright (C) 2020, 2021 Vivien Kraus
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
*/
#ifndef H_UTILITIES_INCLUDED
#define H_UTILITIES_INCLUDED
#ifdef HAVE_CONFIG_H
#include
#endif /* HAVE_CONFIG_H */
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
/* The symbols are used as parameter names for crypto keys */
static SCM p256;
static SCM p384;
static SCM p521;
static SCM kcrv;
static SCM kx;
static SCM ky;
static SCM kd;
static SCM kn;
static SCM ke;
static SCM kp;
static SCM kq;
static SCM kdp;
static SCM kdq;
static SCM kqi;
/* Return a base64 encoding of some raw data. */
static SCM wrap_bytevector (size_t length, uint8_t * data);
/* Return a base64 encoding of the bigint z. */
static SCM wrap_mpz_t (mpz_t z);
/* Return an alist for a key */
static SCM wrap_ecc_point (const struct ecc_curve *crv,
const struct ecc_point *x);
static SCM wrap_ecc_scalar (const struct ecc_curve *crv,
const struct ecc_scalar *x);
static SCM wrap_ecc_key_pair (const struct ecc_curve *crv,
const struct ecc_point *point,
const struct ecc_scalar *scalar);
static SCM wrap_rsa_public_key (struct rsa_public_key *x);
static SCM wrap_rsa_private_key (struct rsa_private_key *x);
static SCM wrap_rsa_key_pair (struct rsa_public_key *pub,
struct rsa_private_key *key);
/* Decode a base64 of binary data. */
static uint8_t *get_as_bytevector (SCM data, size_t *size, int throw_if_fail);
/* Parse a bigint (data) as base64 and store it in x. */
static int do_mpz_t_load (mpz_t x, SCM data, int throw_if_fail);
/* Parse an ECC curve */
static const struct ecc_curve *do_ecc_curve_load (SCM data,
int throw_if_fail);
/* Set up a key, and return whether it was OK */
static int do_ecc_point_load (struct ecc_point *x, SCM data);
static int do_ecc_scalar_load (struct ecc_scalar *x, SCM data);
static int do_rsa_public_key_load (struct rsa_public_key *x, SCM data);
static int do_rsa_private_key_load (struct rsa_private_key *x, SCM data);
/* Register x to be destroyed at the end of the dynamic wind. */
static void dynwind_mpz_t_clear (mpz_t * x);
static void dynwind_ecc_point_clear (struct ecc_point *x);
static void dynwind_ecc_scalar_clear (struct ecc_scalar *x);
static void dynwind_rsa_public_key_clear (struct rsa_public_key *x);
static void dynwind_rsa_private_key_clear (struct rsa_private_key *x);
static inline SCM
wrap_bytevector (size_t length, uint8_t * data)
{
char *head;
char tail[BASE64_ENCODE_FINAL_LENGTH];
size_t head_size, tail_size;
char *full;
struct base64_encode_ctx encoder;
SCM ret;
base64url_encode_init (&encoder);
scm_dynwind_begin (0);
head = scm_malloc (BASE64_ENCODE_LENGTH (length));
scm_dynwind_free (head);
head_size = base64_encode_update (&encoder, head, length, data);
tail_size = base64_encode_final (&encoder, tail);
while (tail_size != 0 && tail[tail_size - 1] == '=')
{
tail_size--;
}
full = scm_malloc (head_size + tail_size);
memcpy (full, head, head_size);
memcpy (full + head_size, tail, tail_size);
ret = scm_from_utf8_stringn (full, head_size + tail_size);
scm_dynwind_end ();
return ret;
}
static uint8_t *
export_mpz_t (mpz_t data, size_t *length)
{
size_t check_length;
uint8_t *ret = NULL;
*length = (mpz_sizeinbase (data, 2) + 7) / 8;
ret = scm_malloc (*length);
mpz_export (ret, &check_length, 1, 1, 1, 0, data);
assert (*length == check_length);
return ret;
}
static inline SCM
wrap_mpz_t (mpz_t data)
{
size_t length;
uint8_t *c_data;
SCM ret;
scm_dynwind_begin (0);
c_data = export_mpz_t (data, &length);
scm_dynwind_free (c_data);
ret = wrap_bytevector (length, c_data);
scm_dynwind_end ();
return ret;
}
static SCM
wrap_ecc_curve (const struct ecc_curve *crv)
{
static const struct ecc_curve *p_256 = NULL;
static const struct ecc_curve *p_384 = NULL;
static const struct ecc_curve *p_521 = NULL;
static int init = 0;
if (!init)
{
init = 1;
p_256 = nettle_get_secp_256r1 ();
p_384 = nettle_get_secp_384r1 ();
p_521 = nettle_get_secp_521r1 ();
}
if (crv == p_256)
{
return p256; /* the symbol */
}
if (crv == p_384)
{
return p384;
}
if (crv == p_521)
{
return p521;
}
abort ();
return SCM_UNDEFINED;
}
static inline SCM
wrap_ecc_point (const struct ecc_curve *crv, const struct ecc_point *point)
{
mpz_t x, y;
SCM ret;
scm_dynwind_begin (0);
mpz_init (x);
dynwind_mpz_t_clear (&x);
mpz_init (y);
dynwind_mpz_t_clear (&y);
ecc_point_get (point, x, y);
ret =
scm_list_3 (scm_cons (kcrv, wrap_ecc_curve (crv)),
scm_cons (kx, wrap_mpz_t (x)), scm_cons (ky, wrap_mpz_t (y)));
scm_dynwind_end ();
return ret;
}
static inline SCM
wrap_ecc_scalar (const struct ecc_curve *crv, const struct ecc_scalar *scalar)
{
mpz_t z;
SCM ret;
scm_dynwind_begin (0);
mpz_init (z);
dynwind_mpz_t_clear (&z);
ecc_scalar_get (scalar, z);
ret =
scm_list_2 (scm_cons (kcrv, wrap_ecc_curve (crv)),
scm_cons (kd, wrap_mpz_t (z)));
scm_dynwind_end ();
return ret;
}
static inline SCM
wrap_ecc_key_pair (const struct ecc_curve *crv, const struct ecc_point *point,
const struct ecc_scalar *scalar)
{
return scm_append (scm_list_2 (wrap_ecc_point (crv, point),
wrap_ecc_scalar (crv, scalar)));
}
static inline SCM
wrap_rsa_public_key (struct rsa_public_key *x)
{
return scm_list_2 (scm_cons (kn, wrap_mpz_t (x->n)),
scm_cons (ke, wrap_mpz_t (x->e)));
}
static inline SCM
wrap_rsa_private_key (struct rsa_private_key *x)
{
return scm_list_n (scm_cons (kd, wrap_mpz_t (x->d)),
scm_cons (kp, wrap_mpz_t (x->p)),
scm_cons (kq, wrap_mpz_t (x->q)),
scm_cons (kdp, wrap_mpz_t (x->a)),
scm_cons (kdq, wrap_mpz_t (x->b)),
scm_cons (kqi, wrap_mpz_t (x->c)), SCM_UNDEFINED);
}
static inline SCM
wrap_rsa_key_pair (struct rsa_public_key *pub, struct rsa_private_key *key)
{
return scm_append (scm_list_2 (wrap_rsa_public_key (pub),
wrap_rsa_private_key (key)));
}
static inline uint8_t *
get_as_bytevector (SCM data, size_t *size, int throw_if_fail)
{
uint8_t *ret = NULL;
size_t data_length;
char *data_str = NULL;
struct base64_decode_ctx decoder;
int ok = 1;
if (!scm_is_string (data) && !throw_if_fail)
{
return NULL;
}
base64url_decode_init (&decoder);
scm_dynwind_begin (0);
data_str = scm_to_utf8_stringn (data, &data_length);
scm_dynwind_free (data_str);
ret = scm_malloc (BASE64_DECODE_LENGTH (data_length));
/* Not protected! Nothing will throw until scm_dynwind_end. */
ok = base64_decode_update (&decoder, size, ret, data_length, data_str);
scm_dynwind_end ();
if (!ok)
{
ret = NULL;
if (throw_if_fail)
{
SCM base64_decoding_error =
scm_from_utf8_symbol ("base64-decoding-error");
scm_throw (base64_decoding_error, scm_list_1 (data));
}
}
return ret;
}
static inline int
do_mpz_t_load (mpz_t x, SCM data, int throw_if_fail)
{
size_t size;
uint8_t *c_data;
int ret = 1;
scm_dynwind_begin (0);
c_data = get_as_bytevector (data, &size, throw_if_fail);
if (c_data)
{
scm_dynwind_free (c_data);
mpz_import (x, size, 1, 1, 1, 0, c_data);
}
else
{
ret = 0;
}
scm_dynwind_end ();
return ret;
}
static inline const struct ecc_curve *
do_ecc_curve_load (SCM crv, int throw_if_fail)
{
if (scm_is_string (crv))
{
return do_ecc_curve_load (scm_string_to_symbol (crv), throw_if_fail);
}
if (scm_is_eq (crv, p256))
{
return nettle_get_secp_256r1 ();
}
if (scm_is_eq (crv, p384))
{
return nettle_get_secp_384r1 ();
}
if (scm_is_eq (crv, p521))
{
return nettle_get_secp_521r1 ();
}
if (throw_if_fail)
{
scm_throw (scm_from_utf8_symbol ("unsupported-crv"), scm_list_1 (crv));
}
return NULL;
}
static inline int
do_ecc_point_load (struct ecc_point *point, SCM data)
{
mpz_t x, y;
int ret = 1;
scm_dynwind_begin (0);
mpz_init (x);
dynwind_mpz_t_clear (&x);
mpz_init (y);
dynwind_mpz_t_clear (&y);
ret =
(do_mpz_t_load (x, scm_assq_ref (data, kx), 0)
&& do_mpz_t_load (y, scm_assq_ref (data, ky), 0)
&& ecc_point_set (point, x, y));
scm_dynwind_end ();
return ret;
}
static inline int
do_ecc_scalar_load (struct ecc_scalar *scalar, SCM data)
{
mpz_t z;
int ret = 1;
scm_dynwind_begin (0);
mpz_init (z);
dynwind_mpz_t_clear (&z);
ret =
(do_mpz_t_load (z, scm_assq_ref (data, kd), 0)
&& ecc_scalar_set (scalar, z));
scm_dynwind_end ();
return ret;
}
static inline int
do_rsa_public_key_load (struct rsa_public_key *x, SCM data)
{
return (do_mpz_t_load (x->n, scm_assq_ref (data, kn), 0)
&& do_mpz_t_load (x->e, scm_assq_ref (data, ke), 0)
&& rsa_public_key_prepare (x));
}
static inline int
do_rsa_private_key_load (struct rsa_private_key *x, SCM data)
{
return (do_mpz_t_load (x->d, scm_assq_ref (data, kd), 0)
&& do_mpz_t_load (x->p, scm_assq_ref (data, kp), 0)
&& do_mpz_t_load (x->q, scm_assq_ref (data, kq), 0)
&& do_mpz_t_load (x->a, scm_assq_ref (data, kdp), 0)
&& do_mpz_t_load (x->b, scm_assq_ref (data, kdq), 0)
&& do_mpz_t_load (x->c, scm_assq_ref (data, kqi), 0)
&& rsa_private_key_prepare (x));
}
static void
do_mpz_t_clear (void *ptr)
{
mpz_t *z = ptr;
mpz_clear (*z);
}
static void
do_rsa_public_key_clear (void *ptr)
{
struct rsa_public_key *pub = ptr;
rsa_public_key_clear (pub);
}
static void
do_rsa_private_key_clear (void *ptr)
{
struct rsa_private_key *key = ptr;
rsa_private_key_clear (key);
}
static void
do_ecc_point_clear (void *ptr)
{
struct ecc_point *point = ptr;
ecc_point_clear (point);
}
static void
do_ecc_scalar_clear (void *ptr)
{
struct ecc_scalar *scalar = ptr;
ecc_scalar_clear (scalar);
}
static inline void
dynwind_mpz_t_clear (mpz_t * z)
{
scm_dynwind_unwind_handler (do_mpz_t_clear, z, SCM_F_WIND_EXPLICITLY);
}
static inline void
dynwind_rsa_public_key_clear (struct rsa_public_key *pub)
{
scm_dynwind_unwind_handler (do_rsa_public_key_clear, pub,
SCM_F_WIND_EXPLICITLY);
}
static inline void
dynwind_rsa_private_key_clear (struct rsa_private_key *key)
{
scm_dynwind_unwind_handler (do_rsa_private_key_clear, key,
SCM_F_WIND_EXPLICITLY);
}
static inline void
dynwind_ecc_point_clear (struct ecc_point *point)
{
scm_dynwind_unwind_handler (do_ecc_point_clear, point,
SCM_F_WIND_EXPLICITLY);
}
static inline void
dynwind_ecc_scalar_clear (struct ecc_scalar *scalar)
{
scm_dynwind_unwind_handler (do_ecc_scalar_clear, scalar,
SCM_F_WIND_EXPLICITLY);
}
#endif /* not H_UTILITIES_INCLUDED */