import java.sql.{Connection, DriverManager, PreparedStatement, ResultSet}
import scala.util.{Try, Success, Failure, Using}
import scala.collection.mutable.ListBuffer
case class User(id: Int, name: String, email: String, department: String)
class UserRepository(connection: Connection) {
private def validateUserId(userId: String): Int = {
Try(userId.toInt) match {
case Success(id) if id > 0 => id
case _ => throw new IllegalArgumentException(s"Invalid user ID: $userId")
}
}
private def validateEmail(email: String): String = {
require(email.nonEmpty && email.contains("@") && email.length <= 100,
"Invalid email format")
email.trim
}
// SECURE: PreparedStatement with parameter binding
def getUserById(userIdStr: String): Option[User] = {
val userId = validateUserId(userIdStr)
val query = "SELECT id, name, email, department FROM users WHERE id = ?"
Using(connection.prepareStatement(query)) { pstmt =>
pstmt.setInt(1, userId)
val rs = pstmt.executeQuery()
if (rs.next()) {
Some(User(rs.getInt("id"), rs.getString("name"),
rs.getString("email"), rs.getString("department")))
} else {
None
}
}.recover {
case ex => throw new RuntimeException(s"Database error: ${ex.getMessage}", ex)
}.get
}
// SECURE: Parameterized query with validation
def searchUsers(name: String, department: String): List[User] = {
require(name.length <= 100, "Search name too long")
val validDepartments = Set("IT", "HR", "Finance", "Marketing", "Sales")
require(department.isEmpty || validDepartments.contains(department),
s"Invalid department: $department")
val conditions = ListBuffer[String]()
val params = ListBuffer[String]()
var query = "SELECT id, name, email, department FROM users WHERE 1=1"
if (name.nonEmpty) {
conditions += " AND name ILIKE ?"
params += s"%$name%"
}
if (department.nonEmpty) {
conditions += " AND department = ?"
params += department
}
query += conditions.mkString + " LIMIT 100"
Using(connection.prepareStatement(query)) { pstmt =>
params.zipWithIndex.foreach { case (param, index) =>
pstmt.setString(index + 1, param)
}
val rs = pstmt.executeQuery()
val users = ListBuffer[User]()
while (rs.next()) {
users += User(rs.getInt("id"), rs.getString("name"),
rs.getString("email"), rs.getString("department"))
}
users.toList
}.recover {
case ex => throw new RuntimeException(s"Database error: ${ex.getMessage}", ex)
}.get
}
// SECURE: Parameterized update with validation
def updateUserStatus(userEmail: String, newStatus: String): Boolean = {
val email = validateEmail(userEmail)
val validStatuses = Set("active", "inactive", "suspended", "pending")
require(validStatuses.contains(newStatus), s"Invalid status: $newStatus")
val query = "UPDATE users SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE email = ?"
Using(connection.prepareStatement(query)) { pstmt =>
pstmt.setString(1, newStatus)
pstmt.setString(2, email)
val rowsAffected = pstmt.executeUpdate()
rowsAffected > 0
}.recover {
case ex => throw new RuntimeException(s"Database error: ${ex.getMessage}", ex)
}.get
}
// SECURE: Whitelisted sort columns
def getUsersSorted(sortColumn: String, direction: String): List[User] = {
val validSortColumns = Map(
"name" -> "name",
"email" -> "email",
"department" -> "department",
"created" -> "created_at"
)
val validDirections = Set("ASC", "DESC")
val column = validSortColumns.getOrElse(sortColumn,
throw new IllegalArgumentException(s"Invalid sort column: $sortColumn"))
require(validDirections.contains(direction.toUpperCase),
s"Invalid sort direction: $direction")
// Safe because column and direction are whitelisted
val query = s"SELECT id, name, email, department FROM users ORDER BY $column ${direction.toUpperCase} LIMIT 100"
Using(connection.prepareStatement(query)) { pstmt =>
val rs = pstmt.executeQuery()
val users = ListBuffer[User]()
while (rs.next()) {
users += User(rs.getInt("id"), rs.getString("name"),
rs.getString("email"), rs.getString("department"))
}
users.toList
}.recover {
case ex => throw new RuntimeException(s"Database error: ${ex.getMessage}", ex)
}.get
}
// SECURE: Whitelisted fields with parameterized values
def advancedSearch(filters: Map[String, String]): List[User] = {
val allowedFields = Set("name", "department", "status", "email")
// Validate all field names
filters.keys.foreach { field =>
require(allowedFields.contains(field), s"Invalid search field: $field")
}
val conditions = ListBuffer[String]()
val params = ListBuffer[String]()
var query = "SELECT id, name, email, department FROM users WHERE 1=1"
filters.foreach { case (field, value) =>
conditions += s" AND $field = ?" // Safe because field is whitelisted
params += value
}
query += conditions.mkString + " LIMIT 100"
Using(connection.prepareStatement(query)) { pstmt =>
params.zipWithIndex.foreach { case (param, index) =>
pstmt.setString(index + 1, param)
}
val rs = pstmt.executeQuery()
val users = ListBuffer[User]()
while (rs.next()) {
users += User(rs.getInt("id"), rs.getString("name"),
rs.getString("email"), rs.getString("department"))
}
users.toList
}.recover {
case ex => throw new RuntimeException(s"Database error: ${ex.getMessage}", ex)
}.get
}
// Example with transaction support
def batchUpdateStatuses(updates: List[(String, String)]): Try[Int] = {
Try {
connection.setAutoCommit(false)
val query = "UPDATE users SET status = ? WHERE email = ?"
Using(connection.prepareStatement(query)) { pstmt =>
var totalUpdated = 0
updates.foreach { case (email, status) =>
val validatedEmail = validateEmail(email)
val validStatuses = Set("active", "inactive", "suspended", "pending")
require(validStatuses.contains(status), s"Invalid status: $status")
pstmt.setString(1, status)
pstmt.setString(2, validatedEmail)
pstmt.addBatch()
}
val results = pstmt.executeBatch()
connection.commit()
results.sum
}.get
}.recover {
case ex =>
connection.rollback()
throw ex
}.andThen { _ =>
connection.setAutoCommit(true)
}
}
}