#!/usr/bin/env python3
'''Take values from a config file and fill them into a set of templates.
Write the result to the current directory.'''

import os, sys, re, shutil
import argparse

def file_newer(path_a, than_path_b):
  return os.path.getmtime(path_a) > os.path.getmtime(than_path_b)

DEFAULT_ORIG_CONFIG = os.path.normpath(os.path.realpath(__file__) + "/../config_2g3g")
LAST_LOCAL_CONFIG_FILE = '.last_config'
LAST_ORIG_CONFIG_FILE = '.last_config_orig'
LAST_TMPL_DIR = '.last_templates'

parser = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('sources', metavar='SRC', nargs='*',
                    help='Pass both a template directory and a config file.')
parser.add_argument('-s', '--check-stale', dest='check_stale', action='store_true',
                    help='only verify age of generated files vs. config and templates.'
                    ' Exit nonzero when any source file is newer. Do not write anything.')
parser.add_argument('-o', '--original-config',
                    help='get missing variables from this file, default is config_2g3g'
                    ' or the file used previously to fill an existing template dir')

args = parser.parse_args()

local_config_file = None
orig_config_file = args.original_config
tmpl_dir = None

for src in args.sources:
  if os.path.isdir(src):
    if tmpl_dir is not None:
      print('Error: only one template dir permitted. (%r vs. %r)' % (tmpl_dir, src))
    tmpl_dir = src
  elif os.path.isfile(src):
    if local_config_file is not None:
      print('Error: only one config file permitted. (%r vs. %r)' % (local_config_file, src))
    local_config_file = src

if local_config_file is None and os.path.isfile(LAST_LOCAL_CONFIG_FILE):
  local_config_file = open(LAST_LOCAL_CONFIG_FILE).read().strip()

if orig_config_file is None:
  if os.path.isfile(LAST_ORIG_CONFIG_FILE):
    orig_config_file = open(LAST_ORIG_CONFIG_FILE).read().strip()
  else:
    orig_config_file = DEFAULT_ORIG_CONFIG

if tmpl_dir is None and os.path.isfile(LAST_TMPL_DIR):
  tmpl_dir = open(LAST_TMPL_DIR).read().strip()

if not tmpl_dir or not os.path.isdir(tmpl_dir):
  print("Template dir does not exist: %r" % tmpl_dir)
  exit(1)

if not local_config_file or not os.path.isfile(local_config_file):
  print("No such config file: %r" % local_config_file)
  exit(1)

if not os.path.isfile(orig_config_file):
  print("No such config file: %r" % orig_config_file)
  exit(1)

local_config_file = os.path.realpath(local_config_file)
tmpl_dir = os.path.realpath(tmpl_dir)
net_dir = os.path.realpath(".")

print(f'using config file: {local_config_file}')
print(f'with original:     {orig_config_file}')
print(f'on templates:      {tmpl_dir}')
print(f'with NET_DIR:      {net_dir}')

with open(LAST_LOCAL_CONFIG_FILE, 'w') as last_file:
  last_file.write(local_config_file)
with open(LAST_ORIG_CONFIG_FILE, 'w') as last_file:
  last_file.write(orig_config_file)
with open(LAST_TMPL_DIR, 'w') as last_file:
  last_file.write(tmpl_dir)

# read in variable values from config files
# NET_DIR is the folder where fill_config.py was started
local_config = {"NET_DIR": net_dir}

for config_file in [orig_config_file, local_config_file]:
  current_config_identifiers = ["NET_DIR"]
  line_nr = 0
  for line in open(config_file):
    line_nr += 1
    line = line.strip('\n')

    if line.startswith('#'):
      continue

    if not '=' in line:
      if line:
        print("Error: %r line %d: %r" % (config_file, line_nr, line))
        exit(1)
      continue

    split_pos = line.find('=')
    name = line[:split_pos]
    val = line[split_pos + 1:]

    if val.startswith('"') and val.endswith('"'):
      val = val[1:-1]

    if name in current_config_identifiers:
      print("Error: duplicate identifier in %r line %d: %r" % (config_file, line_nr, line))
    local_config[name] = val
    current_config_identifiers += [name]

# replace variable names with above values recursively
replace_re = re.compile('\$\{([A-Z_][A-Za-z0-9_]*)\}')
command_re = re.compile('\$\{([a-z][A-Za-z0-9_]*)\(([^)]*)\)\}')

idx = 0

def check_stale(src_path, target_path):
  if file_newer(src_path, target_path):
    print()
    print('Stale: %r is newer than %r' % (src_path, target_path))
    exit(1)

def replace_vars(tmpl, tmpl_dir, tmpl_src, local_config, strict=True):
    used_vars = set()
    for m in replace_re.finditer(tmpl):
      name = m.group(1)
      if not name in local_config:
        if strict:
          print('Error: undefined var %r in %r' % (name, tmpl_src))
          exit(1)
        else:
          continue
      used_vars.add(name)

    for var in used_vars:
      tmpl = tmpl.replace('${%s}' % var, local_config.get(var))

    return tmpl

def insert_includes(tmpl, tmpl_dir, tmpl_src, local_config, arg):
    include_path = os.path.join(tmpl_dir, arg)
    if not os.path.isfile(include_path):
      print('Error: included file does not exist: %r in %r' % (include_path, tmpl_src))
      exit(1)
    try:
      incl = open(include_path).read()
    except:
      print('Cannot read %r for %r' % (include_path, tmpl_src))
      raise
    if args.check_stale:
      check_stale(include_path, dst)

    # recurse, to follow the paths that the included bits come from
    incl = handle_commands(incl, os.path.dirname(include_path), include_path, local_config)

    return tmpl.replace('${include(%s)}' % arg, incl)

def insert_foreach(tmpl, tmpl_dir, tmpl_src, match, local_config, arg):

    # figure out section to handle
    start_span = match.span()

    if tmpl[start_span[1]] == '\n':
      start_span = (start_span[0], start_span[1] + 1)

    end_str = '${foreach_end}\n'

    end_at = tmpl.find(end_str, start_span[1])
    if end_at < 0:
      end_str = end_str[:-1]
      end_at = tmpl.find(end_str, start_span[1])

    if end_at < 0:
      raise Exception('%r: ${for_each()} expects %r in %r' % (tmpl_src, end_str, tmpl[start_span[1]:]))

    end_span = (end_at, end_at + len(end_str))

    before_block = tmpl[:start_span[0]]
    foreach_block = tmpl[start_span[1]:end_span[0]]
    after_block = tmpl[end_span[1]:]

    # figure out what items matching the foreach(FOO<number>) there are
    item_re = re.compile('(^%s([0-9]+))_.*' % arg)
    items = set()
    for item in local_config.keys():
      item_m = item_re.match(item)
      if not item_m:
        continue
      items.add((int(item_m.group(2)), item_m.group(1)))

    items = sorted(list(items))

    expanded = [before_block]
    for nr, item in items:
      expanded_block = foreach_block

      while True:
        expanded_block_was = expanded_block

        expanded_block = expanded_block.replace('${%sn_' % arg, '${%s_' % item)
        expanded_block = expanded_block.replace('${%sn}' % arg, str(nr))
        expanded_block = replace_vars(expanded_block, tmpl_dir, tmpl_src, local_config)

        if expanded_block_was == expanded_block:
          break

      expanded.append(expanded_block)

    expanded.extend(after_block)
    return ''.join(expanded)

def handle_commands(tmpl, tmpl_dir, tmpl_src, local_config):
    while True:
      # make sure to re-run the regex after each expansion to get proper string
      # offsets each time
      m = command_re.search(tmpl)
      if not m:
        break;
      cmd = m.group(1)
      arg = m.group(2)
      expanded = False
      if cmd == 'include':
        tmpl = insert_includes(tmpl, tmpl_dir, tmpl_src, local_config, arg)
        expanded = True
      elif cmd == 'foreach':
        tmpl = insert_foreach(tmpl, tmpl_dir, tmpl_src, m, local_config, arg)
        expanded = True
      elif cmd == 'strftime':
        pass
      else:
        print('Error: unknown command: %r in %r' % (cmd, tmpl_src))
        break

      if not expanded:
        break

    return tmpl

for tmpl_name in sorted(os.listdir(tmpl_dir)):

  # omit "hidden" files
  if tmpl_name.startswith('.'):
    continue

  # omit files to be included by other files
  if tmpl_name.startswith('common_'):
    continue

  tmpl_src = os.path.join(tmpl_dir, tmpl_name)
  dst = tmpl_name

  # subdirectories: must not contain config files, just copy them
  if os.path.isdir(tmpl_src):
    if os.path.exists(dst) and os.path.isdir(dst):
      shutil.rmtree(dst)
    shutil.copytree(tmpl_src, dst, symlinks=True)
    continue

  if args.check_stale:
    check_stale(local_config_file, dst)
    check_stale(orig_config_file, dst)
    check_stale(tmpl_src, dst)

  local_config['_fname'] = tmpl_name
  local_config['_name'] = os.path.splitext(tmpl_name)[0]
  local_config['_idx0'] = str(idx)
  idx += 1
  local_config['_idx1'] = str(idx)

  # If there are ${FOOn} in the value of a variable called FOO23_SOMETHING,
  # then replace that n by 23. This happens automatically in ${foreach} blocks,
  # but doing this also allows expanding the n outside of ${foreach}.
  for key, val in local_config.items():
    foo_n_re = re.compile('\$\{([A-Za-z0-9_]*)n[_}]')
    for m in foo_n_re.finditer(val):
      name = m.group(1)
      item_re = re.compile('^%s([0-9]+)_.*' % name)
      item_m = item_re.match(key)
      if not item_m:
        continue
      nr_in_key = item_m.group(1)
      val = val.replace('${%sn}' % name, nr_in_key)
      val = val.replace('${%sn_' % name, '${%s%s_' % (name, nr_in_key))
    local_config[key] = val

  try:
    result = open(tmpl_src).read()
  except:
    print('Error in %r' % tmpl_src)
    raise

  while True:
    result_was = result
    result = handle_commands(result, tmpl_dir, tmpl_src, local_config)
    result = replace_vars(result, tmpl_dir, tmpl_src, local_config)
    if result_was == result:
      break

  if not args.check_stale:
    with open(dst, 'w') as dst_file:
      dst_file.write(result)
    shutil.copymode(tmpl_src, dst)

# vim: ts=2 sw=2 expandtab
