Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/main/scala/za/co/absa/standardization/udf/UDFBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ package za.co.absa.standardization.udf
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.DataType
import za.co.absa.standardization.config.StandardizationConfig
import za.co.absa.standardization.config.{
BasicErrorCodesConfig,
ErrorCodesConfig,
StandardizationConfig
}
import za.co.absa.standardization.types.parsers.NumericParser
import za.co.absa.standardization.types.parsers.NumericParser.NumericParserException

Expand All @@ -39,9 +43,25 @@ object UDFBuilder {
val vColumnNameForError = columnNameForError
val vDefaultValue = defaultValue
val vColumnNullable = columnNullable
val vStdConfig = stdConfig
val vErrorCodes = BasicErrorCodesConfig(
stdConfig.errorCodes.castError,
stdConfig.errorCodes.nullError,
stdConfig.errorCodes.typeError,
stdConfig.errorCodes.schemaError
)

udf[UDFResult[T], String](numericParserToTyped(_, sourceDataType, targetDataType, vParser, vColumnNullable, vColumnNameForError, vStdConfig, vDefaultValue))
udf[UDFResult[T], String](
numericParserToTyped(
_,
sourceDataType,
targetDataType,
vParser,
vColumnNullable,
vColumnNameForError,
vErrorCodes,
vDefaultValue
)
)
}

private def numericParserToTyped[T](input: String,
Expand All @@ -50,14 +70,14 @@ object UDFBuilder {
parser: NumericParser[T],
columnNullable: Boolean,
columnNameForError: String,
stdConfig: StandardizationConfig,
errorCodes: ErrorCodesConfig,
defaultValue: Option[T]): UDFResult[T] = {
val result = Option(input) match {
case Some(string) => parser.parse(string).map(Some(_))
case None if columnNullable => Success(None)
case None => Failure(nullException)
}
UDFResult.fromTry(result, columnNameForError, input, sourceDataType.typeName, targetDataType.typeName, None, stdConfig, defaultValue)
UDFResult.fromTry(result, columnNameForError, input, sourceDataType.typeName, targetDataType.typeName, None, errorCodes, defaultValue)
}

private val nullException = new NumericParserException("Null value on input for non-nullable field")
Expand Down
17 changes: 14 additions & 3 deletions src/main/scala/za/co/absa/standardization/udf/UDFResult.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package za.co.absa.standardization.udf

import za.co.absa.standardization.ErrorMessage
import za.co.absa.standardization.StandardizationErrorMessage
import za.co.absa.standardization.config.StandardizationConfig
import za.co.absa.standardization.config.{ErrorCodesConfig, StandardizationConfig}

import scala.util.{Failure, Success, Try}

Expand All @@ -38,11 +38,22 @@ object UDFResult {
pattern: Option[String],
stdConfig: StandardizationConfig,
defaultValue: Option[T] = None): UDFResult[T] = {
fromTry(result, columnName, rawValue, sourceType, targetType, pattern, stdConfig.errorCodes, defaultValue)
}

def fromTry[T](result: Try[Option[T]],
columnName: String,
rawValue: String,
sourceType: String,
targetType: String,
pattern: Option[String],
errorCodes: ErrorCodesConfig,
defaultValue: Option[T] = None): UDFResult[T] = {
result match {
case Success(success) => UDFResult.success(success)
case Failure(_) if Option(rawValue).isEmpty => UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdNullErr(columnName)(stdConfig.errorCodes)))
case Failure(_) if Option(rawValue).isEmpty => UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdNullErr(columnName)(errorCodes)))
case Failure(_) =>
UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdCastErr(columnName, rawValue, sourceType, targetType, pattern)(stdConfig.errorCodes)))
UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdCastErr(columnName, rawValue, sourceType, targetType, pattern)(errorCodes)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.scalatest.funsuite.AnyFunSuite
import za.co.absa.standardization.RecordIdGeneration.IdType.NoId
import za.co.absa.standardization.config.{BasicMetadataColumnsConfig, BasicStandardizationConfig, StandardizationConfig}
import za.co.absa.standardization.config.{
BasicMetadataColumnsConfig,
BasicStandardizationConfig,
DefaultStandardizationConfig,
StandardizationConfig
}
import za.co.absa.standardization.schema.MetadataKeys
import za.co.absa.standardization.types.TypedStructField._
import za.co.absa.standardization.types.parsers.IntegralParser.{PatternIntegralParser, RadixIntegralParser}
Expand Down Expand Up @@ -144,4 +149,38 @@ class UDFBuilderSuite extends AnyFunSuite {
ois.readObject().asInstanceOf[UserDefinedFunction]
}

test("Serialization and deserialization of stringUdfViaNumericParser with default config") {
val fieldName = "test"
val field: StructField = StructField(fieldName, IntegerType, nullable = true, new MetadataBuilder()
.putString(MetadataKeys.Pattern, "000000")
.build)
val typedField = TypedStructField(field)

val numericTypeField = typedField.asInstanceOf[NumericTypeStructField[Int]]
val defaultValue: Option[Int] = typedField.defaultValueWithGlobal.get.map(_.asInstanceOf[Int])
val parser = numericTypeField.parser.get.asInstanceOf[PatternIntegralParser[Int]]
val udfFnc = UDFBuilder.stringUdfViaNumericParser(
StringType,
field.dataType,
parser,
numericTypeField.nullable,
fieldName,
DefaultStandardizationConfig,
defaultValue
)
//write
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(udfFnc)
oos.flush()
val serialized = baos.toByteArray
assert(serialized.nonEmpty)
//read
val ois = new ObjectInputStream(new ByteArrayInputStream(serialized)) {
override def resolveClass(desc: ObjectStreamClass): Class[_] =
Class.forName(desc.getName, false, loader)
}
ois.readObject().asInstanceOf[UserDefinedFunction]
}

}