import {
  crypto_core_ristretto255_is_valid_point,
  crypto_core_ristretto255_scalar_random,
  crypto_scalarmult_ristretto255,
  crypto_core_ristretto255_scalar_invert,
  crypto_core_ristretto255_from_hash,
} from 'libsodium-wrappers-sumo';

/**
 * Provides essential functions for executing an
 * [OPRF](https://en.wikipedia.org/wiki/Pseudorandom_function_family#Oblivious_pseudorandom_functions).
 *
 * OPRF allows secure data (usually passwords) to be salted on the server,
 * without the server ever receiving the data in a readable form.
 *
 * This implementation uses Ristretto to avoid issues with Curve25519 cofactor
 * attacks.
 *
 * NOTE: This OPRF implementation uses the DH-OPRF operations found in
 * [OPAQUE](https://datatracker.ietf.org/doc/draft-krawczyk-cfrg-opaque/?include_text=1).
 */
export module OPRF {
  export interface Alpha {
    mask: Uint8Array;
    point: Uint8Array;
  }

  export class InvalidPointError extends Error {}

  /**
   * Mask the input so it can be transmitted to another party securely.
   *
   * @param input The sensitive data that needs to be sent, in Ristretto point
   * form.
   *
   * @return      Mask used for computing `alpha`, as well as the masked input.
   */
  export const mask = (input: Uint8Array): Alpha => {
    if (!crypto_core_ristretto255_is_valid_point(input)) {
      throw new InvalidPointError();
    }

    const computedMask = crypto_core_ristretto255_scalar_random();
    const alpha = crypto_scalarmult_ristretto255(computedMask, input);

    return {
      mask: computedMask,
      point: alpha,
    };
  };

  /**
   * Take a normal hash and transform it into a Ristretto point for mask.
   *
   * @param input Plaintext to transform.
   *
   * @returns     Ristretto point.
   */
  export const makePoint = (input: Uint8Array): Uint8Array =>
    crypto_core_ristretto255_from_hash(input);

  /**
   * Generate a new salt that can be used for OPRF inputs.
   *
   * @return A random salt, which should be stored to repeat the calculation.
   */
  export const salt = (): Uint8Array => crypto_core_ristretto255_scalar_random();

  /**
   * Apply a salt to a masked input.
   *
   * @param k     A pre-computed salt.
   * @param alpha Masked input from a client program.
   *
   * @return      `beta` salted masked input.
   */
  // This is not a const so it can be easily mocked in tests
  // eslint-disable-next-line prefer-const
  export let mult = (k: Uint8Array, alpha: Uint8Array): Uint8Array => {
    if (!crypto_core_ristretto255_is_valid_point(k)) {
      throw new InvalidPointError('K is invalid');
    }

    if (!crypto_core_ristretto255_is_valid_point(alpha)) {
      throw new InvalidPointError('alpha is invalid');
    }

    return crypto_scalarmult_ristretto255(k, alpha);
  };

  /**
   * Remove a mask in order to reveal the properly salted alpha.
   *
   * @param beta Salted, masked output from the server.
   * @param computedMask The original mask used to compute `alpha`.
   *
   * @return     A securely-salted input.
   */
  export const unmask = (beta: Uint8Array, computedMask: Uint8Array): Uint8Array => {
    if (
      !crypto_core_ristretto255_is_valid_point(beta) ||
      !crypto_core_ristretto255_is_valid_point(computedMask)
    ) {
      throw new InvalidPointError();
    }

    const ir = crypto_core_ristretto255_scalar_invert(computedMask);
    return crypto_scalarmult_ristretto255(ir, beta);
  };
}
