From 2b4aca073d7625aca3812858eeff43c0baf3d5aa Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 7 Dec 2020 10:11:03 -0600 Subject: [PATCH] Miscellaneous code updates (#506) --- README.md | 14 +++++++------- build.sbt | 4 ++-- project/plugins.sbt | 4 ++-- .../spark/xml/parsers/StaxXmlParserUtils.scala | 2 +- .../databricks/spark/xml/util/InferSchema.scala | 6 +++--- .../com/databricks/spark/xml/XmlSuite.scala | 15 ++++++++++++--- .../spark/xml/util/XSDToSchemaSuite.scala | 16 ++++++++-------- 7 files changed, 35 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index b5258d7c..85d6a8ec 100644 --- a/README.md +++ b/README.md @@ -383,13 +383,13 @@ from pyspark.sql import SparkSession from pyspark.sql.types import * spark = SparkSession.builder.getOrCreate() -customSchema = StructType([ \ - StructField("_id", StringType(), True), \ - StructField("author", StringType(), True), \ - StructField("description", StringType(), True), \ - StructField("genre", StringType(), True), \ - StructField("price", DoubleType(), True), \ - StructField("publish_date", StringType(), True), \ +customSchema = StructType([ + StructField("_id", StringType(), True), + StructField("author", StringType(), True), + StructField("description", StringType(), True), + StructField("genre", StringType(), True), + StructField("price", DoubleType(), True), + StructField("publish_date", StringType(), True), StructField("title", StringType(), True)]) df = spark.read \ diff --git a/build.sbt b/build.sbt index a4a637a4..cd2e987b 100755 --- a/build.sbt +++ b/build.sbt @@ -16,11 +16,11 @@ val sparkVersion = sys.props.get("spark.testVersion").getOrElse("2.4.7") autoScalaLibrary := false libraryDependencies ++= Seq( - "commons-io" % "commons-io" % "2.7", + "commons-io" % "commons-io" % "2.8.0", "org.glassfish.jaxb" % "txw2" % "2.3.3", "org.apache.ws.xmlschema" % "xmlschema-core" % "2.2.5", "org.slf4j" % "slf4j-api" % "1.7.25" % Provided, - "org.scalatest" %% "scalatest" % "3.2.2" % Test, + "org.scalatest" %% "scalatest" % "3.2.3" % Test, "com.novocode" % "junit-interface" % "0.11" % Test, "org.apache.spark" %% "spark-core" % sparkVersion % Provided, "org.apache.spark" %% "spark-sql" % sparkVersion % Provided, diff --git a/project/plugins.sbt b/project/plugins.sbt index 6831f88a..5b29740f 100755 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -6,8 +6,8 @@ libraryDependencies += "org.scalariform" %% "scalariform" % "0.2.10" addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.3") -addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1") +addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.1.1") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.6.1") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.8.0") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.8.1") diff --git a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala index 0c223d4d..60c659b0 100644 --- a/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala +++ b/src/main/scala/com/databricks/spark/xml/parsers/StaxXmlParserUtils.scala @@ -61,7 +61,7 @@ private[xml] object StaxXmlParserUtils { */ def skipUntil(parser: XMLEventReader, eventType: Int): XMLEvent = { var event = parser.peek - while(parser.hasNext && event.getEventType != eventType) { + while (parser.hasNext && event.getEventType != eventType) { event = parser.nextEvent } event diff --git a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala index a4a4bfa3..caa84819 100644 --- a/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala +++ b/src/main/scala/com/databricks/spark/xml/util/InferSchema.scala @@ -68,10 +68,10 @@ private[xml] object InferSchema { * 3. Replace any remaining null fields with string, the top type */ def infer(xml: RDD[String], options: XmlOptions): StructType = { - val schemaData = if (options.samplingRatio > 0.99) { - xml - } else { + val schemaData = if (options.samplingRatio < 1.0) { xml.sample(withReplacement = false, options.samplingRatio, 1) + } else { + xml } // perform schema inference on each row and merge afterwards val rootType = schemaData.mapPartitions { iter => diff --git a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala index 45c4a523..73325305 100755 --- a/src/test/scala/com/databricks/spark/xml/XmlSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/XmlSuite.scala @@ -16,7 +16,7 @@ package com.databricks.spark.xml import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} -import java.nio.file.{Files, Path} +import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.util.TimeZone @@ -836,7 +836,7 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll { } private[this] def testNextedElementFromFile(xmlFile: String) = { - val lines = Source.fromFile(xmlFile).getLines.toList + val lines = getLines(Paths.get(xmlFile)).toList val firstExpected = lines(2).trim val lastExpected = lines(3).trim val config = new Configuration(spark.sparkContext.hadoopConfiguration) @@ -1282,7 +1282,7 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll { val xmlFile = Files.list(xmlPath).iterator.asScala.filter(_.getFileName.toString.startsWith("part-")).next - val firstLine = Source.fromFile(xmlFile.toFile).getLines.next + val firstLine = getLines(xmlFile).head assert(firstLine === "") } @@ -1310,4 +1310,13 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll { assert(map.contains("M2")) } + private def getLines(path: Path): Seq[String] = { + val source = Source.fromFile(path.toFile) + try { + source.getLines.toList + } finally { + source.close() + } + } + } diff --git a/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala b/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala index f5a71bea..2ade2f63 100644 --- a/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala +++ b/src/test/scala/com/databricks/spark/xml/util/XSDToSchemaSuite.scala @@ -28,7 +28,7 @@ class XSDToSchemaSuite extends AnyFunSuite { private val resDir = "src/test/resources" test("Basic parsing") { - val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/basket.xsd")) + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/basket.xsd")) val expectedSchema = buildSchema( field("basket", struct( @@ -39,7 +39,7 @@ class XSDToSchemaSuite extends AnyFunSuite { } test("Relative path parsing") { - val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/include-example/first.xsd")) + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/include-example/first.xsd")) val expectedSchema = buildSchema( field("basket", struct( @@ -50,7 +50,7 @@ class XSDToSchemaSuite extends AnyFunSuite { } test("Test schema types and attributes") { - val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/catalog.xsd")) + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/catalog.xsd")) val expectedSchema = buildSchema( field("catalog", struct( @@ -73,26 +73,26 @@ class XSDToSchemaSuite extends AnyFunSuite { } test("Test xs:choice nullability") { - val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/choice.xsd")) + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/choice.xsd")) val expectedSchema = buildSchema( field("el", struct(field("foo"), field("bar"), field("baz")), nullable = false)) assert(expectedSchema === parsedSchema) } test("Two root elements") { - val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/twoelements.xsd")) + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/twoelements.xsd")) val expectedSchema = buildSchema(field("bar", nullable = false), field("foo", nullable = false)) assert(expectedSchema === parsedSchema) } test("xs:any schema") { - val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/xsany.xsd")) + val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/xsany.xsd")) val expectedSchema = buildSchema( field("root", struct( field("foo", struct( - field("xs_any", nullable = true)), + field("xs_any")), nullable = false), field("bar", struct( @@ -104,7 +104,7 @@ class XSDToSchemaSuite extends AnyFunSuite { nullable = false), field("bing", struct( - field("xs_any", nullable = true)), + field("xs_any")), nullable = false)), nullable = false)) assert(expectedSchema === parsedSchema)