package org.ddahl.rscala

import java.net._
import java.io._
import scala.language.dynamics

import Protocol._

class RClient private (private val in: DataInputStream, private val out: DataOutputStream, private val debugger: Debugger) extends Dynamic {

  def debug = debugger.debug

  def debug_=(v: Boolean) = {
    if ( v != debug ) {
      if ( debug ) debugger.msg("Sending DEBUG request.")
      out.writeInt(DEBUG)
      out.writeInt(if ( v ) 1 else 0)
      out.flush()
      debugger.debug = v
    }
  }

  def exit() = {
    if ( debug ) debugger.msg("Sending EXIT request.")
    out.writeInt(EXIT)
    out.flush()
  }

  def eval(snippet: String, evalOnly: Boolean = true): Any = {
    if ( debug ) debugger.msg("Sending EVAL request.")
    out.writeInt(EVAL)
    Helper.writeString(out,snippet)
    out.flush()
    val status = in.readInt()
    val output = Helper.readString(in)
    if ( output != "" ) println(output)
    if ( status != OK ) throw new RuntimeException("Error in R evaluation.")
    if ( evalOnly ) null else get(".rscala.last.value")._1
  }

  def evalI0(snippet: String) = { eval(snippet,true); getI0(".rscala.last.value") }
  def evalD0(snippet: String) = { eval(snippet,true); getD0(".rscala.last.value") }
  def evalB0(snippet: String) = { eval(snippet,true); getB0(".rscala.last.value") }
  def evalS0(snippet: String) = { eval(snippet,true); getS0(".rscala.last.value") }
  def evalI1(snippet: String) = { eval(snippet,true); getI1(".rscala.last.value") }
  def evalD1(snippet: String) = { eval(snippet,true); getD1(".rscala.last.value") }
  def evalB1(snippet: String) = { eval(snippet,true); getB1(".rscala.last.value") }
  def evalS1(snippet: String) = { eval(snippet,true); getS1(".rscala.last.value") }
  def evalI2(snippet: String) = { eval(snippet,true); getI2(".rscala.last.value") }
  def evalD2(snippet: String) = { eval(snippet,true); getD2(".rscala.last.value") }
  def evalB2(snippet: String) = { eval(snippet,true); getB2(".rscala.last.value") }
  def evalS2(snippet: String) = { eval(snippet,true); getS2(".rscala.last.value") }

  def selectDynamic(identifier: String): (Any,String) = get(identifier)

  def updateDynamic(identifier: String)(value : Any): Unit = set(identifier,value)

  def set(identifier: String, value: Any, index: String = "", singleBrackets: Boolean = true): Unit = {
    if ( debug ) debugger.msg("Setting: "+identifier)
    val v = value
    if ( index == "" ) out.writeInt(SET)
    else if ( singleBrackets ) {
      out.writeInt(SET_SINGLE)
      Helper.writeString(out,index)
    } else {
      out.writeInt(SET_DOUBLE)
      Helper.writeString(out,index)
    }
    Helper.writeString(out,identifier)
    if ( v == null || v.isInstanceOf[Unit] ) {
      if ( debug ) debugger.msg("... which is null")
      out.writeInt(NULLTYPE)
      out.flush()
      if ( index != "" ) {
        val status = in.readInt()
        if ( status != OK ) {
          val output = Helper.readString(in)
          if ( output != "" ) println(output)
          throw new RuntimeException("Error in R evaluation.")
        }
      }
      return
    }
    val c = v.getClass
    if ( debug ) debugger.msg("... whose class is: "+c)
    if ( debug ) debugger.msg("... and whose value is: "+v)
    if ( c.isArray ) {
      c.getName match {
        case "[I" =>
          val vv = v.asInstanceOf[Array[Int]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(INTEGER)
          for ( i <- 0 until vv.length ) out.writeInt(vv(i))
        case "[D" =>
          val vv = v.asInstanceOf[Array[Double]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(DOUBLE)
          for ( i <- 0 until vv.length ) out.writeDouble(vv(i))
        case "[Z" =>
          val vv = v.asInstanceOf[Array[Boolean]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(BOOLEAN)
          for ( i <- 0 until vv.length ) out.writeInt(if ( vv(i) ) 1 else 0)
        case "[Ljava.lang.String;" =>
          val vv = v.asInstanceOf[Array[String]]
          out.writeInt(VECTOR)
          out.writeInt(vv.length)
          out.writeInt(STRING)
          for ( i <- 0 until vv.length ) Helper.writeString(out,vv(i))
        case "[[I" =>
          val vv = v.asInstanceOf[Array[Array[Int]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(INTEGER)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vvv.length ) {
                out.writeInt(vv(i)(j))
              }
            }
          }
        case "[[D" =>
          val vv = v.asInstanceOf[Array[Array[Double]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(DOUBLE)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vvv.length ) {
                out.writeDouble(vvv(j))
              }
            }
          } else out.writeInt(UNSUPPORTED_STRUCTURE)
        case "[[Z" =>
          val vv = v.asInstanceOf[Array[Array[Boolean]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(BOOLEAN)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vv(i).length ) {
                out.writeInt(if ( vvv(j) ) 1 else 0)
              }
            }
          } else out.writeInt(UNSUPPORTED_STRUCTURE)
        case "[[Ljava.lang.String;" =>
          val vv = v.asInstanceOf[Array[Array[String]]]
          if ( Helper.isMatrix(vv) ) {
            out.writeInt(MATRIX)
            out.writeInt(vv.length)
            if ( vv.length > 0 ) out.writeInt(vv(0).length)
            else out.writeInt(0)
            out.writeInt(STRING)
            for ( i <- 0 until vv.length ) {
              val vvv = vv(i)
              for ( j <- 0 until vv(i).length ) {
                Helper.writeString(out,vvv(j))
              }
            }
          } else out.writeInt(UNSUPPORTED_STRUCTURE)
        case _ =>
          throw new RuntimeException("Unsupported array type: "+c.getName)
      }
    } else {
      c.getName match {
        case "java.lang.Integer" =>
          out.writeInt(ATOMIC)
          out.writeInt(INTEGER)
          out.writeInt(v.asInstanceOf[Int])
        case "java.lang.Double" =>
          out.writeInt(ATOMIC)
          out.writeInt(DOUBLE)
          out.writeDouble(v.asInstanceOf[Double])
        case "java.lang.Boolean" =>
          out.writeInt(ATOMIC)
          out.writeInt(BOOLEAN)
          out.writeInt(if (v.asInstanceOf[Boolean]) 1 else 0)
        case "java.lang.String" =>
          out.writeInt(ATOMIC)
          out.writeInt(STRING)
          Helper.writeString(out,v.asInstanceOf[String])
        case _ =>
          throw new RuntimeException("Unsupported non-array type: "+c.getName)
      }
    }
    out.flush()
    if ( index != "" ) {
      val status = in.readInt()
      if ( status != OK ) {
        val output = Helper.readString(in)
        if ( output != "" ) println(output)
        throw new RuntimeException("Error in R evaluation.")
      }
    }
  }

  def get(identifier: String): (Any,String) = {
    if ( debug ) debugger.msg("Getting: "+identifier)
    out.writeInt(GET)
    Helper.writeString(out,identifier)
    out.flush()
    in.readInt match {
      case NULLTYPE =>
        if ( debug ) debugger.msg("Getting null.")
        (null,"Null")
      case ATOMIC =>
        if ( debug ) debugger.msg("Getting atomic.")
        in.readInt() match {
          case INTEGER => (in.readInt(),"Int")
          case DOUBLE => (in.readDouble(),"Double")
          case BOOLEAN => (( in.readInt() != 0 ),"Boolean")
          case STRING => (Helper.readString(in),"String")
          case _ => throw new RuntimeException("Protocol error")
        }
      case VECTOR =>
        if ( debug ) debugger.msg("Getting vector...")
        val length = in.readInt()
        if ( debug ) debugger.msg("... of length: "+length)
        in.readInt() match {
          case INTEGER => (Array.fill(length) { in.readInt() },"Array[Int]")
          case DOUBLE => (Array.fill(length) { in.readDouble() },"Array[Double]")
          case BOOLEAN => (Array.fill(length) { ( in.readInt() != 0 ) },"Array[Boolean]")
          case STRING => (Array.fill(length) { Helper.readString(in) },"Array[String]")
          case _ => throw new RuntimeException("Protocol error")
        }
      case MATRIX =>
        if ( debug ) debugger.msg("Getting matrix...")
        val nrow = in.readInt()
        val ncol = in.readInt()
        if ( debug ) debugger.msg("... of dimensions: "+nrow+","+ncol)
        in.readInt() match {
          case INTEGER => (Array.fill(nrow) { Array.fill(ncol) { in.readInt() } },"Array[Array[Int]]")
          case DOUBLE => (Array.fill(nrow) { Array.fill(ncol) { in.readDouble() } },"Array[Array[Double]]")
          case BOOLEAN => (Array.fill(nrow) { Array.fill(ncol) { ( in.readInt() != 0 ) } },"Array[Array[Boolean]]")
          case STRING => (Array.fill(nrow) { Array.fill(ncol) { Helper.readString(in) } },"Array[Array[String]]")
          case _ => throw new RuntimeException("Protocol error")
        }
      case UNDEFINED_IDENTIFIER => throw new RuntimeException("Undefined identifier")
      case UNSUPPORTED_STRUCTURE => throw new RuntimeException("Unsupported data type")
      case _ => throw new RuntimeException("Protocol error")
    }
  }

  def getI0(identifier: String): Int = get(identifier) match {
    case (a,"Int") => a.asInstanceOf[Int]
    case (a,"Double") => a.asInstanceOf[Double].toInt
    case (a,"Boolean") => if (a.asInstanceOf[Boolean]) 1 else 0
    case (a,"String") => a.asInstanceOf[String].toInt
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]](0)
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]](0).toInt
    case (a,"Array[Boolean]") => if ( a.asInstanceOf[Array[Boolean]](0) ) 1 else 0
    case (a,"Array[String]") => a.asInstanceOf[Array[String]](0).toInt
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Int")
  }

  def getD0(identifier: String): Double = get(identifier) match {
    case (a,"Int") => a.asInstanceOf[Int].toDouble
    case (a,"Double") => a.asInstanceOf[Double]
    case (a,"Boolean") => if (a.asInstanceOf[Boolean]) 1.0 else 0.0
    case (a,"String") => a.asInstanceOf[String].toDouble
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]](0).toDouble
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]](0)
    case (a,"Array[Boolean]") => if ( a.asInstanceOf[Array[Boolean]](0) ) 1.0 else 0.0
    case (a,"Array[String]") => a.asInstanceOf[Array[String]](0).toDouble
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Double")
  }

  def getB0(identifier: String): Boolean = get(identifier) match {
    case (a,"Int") => a.asInstanceOf[Int] != 0
    case (a,"Double") => a.asInstanceOf[Double] != 0.0
    case (a,"Boolean") => a.asInstanceOf[Boolean]
    case (a,"String") => a.asInstanceOf[String].toLowerCase != "false"
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]](0) != 0
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]](0) != 0.0
    case (a,"Array[Boolean]") => a.asInstanceOf[Array[Boolean]](0)
    case (a,"Array[String]") => a.asInstanceOf[Array[String]](0).toLowerCase != "false"
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Boolean")
  }

  def getS0(identifier: String): String = get(identifier) match {
    case (a,"Int") => a.asInstanceOf[Int].toString
    case (a,"Double") => a.asInstanceOf[Double].toString
    case (a,"Boolean") => a.asInstanceOf[Boolean].toString
    case (a,"String") => a.asInstanceOf[String]
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]](0).toString
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]](0).toString
    case (a,"Array[Boolean]") => a.asInstanceOf[Array[Boolean]](0).toString
    case (a,"Array[String]") => a.asInstanceOf[Array[String]](0)
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to String")
  }

  def getI1(identifier: String): Array[Int] = get(identifier) match {
    case (a,"Int") => Array(a.asInstanceOf[Int])
    case (a,"Double") => Array(a.asInstanceOf[Double].toInt)
    case (a,"Boolean") => Array(if (a.asInstanceOf[Boolean]) 1 else 0)
    case (a,"String") => Array(a.asInstanceOf[String].toInt)
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]]
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]].map(_.toInt)
    case (a,"Array[Boolean]") => a.asInstanceOf[Array[Boolean]].map(x => if (x) 1 else 0)
    case (a,"Array[String]") => a.asInstanceOf[Array[String]].map(_.toInt)
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Int]")
  }

  def getD1(identifier: String): Array[Double] = get(identifier) match {
    case (a,"Int") => Array(a.asInstanceOf[Int].toDouble)
    case (a,"Double") => Array(a.asInstanceOf[Double])
    case (a,"Boolean") => Array(if (a.asInstanceOf[Boolean]) 1.0 else 0.0)
    case (a,"String") => Array(a.asInstanceOf[String].toDouble)
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]].map(_.toDouble)
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]]
    case (a,"Array[Boolean]") => a.asInstanceOf[Array[Boolean]].map(x => if (x) 1.0 else 0.0)
    case (a,"Array[String]") => a.asInstanceOf[Array[String]].map(_.toDouble)
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Double]")
  }

  def getB1(identifier: String): Array[Boolean] = get(identifier) match {
    case (a,"Int") => Array(a.asInstanceOf[Int] != 0)
    case (a,"Double") => Array(a.asInstanceOf[Double] != 0.0)
    case (a,"Boolean") => Array(a.asInstanceOf[Boolean])
    case (a,"String") => Array(a.asInstanceOf[String].toLowerCase != "false")
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]].map(_ != 0)
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]].map(_ != 0.0)
    case (a,"Array[Boolean]") => a.asInstanceOf[Array[Boolean]]
    case (a,"Array[String]") => a.asInstanceOf[Array[String]].map(_.toLowerCase != "false")
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Boolean]")
  }

  def getS1(identifier: String): Array[String] = get(identifier) match {
    case (a,"Int") => Array(a.asInstanceOf[Int].toString)
    case (a,"Double") => Array(a.asInstanceOf[Double].toString)
    case (a,"Boolean") => Array(a.asInstanceOf[Boolean].toString)
    case (a,"String") => Array(a.asInstanceOf[String])
    case (a,"Array[Int]") => a.asInstanceOf[Array[Int]].map(_.toString)
    case (a,"Array[Double]") => a.asInstanceOf[Array[Double]].map(_.toString)
    case (a,"Array[Boolean]") => a.asInstanceOf[Array[Boolean]].map(_.toString)
    case (a,"Array[String]") => a.asInstanceOf[Array[String]]
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[String]")
  }

  def getI2(identifier: String): Array[Array[Int]] = get(identifier) match {
    case (a,"Array[Array[Int]]") => a.asInstanceOf[Array[Array[Int]]]
    case (a,"Array[Array[Double]]") => a.asInstanceOf[Array[Array[Double]]].map(_.map(_.toInt))
    case (a,"Array[Array[Boolean]]") => a.asInstanceOf[Array[Array[Boolean]]].map(_.map(x => if (x) 1 else 0))
    case (a,"Array[Array[String]]") => a.asInstanceOf[Array[Array[String]]].map(_.map(_.toInt))
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Array[Int]]")
  }


  def getD2(identifier: String): Array[Array[Double]] = get(identifier) match {
    case (a,"Array[Array[Int]]") => a.asInstanceOf[Array[Array[Int]]].map(_.map(_.toDouble))
    case (a,"Array[Array[Double]]") => a.asInstanceOf[Array[Array[Double]]]
    case (a,"Array[Array[Boolean]]") => a.asInstanceOf[Array[Array[Boolean]]].map(_.map(x => if (x) 1.0 else 0.0))
    case (a,"Array[Array[String]]") => a.asInstanceOf[Array[Array[String]]].map(_.map(_.toDouble))
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Array[Double]]")
  }

  def getB2(identifier: String): Array[Array[Boolean]] = get(identifier) match {
    case (a,"Array[Array[Int]]") => a.asInstanceOf[Array[Array[Int]]].map(_.map(_ != 0))
    case (a,"Array[Array[Double]]") => a.asInstanceOf[Array[Array[Double]]].map(_.map(_ != 0.0))
    case (a,"Array[Array[Boolean]]") => a.asInstanceOf[Array[Array[Boolean]]]
    case (a,"Array[Array[String]]") => a.asInstanceOf[Array[Array[String]]].map(_.map(_.toLowerCase != "false"))
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Array[Boolean]]")
  }

  def getS2(identifier: String): Array[Array[String]] = get(identifier) match {
    case (a,"Array[Array[Int]]") => a.asInstanceOf[Array[Array[Int]]].map(_.map(_.toString))
    case (a,"Array[Array[Double]]") => a.asInstanceOf[Array[Array[Double]]].map(_.map(_.toString))
    case (a,"Array[Array[Boolean]]") => a.asInstanceOf[Array[Array[Boolean]]].map(_.map(_.toString))
    case (a,"Array[Array[String]]") => a.asInstanceOf[Array[Array[String]]]
    case (_,tp) => throw new RuntimeException(s"Unable to cast ${tp} to Array[Array[String]]")
  }

}

object RClient {

  import scala.sys.process._

  private val OS = sys.props("os.name").toLowerCase match {
    case s if s.startsWith("""windows""") => "windows"
    case s if s.startsWith("""linux""") => "linux"
    case s if s.startsWith("""unix""") => "linux"
    case s if s.startsWith("""mac""") => "macintosh"
    case _ => throw new RuntimeException("Unrecognized OS")
  }

  private val defaultArguments = OS match {
    case "windows" =>    Array[String]("--vanilla","--silent","--slave","--ess") 
    case "linux" =>      Array[String]("--vanilla","--silent","--slave","--interactive")
    case "unix" =>       Array[String]("--vanilla","--silent","--slave","--interactive")
    case "macintosh" =>  Array[String]("--vanilla","--silent","--slave","--interactive")
  }

  private lazy val defaultRCmd = OS match {
    case "windows" =>   findROnWindows
    case "linux" =>     """R"""
    case "unix" =>      """R"""
    case "macintosh" => """R"""
  }

  private def findROnWindows: String = {
    val NEWLINE = sys.props("line.separator")  
    var result : String = null
    for ( root <- List("HKEY_LOCAL_MACHINE","HKEY_CURRENT_USER") ) {
      val out = new StringBuilder()
      val logger = ProcessLogger((o: String) => { out.append(o); out.append(NEWLINE) },(e: String) => {})
      try {
        ("reg query \"" + root + "\\Software\\R-core\\R\" /v \"InstallPath\"") ! logger
        val a = out.toString.split(NEWLINE).filter(_.matches("""^\s*InstallPath\s*.*"""))(0)
        result = a.split("REG_SZ")(1).trim() + """\bin\R.exe"""
      } catch {
        case _ : Throwable =>
      }
    }
    if ( result == null ) throw new RuntimeException("Cannot locate R using Windows registry.")
    else return result
  }

  private def reader(debugger: Debugger, label: String)(input: InputStream) = {
    val in = new BufferedReader(new InputStreamReader(input))
    var line = in.readLine()
    while ( line != null ) {
      if ( debugger.debug ) println(label+line)
      line = in.readLine()
    }
    in.close()
  }

  def apply(): RClient = apply(defaultRCmd)

  def apply(rCmd: String, debug: Boolean = false): RClient = {
    var cmd: PrintWriter = null
    val command = rCmd +: defaultArguments
    val processCmd = Process(command)
    val debugger = new Debugger(debug)
    val processIO = new ProcessIO(
      o => { cmd = new PrintWriter(o) },
      reader(debugger,"STDOUT DEBUG: "),
      reader(debugger,"STDERR DEBUG: "),
      true
    )
    val portsFile = File.createTempFile("rscala-","")
    val processInstance = processCmd.run(processIO)
    while ( cmd == null ) Thread.sleep(100)
    cmd.println("library(rscala)")
    cmd.print("rscala:::rServe(rscala:::newSockets('")
    cmd.print(portsFile.getAbsolutePath.replaceAll(File.separator,"/"))
    cmd.print("',debug=")
    cmd.print( if ( debug ) "TRUE" else "FALSE" )
    cmd.println("))")
    cmd.println(raw"q(save='no')")
    cmd.flush()
    val sockets = new ScalaSockets(portsFile.getAbsolutePath,debugger)
    sockets.out.writeInt(OK)
    sockets.out.flush()
    apply(sockets.in,sockets.out,debugger)
  }

  def apply(in: DataInputStream, out: DataOutputStream, debugger: Debugger): RClient = new RClient(in,out,debugger)

}

