package sparklyr

import java.io._
import java.net.{InetAddress, ServerSocket}
import java.util.Arrays

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.{SparkEnv, SparkException}

import scala.util.Try

object Utils {
  /**
   * Utilities for collecting columns / Datasets back to R
   */

  def collectColumnBoolean(df: DataFrame, colName: String): Array[Boolean] = {
    df.select(colName).rdd.map(row => row(0).asInstanceOf[Boolean]).collect()
  }

  def collectColumnInteger(df: DataFrame, colName: String): Array[Int] = {
    df.select(colName).rdd.map(row => {
       val element = row(0)
       if (element.isInstanceOf[Int]) element.asInstanceOf[Int] else scala.Int.MinValue
    }).collect()
  }

  def collectColumnDouble(df: DataFrame, colName: String): Array[Double] = {
    df.select(colName).rdd.map(row => {
       val element = row(0)
       if (element.isInstanceOf[Double]) element.asInstanceOf[Double] else scala.Double.NaN
    }).collect()
  }

  def collectColumnString(df: DataFrame, colName: String, separator: String): String = {
    val text = df.select(colName).rdd.map(row => {
      val element = row(0)
      if (element.isInstanceOf[String]) element.asInstanceOf[String] else "<NA>"
    }).collect().mkString(separator)

    if (text.length() > 0) text + separator else text
  }

  def collectColumnDefault(df: DataFrame, colName: String): Array[Any] = {
    df.select(colName).rdd.map(row => row(0)).collect()
  }

  def collectColumn(df: DataFrame, colName: String, colType: String, separator: String) = {
    colType match {
      case "BooleanType" => collectColumnBoolean(df, colName)
      case "IntegerType" => collectColumnInteger(df, colName)
      case "DoubleType"  => collectColumnDouble(df, colName)
      case "StringType"  => collectColumnString(df, colName, separator)
      case _             => collectColumnDefault(df, colName)
    }
  }

  def collectImplBoolean(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Boolean]) if(el.asInstanceOf[Boolean]) 1 else 0 else scala.Int.MinValue
    }}
  }

  def collectImplBooleanArrArr(local: Array[Row], idx: Integer): Array[Array[Int]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e => {
        if (e.isInstanceOf[Boolean]) if (el.asInstanceOf[Boolean]) 1 else 0 else scala.Int.MinValue
      }).toArray
    }}
  }

  def collectImplInteger(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Int]) el.asInstanceOf[Int] else scala.Int.MinValue
    }}
  }

  def collectImplIntegerArrArr(local: Array[Row], idx: Integer): Array[Array[Int]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        if (e.isInstanceOf[Int]) e.asInstanceOf[Int] else scala.Int.MinValue
      ).toArray
    }}
  }

  def collectImplDouble(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Double]) el.asInstanceOf[Double] else scala.Double.NaN
    }}
  }

  def collectImplDoubleArrArr(local: Array[Row], idx: Integer): Array[Array[Double]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        if (e.isInstanceOf[Double]) e.asInstanceOf[Double] else scala.Double.NaN
      ).toArray
    }}
  }

  def collectImplFloat(local: Array[Row], idx: Integer): Array[Double]  = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Float]) el.asInstanceOf[Float].toDouble else scala.Double.NaN
    }}
  }

  def collectImplFloatArrArr(local: Array[Row], idx: Integer): Array[Array[Double]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        if (e.isInstanceOf[Float]) e.asInstanceOf[Float].toDouble else scala.Double.NaN
      ).toArray
    }}
  }

  def collectImplByte(local: Array[Row], idx: Integer): Array[Int] = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Byte]) el.asInstanceOf[Byte].toInt else scala.Int.MinValue
    }}
  }

  def collectImplByteArrArr(local: Array[Row], idx: Integer): Array[Array[Int]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        if (e.isInstanceOf[Byte]) e.asInstanceOf[Byte].toInt else scala.Int.MinValue
      ).toArray
    }}
  }

  def collectImplShort(local: Array[Row], idx: Integer): Array[Int] = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Short]) el.asInstanceOf[Short].toInt else scala.Int.MinValue
    }}
  }

  def collectImplShortArrArr(local: Array[Row], idx: Integer): Array[Array[Int]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        if (e.isInstanceOf[Short]) e.asInstanceOf[Short].toInt else scala.Int.MinValue
      ).toArray
    }}
  }

  def collectImplLong(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[Long]) el.asInstanceOf[Long].toDouble else scala.Double.NaN
    }}
  }

  def collectImplLongArrArr(local: Array[Row], idx: Integer): Array[Array[Double]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        if (e.isInstanceOf[Long]) e.asInstanceOf[Long].toDouble else scala.Double.NaN
      ).toArray
    }}
  }

  def collectImplForceString(local: Array[Row], idx: Integer, separator: String) = {
    var text = local.map{row => {
      val el = row(idx)
      if (el != null) el.toString() else "<NA>"
    }}.mkString(separator)

    if (text.length() > 0) text + separator else text
  }

  def collectImplForceStringArrArr(local: Array[Row], idx: Integer, separator: String): Array[Array[String]] = {
    local.map{row => {
      val locale = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]

      locale.map{e => {
        if (e != null) e.toString() else "<NA>"
      }}.toArray
    }}
  }

  def collectImplString(local: Array[Row], idx: Integer, separator: String) = {
    var text = local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[String]) el.asInstanceOf[String] else "<NA>"
    }}.mkString(separator)

    if (text.length() > 0) text + separator else text
  }

  def collectImplStringArrArr(local: Array[Row], idx: Integer, separator: String): Array[Array[String]] = {
    local.map{row => {
      val locale = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]

      locale.map{e => {
        if (e.isInstanceOf[String]) e.asInstanceOf[String] else "<NA>"
      }}.toArray
    }}
  }

  def collectImplDecimal(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      if (el.isInstanceOf[java.math.BigDecimal])
        el.asInstanceOf[java.math.BigDecimal].doubleValue
      else
        scala.Double.NaN
    }}
  }

  def collectImplDecimalArrArr(local: Array[Row], idx: Integer): Array[Array[Double]] = {
    local.map{row => {
      val localel = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]

      localel.map(el => {
        if (el.isInstanceOf[java.math.BigDecimal])
          el.asInstanceOf[java.math.BigDecimal].doubleValue
        else
          scala.Double.NaN
      }).toArray
    }}
  }

  def collectImplVector(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      el match {
        case null => Array.empty
        case _: Seq[_] => el.asInstanceOf[Seq[Any]].toArray
        case _ => el.getClass.getDeclaredMethod("toArray").invoke(el)
      }
    }}
  }

  def collectImplJSON(local: Array[Row], idx: Integer) = {
    local.map{row => {
      val el = row(idx)
      el match {
        case _: String => new StructTypeAsJSON(el.asInstanceOf[String])
        case _ => collectImplDefault(local, idx)
      }
    }}
  }

  def collectImplTimestamp(local: Array[Row], idx: Integer) = {
    local.map{row => {
      Try(row.getAs[java.sql.Timestamp](idx)).getOrElse(null)
    }}
  }

  def collectImplTimestampArrArr(local: Array[Row], idx: Integer): Array[Array[java.sql.Timestamp]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        Try(e.asInstanceOf[java.sql.Timestamp]).getOrElse(null)
      ).toArray
    }}
  }

  def collectImplDate(local: Array[Row], idx: Integer) = {
    local.map{row => {
      Try(row.getAs[java.sql.Date](idx)).getOrElse(null)
    }}
  }

  def collectImplDateArrArr(local: Array[Row], idx: Integer): Array[Array[java.sql.Date]] = {
    local.map{row => {
      val el = row(idx).asInstanceOf[scala.collection.mutable.WrappedArray[_]]
      el.map(e =>
        Try(e.asInstanceOf[java.sql.Date]).getOrElse(null)
      ).toArray
    }}
  }

  def collectImplDefault(local: Array[Row], idx: Integer) = {
    local.map(row => row(idx))
  }

  def collectImpl(local: Array[Row], idx: Integer, colType: String, separator: String) = {
    val ReDecimalType = "(DecimalType.*)".r
    val ReVectorType  = "(.*VectorUDT.*)".r

    colType match {
      case "BooleanType"          => collectImplBoolean(local, idx)
      case "IntegerType"          => collectImplInteger(local, idx)
      case "DoubleType"           => collectImplDouble(local, idx)
      case "StringType"           => collectImplString(local, idx, separator)
      case "LongType"             => collectImplLong(local, idx)

      case "ByteType"             => collectImplByte(local, idx)
      case "FloatType"            => collectImplFloat(local, idx)
      case "ShortType"            => collectImplShort(local, idx)
      case "Decimal"              => collectImplForceString(local, idx, separator)

      case "TimestampType"        => collectImplTimestamp(local, idx)
      case "CalendarIntervalType" => collectImplForceString(local, idx, separator)
      case "DateType"             => collectImplDate(local, idx)

      case ReDecimalType(_)       => collectImplDecimal(local, idx)
      case ReVectorType(_)        => collectImplVector(local, idx)
      case StructTypeAsJSON.DType => collectImplJSON(local, idx)

      case "ArrayType(BooleanType,true)"           => collectImplBooleanArrArr(local, idx)
      case "ArrayType(IntegerType,true)"           => collectImplIntegerArrArr(local, idx)
      case "ArrayType(DoubleType,true)"            => collectImplDoubleArrArr(local, idx)
      case "ArrayType(StringType,true)"            => collectImplStringArrArr(local, idx, separator)
      case "ArrayType(LongType,true)"              => collectImplLongArrArr(local, idx)
      case "ArrayType(ByteType,true)"              => collectImplByteArrArr(local, idx)
      case "ArrayType(FloatType,true)"             => collectImplFloatArrArr(local, idx)
      case "ArrayType(ShortType,true)"             => collectImplShortArrArr(local, idx)
      case "ArrayType(DecimalType,true)"           => collectImplDecimalArrArr(local, idx)
      case "ArrayType(TimestampType,true)"         => collectImplTimestampArrArr(local, idx)
      case "ArrayType(CalendarIntervalType,true)"  => collectImplForceStringArrArr(local, idx, separator)
      case "ArrayType(DateType,true)"              => collectImplDateArrArr(local, idx)

      case "ArrayType(BooleanType,false)"          => collectImplBooleanArrArr(local, idx)
      case "ArrayType(IntegerType,false)"          => collectImplIntegerArrArr(local, idx)
      case "ArrayType(DoubleType,false)"           => collectImplDoubleArrArr(local, idx)
      case "ArrayType(StringType,false)"           => collectImplStringArrArr(local, idx, separator)
      case "ArrayType(LongType,false)"             => collectImplLongArrArr(local, idx)
      case "ArrayType(ByteType,false)"             => collectImplByteArrArr(local, idx)
      case "ArrayType(FloatType,false)"            => collectImplFloatArrArr(local, idx)
      case "ArrayType(ShortType,false)"            => collectImplShortArrArr(local, idx)
      case "ArrayType(DecimalType,false)"          => collectImplDecimalArrArr(local, idx)
      case "ArrayType(TimestampType,false)"        => collectImplTimestampArrArr(local, idx)
      case "ArrayType(CalendarIntervalType,false)" => collectImplForceStringArrArr(local, idx, separator)
      case "ArrayType(DateType,false)"             => collectImplDateArrArr(local, idx)

      case "NullType"             => collectImplForceString(local, idx, separator)

      case _                      => collectImplDefault(local, idx)
    }
  }

  def collectArray(local: Array[Row], dtypes: Array[(String, String)], separator: String): Array[_] = {
    (0 until dtypes.length).map{i => collectImpl(local, i, dtypes(i)._2, separator)}.toArray
  }

  def collect(df: DataFrame, separator: String): Array[_] = {
    val columns = df.columns
    val (transformed_df, dtypes) = DFCollectionUtils.prepareDataFrameForCollection(df)
    val local = transformed_df.collect

    collectArray(local, dtypes, separator)
  }

  def separateColumnArray(df: DataFrame,
                          column: String,
                          names: Array[String],
                          indices: Array[Int]) =
  {
    // extract columns of interest
    var col = df.apply(column)
    var colexprs = df.columns.map(df.apply(_))

    // append column expressions that separate from
    // desired column
    (0 until names.length).map{i => {
      val name = names(i)
      val index = indices(i)
      colexprs :+= col.getItem(index).as(name)
    }}

    // select with these column expressions
    df.select(colexprs: _*)
  }

  def separateColumnVector(df: DataFrame,
                           column: String,
                           names: Array[String],
                           indices: Array[Int]) =
  {
    // extract columns of interest
    var col = df.apply(column)
    var colexprs = df.columns.map(df.apply(_))

    // define a udf for extracting vector elements
    // note that we use 'Any' type here just to ensure
    // this compiles cleanly with different Spark versions
    val extractor = udf {
      (x: Any, i: Int) => {
         val el = x.getClass.getDeclaredMethod("toArray").invoke(x)
         val array = el.asInstanceOf[Array[Double]]
         array(i)
      }
    }

    // append column expressions that separate from
    // desired column
    (0 until names.length).map{i => {
      val name = names(i)
      val index = indices(i)
      colexprs :+= extractor(col, lit(index)).as(name)
    }}

    // select with these column expressions
    df.select(colexprs: _*)
  }

  def separateColumnStruct(df: DataFrame,
                          column: String,
                          names: Array[String],
                          indices: Array[Int],
                          intoIsSet: Boolean) =
  {
    // extract columns of interest
    var col = df.apply(column)
    var colexprs = df.columns.map(df.apply(_))

    val fieldNames: Array[String] = df
      .select(column)
      .schema
      .fields
      .flatMap(f => f.dataType match { case struct: StructType => struct.fields})
      .map(_.name)

    val outNames: Array[String] = if (intoIsSet) names else
      fieldNames

    // append column expressions that separate from
    // desired column
    (0 until outNames.length).map{i => {
      val name = outNames(i)
      val index = indices(i)
      colexprs :+= col.getItem(fieldNames(index)).as(name)
    }}

    // select with these column expressions
    df.select(colexprs: _*)
  }

  def separateColumn(df: DataFrame,
                     column: String,
                     names: Array[String],
                     indices: Array[Int],
                     intoIsSet: Boolean) =
  {
    // extract column of interest
    val col = df.apply(column)

    // figure out the type name for this column
    val schema = df.schema
    val typeName = schema.apply(schema.fieldIndex(column)).dataType.typeName

    // delegate to appropriate separator
    typeName match {
      case "array"  => separateColumnArray(df, column, names, indices)
      case "vector" => separateColumnVector(df, column, names, indices)
      case "struct" => separateColumnStruct(df, column, names, indices, intoIsSet)
      case _        => {
        throw new IllegalArgumentException("unhandled type '" + typeName + "'")
      }
    }
  }

  def createDataFrame(sc: SparkContext, rows: Array[_], partitions: Int): RDD[Row] = {
    var data = rows.map(o => {
      val r = o.asInstanceOf[Array[_]]
      org.apache.spark.sql.Row.fromSeq(r)
    })

    sc.parallelize(data, partitions)
  }

  def createDataFrameFromText(
    sc: SparkContext,
    rows: Array[String],
    columns: Array[String],
    partitions: Int,
    separator: String): RDD[Row] = {

    var data = rows.map(o => {
      val r = o.split(separator, -1)
      var typed = (Array.range(0, r.length)).map(idx => {
        val column = columns(idx)
        val value = r(idx)

        column match {
          case "integer"  => if (Try(value.toInt).isSuccess) value.toInt else null
          case "double"  => if (Try(value.toDouble).isSuccess) value.toDouble else null
          case "logical" => if (Try(value.toBoolean).isSuccess) value.toBoolean else null
          case "timestamp" => if (Try(new java.sql.Timestamp(value.toLong * 1000)).isSuccess) new java.sql.Timestamp(value.toLong * 1000) else null
          case "date" => if (Try(new java.sql.Date(value.toLong * 86400000)).isSuccess) new java.sql.Date(value.toLong * 86400000) else null
          case _ => if (value == "NA") null else value
        }
      })

      org.apache.spark.sql.Row.fromSeq(typed)
    })

    sc.parallelize(data, partitions)
  }

  def classExists(name: String): Boolean = {
    scala.util.Try(Class.forName(name)).isSuccess
  }

  def createDataFrameFromCsv(
    sc: SparkContext,
    path: String,
    columns: Array[String],
    partitions: Int,
    separator: String): RDD[Row] = {

    val lines = scala.io.Source.fromFile(path).getLines.toIndexedSeq
    val rddRows: RDD[String] = sc.parallelize(lines, partitions);

    val data: RDD[Row] = rddRows.map(o => {
      val r = o.split(separator, -1)
      var typed = (Array.range(0, r.length)).map(idx => {
        val column = columns(idx)
        val value = r(idx)

        column match {
          case "integer"   => if (Try(value.toInt).isSuccess) value.toInt else null
          case "double"    => if (Try(value.toDouble).isSuccess) value.toDouble else null
          case "logical"   => if (Try(value.toBoolean).isSuccess) value.toBoolean else null
          case "timestamp" => if (Try(new java.sql.Timestamp(value.toLong * 1000)).isSuccess) new java.sql.Timestamp(value.toLong * 1000) else null
          case "date" => if (Try(new java.sql.Date(value.toLong * 86400000)).isSuccess) new java.sql.Date(value.toLong * 86400000) else null
          case _ => if (value == "NA") null else value
        }
      })

      org.apache.spark.sql.Row.fromSeq(typed)
    })

    data
  }

  /**
   * Utilities for performing mutations
   */

  def addSequentialIndex(
    df: DataFrame,
    from: Int,
    id: String) : DataFrame = {
      val sqlContext = df.sqlContext
      sqlContext.createDataFrame(
        df.rdd.zipWithIndex.map {
          case (row: Row, i: Long) => Row.fromSeq(row.toSeq :+ (i.toDouble + from.toDouble))
        },
      df.schema.add(id, "double")
      )
  }


  def getLastIndex(df: DataFrame, id: String) : Double = {
    val numPartitions = df.rdd.partitions.length
    df.select(id).rdd.mapPartitionsWithIndex{
      (i, iter) => if (i != numPartitions - 1 || iter.isEmpty) {
        iter
      } else {
        Iterator
        .continually((iter.next(), iter.hasNext))
        .collect { case (value, false) => value }
        .take(1)
      }
    }.collect().last.getDouble(0)
  }

  def unboxString(x: Option[String]) = x match {
    case Some(s) => s
    case None => ""
  }

  def getAncestry(obj: AnyRef, simpleName: Boolean = true): Array[String] = {
    def supers(cl: Class[_]): List[Class[_]] = {
      if (cl == null) Nil else cl :: supers(cl.getSuperclass)
    }
  supers(obj.getClass).map(if (simpleName) _.getSimpleName else _.getName).toArray
  }

  def portIsAvailable(port: Int, inetAddress: InetAddress) = {
    var ss: ServerSocket = null
    var available = false

    Try {
        ss = new ServerSocket(port, 1, inetAddress)
        available = true
    }

    if (ss != null) {
        Try {
            ss.close();
        }
    }

    available
  }

  def nextPort(port: Int, inetAddress: InetAddress) = {
    var freePort = port + 1
    while (!portIsAvailable(freePort, inetAddress) && freePort - port < 100)
      freePort += 1

    // give up after 100 port searches
    if (freePort - port < 100) freePort else 0;
  }

  def buildStructTypeForIntegerField(): StructType = {
    val fields = Array(StructField("id", IntegerType, false))
    StructType(fields)
  }

  def buildStructTypeForLongField(): StructType = {
    val fields = Array(StructField("id", LongType, false))
    StructType(fields)
  }

  def mapRddLongToRddRow(rdd: RDD[Long]): RDD[Row] = {
    rdd.map(x => org.apache.spark.sql.Row(x))
  }

  def mapRddIntegerToRddRow(rdd: RDD[Long]): RDD[Row] = {
    rdd.map(x => org.apache.spark.sql.Row(x.toInt))
  }

  def readWholeFiles(sc: SparkContext, inputPath: String): RDD[Row] = {
    sc.wholeTextFiles(inputPath).map {
      l => Row(l._1, l._2)
    }
  }

  def unionRdd(context: org.apache.spark.SparkContext, rdds: Seq[org.apache.spark.rdd.RDD[org.apache.spark.sql.Row]]):
    org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] = {
    context.union(rdds)
  }

  def collectIter(iter: Iterator[Row], size: Integer, df: DataFrame, separator: String): Array[_] = {
    val local = iter.take(size).toArray
    val dtypes = df.dtypes
    collectArray(local, dtypes, separator)
  }
}



