package net.liftweb.mapper
/*
* Copyright 2006-2008 WorldWide Conferencing, LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions
* and limitations under the License.
*/
import _root_.java.sql.{Connection, ResultSet, Statement, PreparedStatement, Types, ResultSetMetaData}
import _root_.javax.sql.{ DataSource}
import _root_.javax.naming.{Context, InitialContext}
import _root_.scala.collection.mutable._
import _root_.net.liftweb.util._
import _root_.net.liftweb.http._
import Helpers._
object DB {
private val threadStore = new ThreadLocal[HashMap[ConnectionIdentifier, ConnectionHolder]]
private val envContext = FatLazy((new InitialContext).lookup("java:/comp/env").asInstanceOf[Context])
var queryTimeout: Box[Int] = Empty
private var logFuncs: List[(String, Long) => Any] = Nil
def addLogFunc(f: (String, Long) => Any): List[(String, Long) => Any] = {
logFuncs = logFuncs ::: List(f)
logFuncs
}
/**
* can we get a JDBC connection from JNDI?
*/
def jndiJdbcConnAvailable_? : Boolean = {
val touchedEnv = envContext.calculated_?
val ret = try {
(envContext.get.lookup(DefaultConnectionIdentifier.jndiName).asInstanceOf[DataSource].getConnection) != null
} catch {
case e => false
}
if (!touchedEnv) envContext.reset
ret
}
// var connectionManager: Box[ConnectionManager] = Empty
private val connectionManagers = new HashMap[ConnectionIdentifier, ConnectionManager];
def defineConnectionManager(name: ConnectionIdentifier, mgr: ConnectionManager) {
connectionManagers(name) = mgr
}
case class ConnectionHolder(conn: SuperConnection, cnt: Int, postCommit: List[() => Unit])
private def info : HashMap[ConnectionIdentifier, ConnectionHolder] = {
threadStore.get match {
case null =>
val tinfo = new HashMap[ConnectionIdentifier, ConnectionHolder]
threadStore.set(tinfo)
tinfo
case v => v
}
}
// remove thread-local association
def clearThread: Unit = {
val i = info
val ks = i.keySet
if (ks.isEmpty)
threadStore.remove
else {
ks.foreach(n => releaseConnectionNamed(n))
clearThread
}
}
private def newConnection(name : ConnectionIdentifier) : SuperConnection = {
val ret = (Box(connectionManagers.get(name)).flatMap(cm => cm.newConnection(name).map(c => new SuperConnection(c, () => cm.releaseConnection(c))))) openOr {
Helpers.tryo {
val uniqueId = if (Log.isDebugEnabled) Helpers.nextNum.toString else ""
Log.debug("Connection ID "+uniqueId+" for JNDI connection "+name.jndiName+" opened")
val conn = envContext.get.lookup(name.jndiName).asInstanceOf[DataSource].getConnection
new SuperConnection(conn, () => {Log.debug("Connection ID "+uniqueId+" for JNDI connection "+name.jndiName+" closed"); conn.close})
} openOr {throw new NullPointerException("Looking for Connection Identifier "+name+" but failed to find either a JNDI data source "+
"with the name "+name.jndiName+" or a lift connection manager with the correct name")}
}
ret.setAutoCommit(false)
ret
}
/**
* Build a LoanWrapper to pass into S.addAround() to make requests for
* the DefaultConnectionIdentifier transactional for the complete HTTP request
*/
def buildLoanWrapper(): LoanWrapper =
buildLoanWrapper(List(DefaultConnectionIdentifier))
/**
* Build a LoanWrapper to pass into S.addAround() to make requests for
* the List of ConnectionIdentifiers transactional for the complete HTTP request
*/
def buildLoanWrapper(in: List[ConnectionIdentifier]): LoanWrapper =
new LoanWrapper {
private def doWith[T](in: List[ConnectionIdentifier], f: => T): T =
in match {
case Nil => f
case x :: xs => use(x)(ignore => doWith(xs, f))
}
private object DepthCnt extends RequestVar(0)
def apply[T](f: => T): T =
try {
DepthCnt.update(_ + 1)
doWith(in, f)
} finally {
DepthCnt.update(_ - 1)
if (DepthCnt.is == 0) clearThread
}
}
private def releaseConnection(conn : SuperConnection) : Unit = conn.close
private def getConnection(name : ConnectionIdentifier): SuperConnection = {
Log.trace("Acquiring connection "+name+" On thread "+Thread.currentThread)
var ret = info.get(name) match {
case None => ConnectionHolder(newConnection(name), 1, Nil)
case Some(ConnectionHolder(conn, cnt, post)) => ConnectionHolder(conn, cnt + 1, post)
}
info(name) = ret
Log.trace("Acquired connection "+name+" on thread "+Thread.currentThread+
" count "+ret.cnt)
ret.conn
}
private def releaseConnectionNamed(name: ConnectionIdentifier) {
Log.trace("Request to release connection: "+name+" on thread "+Thread.currentThread)
(info.get(name): @unchecked) match {
case Some(ConnectionHolder(c, 1, post)) =>
c.commit
tryo(c.releaseFunc())
info -= name
post.reverse.foreach(f => tryo(f()))
Log.trace("Released connection "+name+" on thread "+Thread.currentThread)
case Some(ConnectionHolder(c, n, post)) =>
Log.trace("Did not release connection: "+name+" on thread "+Thread.currentThread+" count "+(n - 1))
info(name) = ConnectionHolder(c, n - 1, post)
case _ =>
// ignore
}
}
/**
* Append a function to be invoked after the commit has taken place for the given connection identifier
*/
def appendPostFunc(name: ConnectionIdentifier, func: () => Unit) {
info.get(name) match {
case Some(ConnectionHolder(c, n, post)) => info(name) = ConnectionHolder(c, n, func :: post)
case _ =>
}
}
private def runLogger(query: String, time: Long) {
logFuncs.foreach(_(query, time))
}
def statement[T](db : SuperConnection)(f : (Statement) => T) : T = {
Helpers.calcTime {
val st = db.createStatement
queryTimeout.foreach(to => st.setQueryTimeout(to))
try {
(st.toString, f(st))
} finally {
st.close
}
} match {case (time, (query, res)) => runLogger(query, time); res}
}
def exec[T](db : SuperConnection, query : String)(f : (ResultSet) => T) : T = {
Helpers.calcTime(
statement(db) {st =>
f(st.executeQuery(query))
}) match {case (time, res) => runLogger(query, time); res}
}
private def asString(pos: Int, rs: ResultSet, md: ResultSetMetaData): String = {
import _root_.java.sql.Types._
md.getColumnType(pos) match {
case ARRAY | BINARY | BLOB | DATALINK | DISTINCT | JAVA_OBJECT | LONGVARBINARY | NULL | OTHER | REF | STRUCT | VARBINARY => rs.getObject(pos) match {
case null => null
case s => s.toString
}
case BIGINT | INTEGER | DECIMAL | NUMERIC | SMALLINT | TINYINT => rs.getLong(pos).toString
case BIT | BOOLEAN => rs.getBoolean(pos).toString
case VARCHAR | CHAR | CLOB | LONGVARCHAR => rs.getString(pos)
case DATE | TIME | TIMESTAMP => rs.getTimestamp(pos).toString
case DOUBLE | FLOAT | REAL => rs.getDouble(pos).toString
}
}
def resultSetTo(rs: ResultSet): (List[String], List[List[String]]) = {
val md = rs.getMetaData
val cnt = md.getColumnCount
val cntList = (1 to cnt).toList
val colNames = cntList.map(i => md.getColumnName(i))
val lb = new ListBuffer[List[String]]()
while(rs.next) {
lb += cntList.map(i => asString(i, rs, md))
}
(colNames, lb.toList)
}
def runQuery(query: String, params: List[Any]): (List[String], List[List[String]]) = {
use(DefaultConnectionIdentifier)(conn => prepareStatement(query, conn) {
ps =>
params.zipWithIndex.foreach {
case (null, idx) => ps.setNull(idx + 1, Types.VARCHAR)
case (i: Int, idx) => ps.setInt(idx +1, i)
case (l: Long, idx) => ps.setLong(idx + 1, l)
case (d: Double, idx) => ps.setDouble(idx + 1, d)
case (f: Float, idx) => ps.setFloat(idx + 1, f)
case (d: _root_.java.util.Date, idx) => ps.setDate(idx + 1, new _root_.java.sql.Date(d.getTime))
case (b: Boolean, idx) => ps.setBoolean(idx + 1, b)
case (s: String, idx) => ps.setString(idx + 1, s)
case (bn: _root_.java.math.BigDecimal, idx) => ps.setBigDecimal(idx + 1, bn)
case (obj, idx) => ps.setObject(idx + 1, obj)
}
resultSetTo(ps.executeQuery)
})
}
def runQuery(query: String): (List[String], List[List[String]]) = {
use(DefaultConnectionIdentifier)(conn => exec(conn, query)(resultSetTo))
}
def rollback(name: ConnectionIdentifier) = use(name)(conn => conn.rollback)
/**
* Executes {@code statement} and converts the {@code ResultSet} to model
* instance {@code T} using {@code f}
*/
def exec[T](statement : PreparedStatement)(f : (ResultSet) => T) : T = {
queryTimeout.foreach(to => statement.setQueryTimeout(to))
Helpers.calcTime {
val rs = statement.executeQuery
try {
(statement.toString, f(rs))
} finally {
statement.close
rs.close
}} match {case (time, (query, res)) => runLogger(query, time); res}
}
def prepareStatement[T](statement : String, conn: SuperConnection)(f : (PreparedStatement) => T) : T = {
Helpers.calcTime {
val st = conn.prepareStatement(statement)
queryTimeout.foreach(to => st.setQueryTimeout(to))
try {
(st.toString, f(st))
} finally {
st.close
}} match {case (time, (query, res)) => runLogger(query, time); res}
}
def prepareStatement[T](statement : String, keys: Int, conn: SuperConnection)(f : (PreparedStatement) => T) : T = {
Helpers.calcTime{
val st = conn.prepareStatement(statement, keys)
queryTimeout.foreach(to => st.setQueryTimeout(to))
try {
(st.toString, f(st))
} finally {
st.close
}} match {case (time, (query, res)) => runLogger(query, time); res}
}
/**
* Executes function {@code f} with the connection named {@code name}. Releases the connection
* before returning.
*/
def use[T](name : ConnectionIdentifier)(f : (SuperConnection) => T) : T = {
val conn = getConnection(name)
try {
f(conn)
} finally {
releaseConnectionNamed(name)
}
}
val reservedWords = _root_.scala.collection.immutable.HashSet.empty ++
List("abort" ,
"accept" ,
"access" ,
"add" ,
"admin" ,
"after" ,
"all" ,
"allocate" ,
"alter" ,
"analyze" ,
"and" ,
"any" ,
"archive" ,
"archivelog" ,
"array" ,
"arraylen" ,
"as" ,
"asc" ,
"assert" ,
"assign" ,
"at" ,
"audit" ,
"authorization" ,
"avg" ,
"backup" ,
"base_table" ,
"become" ,
"before" ,
"begin" ,
"between" ,
"binary_integer" ,
"block" ,
"body" ,
"boolean" ,
"by" ,
"cache" ,
"cancel" ,
"cascade" ,
"case" ,
"change" ,
"char" ,
"character" ,
"char_base" ,
"check" ,
"checkpoint" ,
"close" ,
"cluster" ,
"clusters" ,
"cobol" ,
"colauth" ,
"column" ,
"columns" ,
"comment" ,
"commit" ,
"compile" ,
"compress" ,
"connect" ,
"constant" ,
"constraint" ,
"constraints" ,
"contents" ,
"continue" ,
"controlfile" ,
"count" ,
"crash" ,
"create" ,
"current" ,
"currval" ,
"cursor" ,
"cycle" ,
"database" ,
"data_base" ,
"datafile" ,
"date" ,
"dba" ,
"debugoff" ,
"debugon" ,
"dec" ,
"decimal" ,
"declare" ,
"default" ,
"definition" ,
"delay" ,
"delete" ,
"delta" ,
"desc" ,
"digits" ,
"disable" ,
"dismount" ,
"dispose" ,
"distinct" ,
"do" ,
"double" ,
"drop" ,
"dump" ,
"each" ,
"else" ,
"elsif" ,
"enable" ,
"end" ,
"entry" ,
"escape" ,
"events" ,
"except" ,
"exception" ,
"exception_init" ,
"exceptions" ,
"exclusive" ,
"exec" ,
"execute" ,
"exists" ,
"exit" ,
"explain" ,
"extent" ,
"externally" ,
"false" ,
"fetch" ,
"file" ,
"float" ,
"flush" ,
"for" ,
"force" ,
"foreign" ,
"form" ,
"fortran" ,
"found" ,
"freelist" ,
"freelists" ,
"from" ,
"function" ,
"generic" ,
"go" ,
"goto" ,
"grant" ,
"group" ,
"having" ,
"identified" ,
"if" ,
"immediate" ,
"in" ,
"including" ,
"increment" ,
"index" ,
"indexes" ,
"indicator" ,
"initial" ,
"initrans" ,
"insert" ,
"instance" ,
"int" ,
"integer" ,
"intersect" ,
"into" ,
"is" ,
"key" ,
"language" ,
"layer" ,
"level" ,
"like" ,
"limited" ,
"link" ,
"lists" ,
"lock" ,
"logfile" ,
"long" ,
"loop" ,
"manage" ,
"manual" ,
"max" ,
"maxdatafiles" ,
"maxextents" ,
"maxinstances" ,
"maxlogfiles" ,
"maxloghistory" ,
"maxlogmembers" ,
"maxtrans" ,
"maxvalue" ,
"min" ,
"minextents" ,
"minus" ,
"minvalue" ,
"mlslabel" ,
"mod" ,
"mode" ,
"modify" ,
"module" ,
"mount" ,
"natural" ,
"new" ,
"next" ,
"nextval" ,
"noarchivelog" ,
"noaudit" ,
"nocache" ,
"nocompress" ,
"nocycle" ,
"nomaxvalue" ,
"nominvalue" ,
"none" ,
"noorder" ,
"noresetlogs" ,
"normal" ,
"nosort" ,
"not" ,
"notfound" ,
"nowait" ,
"null" ,
"number" ,
"number_base" ,
"numeric" ,
"of" ,
"off" ,
"offline" ,
"old" ,
"on" ,
"online" ,
"only" ,
"open" ,
"optimal" ,
"option" ,
"or" ,
"order" ,
"others" ,
"out" ,
"own" ,
"package" ,
"parallel" ,
"partition" ,
"pctfree" ,
"pctincrease" ,
"pctused" ,
"plan" ,
"pli" ,
"positive" ,
"pragma" ,
"precision" ,
"primary" ,
"prior" ,
"private" ,
"privileges" ,
"procedure" ,
"profile" ,
"public" ,
"quota" ,
"raise" ,
"range" ,
"raw" ,
"read" ,
"real" ,
"record" ,
"recover" ,
"references" ,
"referencing" ,
"release" ,
"remr" ,
"rename" ,
"resetlogs" ,
"resource" ,
"restricted" ,
"return" ,
"reuse" ,
"reverse" ,
"revoke" ,
"role" ,
"roles" ,
"rollback" ,
"row" ,
"rowid" ,
"rowlabel" ,
"rownum" ,
"rows" ,
"rowtype" ,
"run" ,
"savepoint" ,
"schema" ,
"scn" ,
"section" ,
"segment" ,
"select" ,
"separate" ,
"sequence" ,
"session" ,
"set" ,
"share" ,
"shared" ,
"size" ,
"smallint" ,
"snapshot" ,
"some" ,
"sort" ,
"space" ,
"sql" ,
"sqlbuf" ,
"sqlcode" ,
"sqlerrm" ,
"sqlerror" ,
"sqlstate" ,
"start" ,
"statement" ,
"statement_id" ,
"statistics" ,
"stddev" ,
"stop" ,
"storage" ,
"subtype" ,
"successful" ,
"sum" ,
"switch" ,
"synonym" ,
"sysdate" ,
"system" ,
"tabauth" ,
"table" ,
"tables" ,
"tablespace" ,
"task" ,
"temporary" ,
"terminate" ,
"then" ,
"thread" ,
"time" ,
"to" ,
"tracing" ,
"transaction" ,
"trigger" ,
"triggers" ,
"true" ,
"truncate" ,
"type" ,
"uid" ,
"under" ,
"union" ,
"unique" ,
"unlimited" ,
"until" ,
"update" ,
"use" ,
"user" ,
"using" ,
"validate" ,
"values" ,
"varchar" ,
"varchar2" ,
"variance" ,
"view" ,
"views" ,
"when" ,
"whenever" ,
"where" ,
"while" ,
"with" ,
"work" ,
"write" ,
"xor")
}
class SuperConnection(val connection: Connection,val releaseFunc: () => Any) {
lazy val brokenLimit_? = driverType.brokenLimit_?
lazy val brokenAutogeneratedKeys_? = driverType.brokenAutogeneratedKeys_?
lazy val wickedBrokenAutogeneratedKeys_? = driverType.wickedBrokenAutogeneratedKeys_?
def createTablePostpend: String = driverType.createTablePostpend
def supportsForeignKeys_? : Boolean = driverType.supportsForeignKeys_?
lazy val driverType = (calcDriver(connection.getMetaData.getDatabaseProductName))
def calcDriver(name: String): DriverType = {
name match {
case DerbyDriver.name => DerbyDriver
case MySqlDriver.name => MySqlDriver
case PostgreSqlDriver.name => PostgreSqlDriver
case H2Driver.name => H2Driver
case SqlServerDriver.name => SqlServerDriver
case OracleDriver.name => OracleDriver
case MaxDbDriver.name => MaxDbDriver
}
}
}
object SuperConnection {
implicit def superToConn(in: SuperConnection): Connection = in.connection
}
trait ConnectionIdentifier {
def jndiName: String
override def toString() = "ConnectionIdentifier("+jndiName+")"
override def hashCode() = jndiName.hashCode()
override def equals(other: Any): Boolean = other match {
case ci: ConnectionIdentifier => ci.jndiName == this.jndiName
case _ => false
}
}
case object DefaultConnectionIdentifier extends ConnectionIdentifier {
var jndiName = "lift"
}