package net.liftweb.mocks

import _root_.scala.collection.mutable.HashMap
import _root_.java.io.PrintWriter
import _root_.java.io.StringReader
import _root_.java.io.BufferedReader
import _root_.java.io.ByteArrayOutputStream
import _root_.java.io.ByteArrayInputStream
import _root_.java.io.FileInputStream
import _root_.java.io.InputStream
import _root_.java.io.StringBufferInputStream
import _root_.java.io.File
import _root_.java.util.Arrays
import _root_.java.util.Date
import _root_.java.util.Locale
import _root_.java.util.Vector
import _root_.javax.servlet._
import _root_.javax.servlet.http._

/**
 * An example of how to use these mock classes in your unit tests:
 *
 *   def testLiftCore = {
 *     val output = new ByteArrayOutputStream
 *     val outputStream = new MockServletOutputStream(output)
 *     val writer = new PrintWriter(outputStream)
 *
 *     val req = new MockHttpServletRequest
 *     req.method = "GET"
 *     req.path = "/"
 *     val res = new MockHttpServletResponse(writer, outputStream)
 *
 *     val filter = new LiftFilter
 *     filter.init(new MockFilterConfig(new MockServletContext("target/test1-1.0-SNAPSHOT")))
 *     filter.doFilter(req, res,new DoNothingFilterChain)
 *     assertTrue(output.toString.startsWith("<?xml"))
 *   }
 */



/**
 * A Mock ServletContext. LiftFilter expects a ServletContext inside a FilterConfig
 *
 * @param target the target directory where your template files live
 *
 * @author Steve Jenson (stevej@pobox.com)
 */
class MockServletContext(var target: String) extends ServletContext {
  def getInitParameter(f: String) = null
  def getInitParameterNames = new Vector[AnyRef]().elements
  def getAttribute(f: String) = null
  def getAttributeNames = new Vector[AnyRef]().elements
  def removeAttribute(name: String) {}
  def setAttribute(name: String, o: Object) {}
  def getContext(path: String) = this
  def getMajorVersion = 2
  def getMimeType(file: String) = null
  def getMinorVersion = 3
  def getRealPath(path: String) = null
  def getNamedDispatcher(name: String) = null
  def getRequestDispatcher(path: String) = null
  def getResource(path: String) = null
  def getResourceAsStream(path: String) = {
    val file = new File(target + path)
    if (file.exists) {
      new FileInputStream(file)
    } else {
      null
    }
  }

  def getResourcePaths(path: String) = null
  def getServerInfo = null
  def getServlet(name: String) = null
  def getServletContextName = null
  def getServletNames = new Vector[AnyRef]().elements
  def getServlets = new Vector[AnyRef]().elements
  def log(msg: String, t: Throwable) {
    t.printStackTrace
    log(msg)
  }
  def log(e: Exception, msg: String) {
    e.printStackTrace
    log(msg)
  }
  def log(msg: String) = println("MockServletContext.log: " + msg)
  def getContextPath = null
}


/**
 * A Mock FilterConfig. Construct with a MockServletContext and pass into
 * LiftFilter.init
 */
class MockFilterConfig(servletContext: ServletContext) extends FilterConfig {
  def getFilterName = "LiftFilter" // as in lift's default web.xml
  def getInitParameter(key: String) = null
  def getInitParameterNames  = new Vector[AnyRef]().elements
  def getServletContext = servletContext
}

/**
 * A FilterChain that does nothing.
 *
 * @author Steve Jenson (stevej@pobox.com)
 */
class DoNothingFilterChain extends FilterChain {
  def doFilter(req: ServletRequest, res: ServletResponse) {println("doing nothing")}
}

/**
 * A Mock ServletInputStream. Pass in any ol InputStream like a ByteArrayInputStream.
 *
 * @author Steve Jenson (stevej@pobox.com)
 */
class MockServletInputStream(is: InputStream) extends ServletInputStream {
  def read = is.read()
}

/**
 * A Mock ServletOutputStream. Pass in any ol' OutputStream like a ByteArrayOuputStream.
 *
 * @author Steve Jenson (stevej@pobox.com)
 */
class MockServletOutputStream(os: ByteArrayOutputStream) extends ServletOutputStream {
  def write(b: Int) {
    os.write(b)
  }
}

/**
 * A Mock HttpSession implementation.
 *
 * @author Steve Jenson (stevej@pobox.com)
 */
class MockHttpSession extends HttpSession {
  val values = new _root_.scala.collection.jcl.HashMap[String, Any](new _root_.java.util.HashMap)
  val attr = new _root_.scala.collection.jcl.HashMap[String, Any](new _root_.java.util.HashMap)
  val sessionContext = null
  var maxii = 0
  var servletContext = null
  var creationTime = System.currentTimeMillis
  def isNew = false
  def invalidate {}
  def getValue(key: String) = values.get(key) match {
    case Some(v) => v.asInstanceOf[Object]
    case None => Nil
  }
  def removeValue(key: String) = values -= key
  def putValue(key: String, value: Any) = values += (key -> value)
  def getAttribute(key: String) = attr.get(key) match {
      case Some(v) => v.asInstanceOf[Object]
      case None => Nil
    }
  def removeAttribute(key: String) = attr -= key
  def setAttribute(key: String, value: Any) = attr += (key -> value)
  def getValueNames: Array[String] = values.keySet.toArray
  def getAttributeNames = new Vector[AnyRef](attr.underlying.keySet).elements
  def getSessionContext = sessionContext
  def getMaxInactiveInterval = maxii
  def setMaxInactiveInterval(i: Int) = maxii = i
  def getServletContext = servletContext
  def getLastAccessedTime = 0L
  def getId = null
  def getCreationTime = creationTime
}