#include <Common.h>
#include <System/SysAll.h>
#include <UI/UIAll.h>
#include "MathLib.h"
#include "EvalExpr.h"
#include "ConvStrToDbl.h"

/* defines. */
#define NESTMAX	40
#define PI	3.14159265358979323846

/* global variables */
static double angleunit = PI / 180.;

/* prototypes */
static int  addsub(char**, double*, int);
static int  muldiv(char**, double*, int);
static int  power(char**, double*, int);
static int  sign(char**, double*, int);
static int  afunc(char**, double*, int);
static int  paren(char**, double*, int);
static int  number(char**, double*);
static void chopspace(char**);
static int  compword(char*, char*);


void setdegree(void)
{
  angleunit = PI / 180.;
}


void setradian(void)
{
  angleunit = 1.;
}


static int addsub(char** exprs, double* x, int nest)
{
  double y;
  int error;

  nest++;
  if (nest >= NESTMAX) return ERR_DEEP_NESTING;

  chopspace(exprs);
  error = muldiv(exprs, x, nest);
  if (error) return error;

  for(;;) {
    chopspace(exprs);
    if ((*exprs)[0] == '+') {
      (*exprs)++;
      error = muldiv(exprs, &y, nest);
      if (error) return error;
      *x = *x + y;
    } else if ((*exprs)[0] == '-') {
      (*exprs)++;
      error = muldiv(exprs, &y, nest);
      if (error) return error;
      *x = *x - y;
    } else
      break;
  }

  nest--;
  return 0;
}


static int muldiv(char** exprs, double* x, int nest)
{
  double y;
  int error;

  nest++;
  if (nest >= NESTMAX) return ERR_DEEP_NESTING;

  chopspace(exprs);
  error = sign(exprs, x, nest);
  if (error) return error;

  for(;;) {
    chopspace(exprs);
    if ((*exprs)[0] == '*') {
      (*exprs)++;
      error = sign(exprs, &y, nest);
      if (error) return error;
      *x = *x * y;
    } else if ((*exprs)[0] == '/') {
      (*exprs)++;
      error = sign(exprs, &y, nest);
      if (error) return error;
      *x = *x / y;
    } else
      break;
  }

  nest--;
  return 0;
}


static int sign(char** exprs, double* x, int nest)
{
  double y;
  int error;

  nest++;
  if (nest >= NESTMAX) return ERR_DEEP_NESTING;

  chopspace(exprs);
  if ((*exprs)[0] == '+') {
    (*exprs)++;
    error = power(exprs, &y, nest);
    if (error) return error;
    *x = +y;
  } else if ((*exprs)[0] == '-') {
    (*exprs)++;
    error = power(exprs, &y, nest);
    if (error) return error;
    *x = -y;
  } else {
    error = power(exprs, x, nest);
    if (error) return error;
  }

  nest--;
  return 0;
}


static int power(char** exprs, double* x, int nest)
{
  double y;
  int error;

  nest++;
  if (nest >= NESTMAX) return ERR_DEEP_NESTING;

  chopspace(exprs);
  error = afunc(exprs, x, nest);
  if (error) return error;

  chopspace(exprs);
  if ((*exprs)[0] == '^') {
    (*exprs)++;
    error = power(exprs, &y, nest);
    if (error) return error;
    *x = pow(*x, y);
  }

  nest--;
  return 0;
}


static int afunc(char** exprs, double* x, int nest)
{
  int error;

  nest++;
  if (nest >= NESTMAX) return ERR_DEEP_NESTING;

  chopspace(exprs);
  if (compword(*exprs, "sqrt")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = sqrt(*x);
  }
  else if (compword(*exprs, "log")) {
    (*exprs) += 3;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = log10(*x);
  }
  else if (compword(*exprs, "ln")) {
    (*exprs) += 2;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = log(*x);
  }
  else if (compword(*exprs, "exp")) {
    (*exprs) += 3;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = exp(*x);
  }
  else if (compword(*exprs, "asinh")) {
    (*exprs) += 5;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = asinh(*x);
  }
  else if (compword(*exprs, "acosh")) {
    (*exprs) += 5;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = acosh(*x);
  }
  else if (compword(*exprs, "atanh")) {
    (*exprs) += 5;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = atanh(*x);
  }
  else if (compword(*exprs, "sinh")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = sinh(*x);
  }
  else if (compword(*exprs, "cosh")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = cosh(*x);
  }
  else if (compword(*exprs, "tanh")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = tanh(*x);
  }
  else if (compword(*exprs, "asin")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = asin(*x) / angleunit;
  }
  else if (compword(*exprs, "acos")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = acos(*x) / angleunit;
  }
  else if (compword(*exprs, "atan")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = atan(*x) / angleunit;
  }
  else if (compword(*exprs, "sin")) {
    (*exprs) += 3;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = sin(*x * angleunit);
  }
  else if (compword(*exprs, "cos")) {
    (*exprs) += 3;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = cos(*x * angleunit);
  }
  else if (compword(*exprs, "tan")) {
    (*exprs) += 3;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = tan(*x * angleunit);
  }

  else if (compword(*exprs, "ceil")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = ceil(*x);
  }
  else if (compword(*exprs, "floor")) {
    (*exprs) += 5;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = floor(*x);
  }
  else if (compword(*exprs, "abs")) {
    (*exprs) += 3;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = fabs(*x);
  }
  else if (compword(*exprs, "rint")) {
    (*exprs) += 4;
    error = afunc(exprs, x, nest);
    if (error) return error;
    *x = rint(*x);
  }

  else {
    error = paren(exprs, x, nest);
    if (error) return error;
  }

  nest--;
  return 0;
}


static int paren(char** exprs, double* x, int nest)
{
  int error;

  nest++;
  if (nest >= NESTMAX) return ERR_DEEP_NESTING;

  chopspace(exprs);
  if ((*exprs)[0] == '(') {
    (*exprs)++;
    error = addsub(exprs, x, nest);
    if (error) return error;
    chopspace(exprs);
    if ((*exprs)[0] != ')') return ERR_PAREN_NOT_MATCH;
    (*exprs)++;
  } else {
    error = number(exprs, x);
    if (error) return error;
  }

  nest--;
  return 0;
}


int number(char** exprs, double* x)
{
  int error;
  char* endptr;

  chopspace(exprs);
  if ((*exprs)[0] == NULL) return ERR_NULL_FOUND;
  error = ConvStrToDbl(x, *exprs, &endptr);
  if (*exprs == endptr) {  /* not a number */
    if (compword(*exprs, "pi") || compword(*exprs, "PI")) {
      /* constant PI */
      (*exprs) += 2;
      *x = PI;
      return 0;
    } else if (**exprs == '$') {
      /* refering variables */
      error = GetVarBySymbol(exprs, x);
      if (error != 0) return ERR_VAR_UNDEFINED;
      return 0;
    } else {
      *exprs = endptr;
      return ERR_NOT_A_NUMBER;
    }
  } else if (error != 0) {
    *exprs = endptr;
    return ERR_NOT_A_NUMBER;
  }

  *exprs = endptr;
  return 0;
}


int evaluate(char** exprs, double* result)
{
  int error;

  error = addsub(exprs, result, 0);
  if (error != 0) return error;

  chopspace(exprs);
  if (**exprs != '\0') return ERR_EXCESS_CHARACTERS;
  return 0;
}


static void chopspace(char** exprs)
{
  while (((*exprs)[0] == ' ') || ((*exprs)[0] == '\t'))
    (*exprs)++;
  return;
}

static int compword(char* str, char* word)
{
  int len = StrLen(word);

  if (StrNCompare(str, word, len) != 0) return false;
  if (*(str + len) == '_') return false;
  if ((*(str + len) >= 'a') && (*(str + len) <= 'z')) return false;
  if ((*(str + len) >= 'A') && (*(str + len) <= 'Z')) return false;

  return true;
}
