Skip to content

Commit

Permalink
Miscellaneous code updates (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
srowen authored Dec 7, 2020
1 parent e583370 commit 2b4aca0
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 26 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,13 @@ from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder.getOrCreate()
customSchema = StructType([ \
StructField("_id", StringType(), True), \
StructField("author", StringType(), True), \
StructField("description", StringType(), True), \
StructField("genre", StringType(), True), \
StructField("price", DoubleType(), True), \
StructField("publish_date", StringType(), True), \
customSchema = StructType([
StructField("_id", StringType(), True),
StructField("author", StringType(), True),
StructField("description", StringType(), True),
StructField("genre", StringType(), True),
StructField("price", DoubleType(), True),
StructField("publish_date", StringType(), True),
StructField("title", StringType(), True)])

df = spark.read \
Expand Down
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ val sparkVersion = sys.props.get("spark.testVersion").getOrElse("2.4.7")
autoScalaLibrary := false

libraryDependencies ++= Seq(
"commons-io" % "commons-io" % "2.7",
"commons-io" % "commons-io" % "2.8.0",
"org.glassfish.jaxb" % "txw2" % "2.3.3",
"org.apache.ws.xmlschema" % "xmlschema-core" % "2.2.5",
"org.slf4j" % "slf4j-api" % "1.7.25" % Provided,
"org.scalatest" %% "scalatest" % "3.2.2" % Test,
"org.scalatest" %% "scalatest" % "3.2.3" % Test,
"com.novocode" % "junit-interface" % "0.11" % Test,
"org.apache.spark" %% "spark-core" % sparkVersion % Provided,
"org.apache.spark" %% "spark-sql" % sparkVersion % Provided,
Expand Down
4 changes: 2 additions & 2 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ libraryDependencies += "org.scalariform" %% "scalariform" % "0.2.10"

addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.3")

addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.1.1")

addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.6.1")

addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.8.0")
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.8.1")
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private[xml] object StaxXmlParserUtils {
*/
def skipUntil(parser: XMLEventReader, eventType: Int): XMLEvent = {
var event = parser.peek
while(parser.hasNext && event.getEventType != eventType) {
while (parser.hasNext && event.getEventType != eventType) {
event = parser.nextEvent
}
event
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ private[xml] object InferSchema {
* 3. Replace any remaining null fields with string, the top type
*/
def infer(xml: RDD[String], options: XmlOptions): StructType = {
val schemaData = if (options.samplingRatio > 0.99) {
xml
} else {
val schemaData = if (options.samplingRatio < 1.0) {
xml.sample(withReplacement = false, options.samplingRatio, 1)
} else {
xml
}
// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
Expand Down
15 changes: 12 additions & 3 deletions src/test/scala/com/databricks/spark/xml/XmlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
package com.databricks.spark.xml

import java.nio.charset.{StandardCharsets, UnsupportedCharsetException}
import java.nio.file.{Files, Path}
import java.nio.file.{Files, Path, Paths}
import java.sql.{Date, Timestamp}
import java.util.TimeZone

Expand Down Expand Up @@ -836,7 +836,7 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll {
}

private[this] def testNextedElementFromFile(xmlFile: String) = {
val lines = Source.fromFile(xmlFile).getLines.toList
val lines = getLines(Paths.get(xmlFile)).toList
val firstExpected = lines(2).trim
val lastExpected = lines(3).trim
val config = new Configuration(spark.sparkContext.hadoopConfiguration)
Expand Down Expand Up @@ -1282,7 +1282,7 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll {

val xmlFile =
Files.list(xmlPath).iterator.asScala.filter(_.getFileName.toString.startsWith("part-")).next
val firstLine = Source.fromFile(xmlFile.toFile).getLines.next
val firstLine = getLines(xmlFile).head
assert(firstLine === "<root foo=\"bar\" bing=\"baz\">")
}

Expand Down Expand Up @@ -1310,4 +1310,13 @@ final class XmlSuite extends AnyFunSuite with BeforeAndAfterAll {
assert(map.contains("M2"))
}

private def getLines(path: Path): Seq[String] = {
val source = Source.fromFile(path.toFile)
try {
source.getLines.toList
} finally {
source.close()
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class XSDToSchemaSuite extends AnyFunSuite {
private val resDir = "src/test/resources"

test("Basic parsing") {
val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/basket.xsd"))
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/basket.xsd"))
val expectedSchema = buildSchema(
field("basket",
struct(
Expand All @@ -39,7 +39,7 @@ class XSDToSchemaSuite extends AnyFunSuite {
}

test("Relative path parsing") {
val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/include-example/first.xsd"))
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/include-example/first.xsd"))
val expectedSchema = buildSchema(
field("basket",
struct(
Expand All @@ -50,7 +50,7 @@ class XSDToSchemaSuite extends AnyFunSuite {
}

test("Test schema types and attributes") {
val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/catalog.xsd"))
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/catalog.xsd"))
val expectedSchema = buildSchema(
field("catalog",
struct(
Expand All @@ -73,26 +73,26 @@ class XSDToSchemaSuite extends AnyFunSuite {
}

test("Test xs:choice nullability") {
val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/choice.xsd"))
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/choice.xsd"))
val expectedSchema = buildSchema(
field("el", struct(field("foo"), field("bar"), field("baz")), nullable = false))
assert(expectedSchema === parsedSchema)
}

test("Two root elements") {
val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/twoelements.xsd"))
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/twoelements.xsd"))
val expectedSchema = buildSchema(field("bar", nullable = false), field("foo", nullable = false))
assert(expectedSchema === parsedSchema)
}

test("xs:any schema") {
val parsedSchema = XSDToSchema.read(Paths.get(s"${resDir}/xsany.xsd"))
val parsedSchema = XSDToSchema.read(Paths.get(s"$resDir/xsany.xsd"))
val expectedSchema = buildSchema(
field("root",
struct(
field("foo",
struct(
field("xs_any", nullable = true)),
field("xs_any")),
nullable = false),
field("bar",
struct(
Expand All @@ -104,7 +104,7 @@ class XSDToSchemaSuite extends AnyFunSuite {
nullable = false),
field("bing",
struct(
field("xs_any", nullable = true)),
field("xs_any")),
nullable = false)),
nullable = false))
assert(expectedSchema === parsedSchema)
Expand Down

0 comments on commit 2b4aca0

Please sign in to comment.