#include "rijndael.hh"

rijndael::rijndael() {
} // End of ctor 

rijndael::~rijndael() {

} // End of dtor 

int rijndael::rijndael_key_schedule(const uint8_t p_key[16]) {
  uint8_t round_const = 1;

  // First round p_key equals p_key
  for (int i = 0; i < 16; i++) {
    _round_keys[0][i & 0x03][i>>2] = p_key[i];
  } // End of 'for' statementSW

  /* now calculate round keys */
  for (int i = 1; i < 11; i++) {
    _round_keys[i][0][0] =  S[_round_keys[i - 1][1][3]] ^ _round_keys[i - 1][0][0] ^ round_const;
    _round_keys[i][1][0] =  S[_round_keys[i - 1][2][3]] ^ _round_keys[i - 1][1][0];
    _round_keys[i][2][0] =  S[_round_keys[i - 1][3][3]] ^ _round_keys[i - 1][2][0];
    _round_keys[i][3][0] =  S[_round_keys[i - 1][0][3]] ^ _round_keys[i - 1][3][0];

    for (int j = 0; j < 4; j++) {
      _round_keys[i][j][1] = _round_keys[i - 1][j][1] ^ _round_keys[i][j][0];
      _round_keys[i][j][2] = _round_keys[i - 1][j][2] ^ _round_keys[i][j][1];
      _round_keys[i][j][3] = _round_keys[i - 1][j][3] ^ _round_keys[i][j][2];
    } // End of 'for' statement

    // Update round constant
    round_const = XTIME[round_const]; 
  } // End of 'for' statement

  return 0;
}

void rijndael::rijndael_encrypt(const uint8_t p_input[16], uint8_t p_output[16]) {

  // Initialise state array from p_input byte string
  uint8_t state[4][4];
  for (int i = 0; i < 16; i++) {
    state[i & 0x3][i>>2] = p_input[i];
  } // End of 'for' statement

  // Add first round_key
  key_add(state, _round_keys, 0);

  // Do lots of full rounds
  int r = 1;
  for ( ; r <= 9; r++) {
    byte_sub(state);
    shift_row(state);
    mix_column(state);
    key_add(state, _round_keys, r);
  } // End of 'for' statement

  // Final round
  byte_sub(state);
  shift_row(state);
  key_add(state, _round_keys, r);

  /* produce output byte string from state array */
  for (int i = 0; i < 16; i++) {
    p_output[i] = state[i & 0x3][i>>2];
  } // End of 'for' statement
}

void rijndael::key_add(uint8_t p_state[4][4], const uint8_t p_round_keys[11][4][4], const int p_round) {
  for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 4; j++) {
      p_state[i][j] ^= p_round_keys[p_round][i][j];
    } // End of 'for' statement
  } // End of 'for' statement
}
void rijndael::byte_sub(uint8_t p_state[4][4]) {
  for (int i = 0; i < 4; i++) {
    for (int j = 0; j < 4; j++) {
      p_state[i][j] = S[p_state[i][j]];
    } // End of 'for' statement
  } // End of 'for' statement
}
void rijndael::shift_row(uint8_t p_state[4][4]) {
  uint8_t temp;

  // Left rotate row 1 by 1
  temp = p_state[1][0];
  p_state[1][0] = p_state[1][1];
  p_state[1][1] = p_state[1][2];
  p_state[1][2] = p_state[1][3];
  p_state[1][3] = temp;

  // Left rotate row 2 by 2
  temp = p_state[2][0];
  p_state[2][0] = p_state[2][2];
  p_state[2][2] = temp;
  temp = p_state[2][1];
  p_state[2][1] = p_state[2][3];
  p_state[2][3] = temp;

  // Left rotate row 3 by 3
  temp = p_state[3][0];
  p_state[3][0] = p_state[3][3];
  p_state[3][3] = p_state[3][2];
  p_state[3][2] = p_state[3][1];
  p_state[3][1] = temp;
}
void rijndael::mix_column(uint8_t p_state[4][4]) {
  // Do one column at a time
  uint8_t temp, tmp, tmp0;
  for (int i = 0; i < 4; i++) {
    temp = p_state[0][i] ^ p_state[1][i] ^ p_state[2][i] ^ p_state[3][i];
    tmp0 = p_state[0][i];

    // XTIME array does multiply by x in GF2^8
    tmp = XTIME[p_state[0][i] ^ p_state[1][i]];
    p_state[0][i] ^= temp ^ tmp;
    tmp = XTIME[p_state[1][i] ^ p_state[2][i]];
    p_state[1][i] ^= temp ^ tmp;
    tmp = XTIME[p_state[2][i] ^ p_state[3][i]];
    p_state[2][i] ^= temp ^ tmp;
    tmp = XTIME[p_state[3][i] ^ tmp0];
    p_state[3][i] ^= temp ^ tmp;
  } // End of 'for' statement
}
