diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 1f0432909..67aefb392 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -3203,6 +3203,7 @@ impl fmt::Display for CreateTable { Some(HiveIOFormat::FileFormat { format }) if !self.external => { write!(f, " STORED AS {format}")? } + Some(HiveIOFormat::Using { format }) => write!(f, " USING {format}")?, _ => (), } if let Some(serde_properties) = serde_properties.as_ref() { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 63b3db644..c10b383eb 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -8658,6 +8658,15 @@ pub enum HiveIOFormat { /// The file format used for storage. format: FileFormat, }, + /// `USING ` syntax used by Spark SQL. + /// + /// Example: `CREATE TABLE t (i INT) USING PARQUET` + /// + /// See + Using { + /// The data source or format name, e.g. `parquet`, `delta`, `csv`. + format: Ident, + }, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Default)] diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 6439ef2c8..0e2a6158d 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -28,6 +28,7 @@ mod oracle; mod postgresql; mod redshift; mod snowflake; +mod spark; mod sqlite; use core::any::{Any, TypeId}; @@ -51,6 +52,7 @@ pub use self::postgresql::PostgreSqlDialect; pub use self::redshift::RedshiftSqlDialect; pub use self::snowflake::parse_snowflake_stage_name; pub use self::snowflake::SnowflakeDialect; +pub use self::spark::SparkSqlDialect; pub use self::sqlite::SQLiteDialect; /// Macro for streamlining the creation of derived `Dialect` objects. @@ -1727,6 +1729,42 @@ pub trait Dialect: Debug + Any { fn supports_xml_expressions(&self) -> bool { false } + + /// Returns true if the dialect supports `USING ` in `CREATE TABLE`. + /// + /// Example: + /// ```sql + /// CREATE TABLE t (i INT) USING PARQUET + /// ``` + /// + /// [Spark SQL](https://spark.apache.org/docs/latest/sql-ref-syntax-ddl-create-table-datasource.html) + fn supports_create_table_using(&self) -> bool { + false + } + + /// Returns true if the dialect treats `LONG` as an alias for `BIGINT`. + /// + /// Example: + /// ```sql + /// CREATE TABLE t (id LONG) + /// ``` + /// + /// [Spark SQL](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) + fn supports_long_type_as_bigint(&self) -> bool { + false + } + + /// Returns true if the dialect supports `MAP` angle-bracket syntax for the MAP data type. + /// + /// Example: + /// ```sql + /// CREATE TABLE t (m MAP) + /// ``` + /// + /// [Spark SQL](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) + fn supports_map_literal_with_angle_brackets(&self) -> bool { + false + } } /// Operators for which precedence must be defined. @@ -1801,6 +1839,7 @@ pub fn dialect_from_str(dialect_name: impl AsRef) -> Option Some(Box::new(AnsiDialect {})), "duckdb" => Some(Box::new(DuckDbDialect {})), "databricks" => Some(Box::new(DatabricksDialect {})), + "spark" | "sparksql" => Some(Box::new(SparkSqlDialect {})), "oracle" => Some(Box::new(OracleDialect {})), _ => None, } diff --git a/src/dialect/spark.rs b/src/dialect/spark.rs new file mode 100644 index 000000000..e14b4d033 --- /dev/null +++ b/src/dialect/spark.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(not(feature = "std"))] +use alloc::boxed::Box; + +use crate::ast::{BinaryOperator, Expr}; +use crate::dialect::Dialect; +use crate::keywords::Keyword; +use crate::parser::{Parser, ParserError}; + +/// A [`Dialect`] for [Apache Spark SQL](https://spark.apache.org/docs/latest/sql-ref.html). +/// +/// See . +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct SparkSqlDialect; + +impl Dialect for SparkSqlDialect { + // See https://spark.apache.org/docs/latest/sql-ref-identifier.html + fn is_delimited_identifier_start(&self, ch: char) -> bool { + matches!(ch, '`') + } + + fn is_identifier_start(&self, ch: char) -> bool { + matches!(ch, 'a'..='z' | 'A'..='Z' | '_') + } + + fn is_identifier_part(&self, ch: char) -> bool { + matches!(ch, 'a'..='z' | 'A'..='Z' | '0'..='9' | '_') + } + + /// See + fn supports_filter_during_aggregation(&self) -> bool { + true + } + + /// See + fn supports_group_by_expr(&self) -> bool { + true + } + + /// See + fn supports_group_by_with_modifier(&self) -> bool { + true + } + + /// See + fn supports_lambda_functions(&self) -> bool { + true + } + + /// See + fn supports_select_wildcard_except(&self) -> bool { + true + } + + /// See + fn supports_struct_literal(&self) -> bool { + true + } + + fn supports_nested_comments(&self) -> bool { + true + } + + /// See + fn supports_create_table_using(&self) -> bool { + true + } + + /// `LONG` is an alias for `BIGINT` in Spark SQL. + /// + /// See + fn supports_long_type_as_bigint(&self) -> bool { + true + } + + /// See + fn supports_values_as_table_factor(&self) -> bool { + true + } + + fn require_interval_qualifier(&self) -> bool { + true + } + + fn supports_bang_not_operator(&self) -> bool { + true + } + + fn supports_select_item_multi_column_alias(&self) -> bool { + true + } + + fn supports_cte_without_as(&self) -> bool { + true + } + + /// See + fn supports_map_literal_with_angle_brackets(&self) -> bool { + true + } + + /// Parse the `DIV` keyword as integer division. + /// + /// Example: `SELECT 10 DIV 3` returns `3`. + /// + /// See + fn parse_infix( + &self, + parser: &mut Parser, + expr: &Expr, + _precedence: u8, + ) -> Option> { + if parser.parse_keyword(Keyword::DIV) { + let left = Box::new(expr.clone()); + let right = Box::new(match parser.parse_expr() { + Ok(expr) => expr, + Err(e) => return Some(Err(e)), + }); + Some(Ok(Expr::BinaryOp { + left, + op: BinaryOperator::MyIntegerDivide, + right, + })) + } else { + None + } + } +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a5526723b..34b87fc76 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -8311,6 +8311,7 @@ impl<'a> Parser<'a> { Keyword::STORED, Keyword::LOCATION, Keyword::WITH, + Keyword::USING, ]) { Some(Keyword::ROW) => { hive_format @@ -8350,6 +8351,16 @@ impl<'a> Parser<'a> { break; } } + Some(Keyword::USING) if self.dialect.supports_create_table_using() => { + let format = self.parse_identifier()?; + hive_format.get_or_insert_with(HiveFormat::default).storage = + Some(HiveIOFormat::Using { format }); + } + Some(Keyword::USING) => { + // USING is not a table format keyword in this dialect; put it back + self.prev_token(); + break; + } None => break, _ => break, } @@ -12475,6 +12486,9 @@ impl<'a> Parser<'a> { Keyword::TINYBLOB => Ok(DataType::TinyBlob), Keyword::MEDIUMBLOB => Ok(DataType::MediumBlob), Keyword::LONGBLOB => Ok(DataType::LongBlob), + Keyword::LONG if self.dialect.supports_long_type_as_bigint() => { + Ok(DataType::BigInt(None)) + } Keyword::BYTES => Ok(DataType::Bytes(self.parse_optional_precision()?)), Keyword::BIT => { if self.parse_keyword(Keyword::VARYING) { @@ -12609,8 +12623,7 @@ impl<'a> Parser<'a> { let field_defs = self.parse_duckdb_struct_type_def()?; Ok(DataType::Struct(field_defs, StructBracketKind::Parentheses)) } - Keyword::STRUCT if dialect_is!(dialect is BigQueryDialect | DatabricksDialect | GenericDialect) => - { + Keyword::STRUCT if self.dialect.supports_struct_literal() => { self.prev_token(); let (field_defs, _trailing_bracket) = self.parse_struct_type_def(Self::parse_struct_field_def)?; @@ -12631,6 +12644,17 @@ impl<'a> Parser<'a> { Keyword::LOWCARDINALITY if dialect_is!(dialect is ClickHouseDialect | GenericDialect) => { Ok(self.parse_sub_type(DataType::LowCardinality)?) } + Keyword::MAP if self.dialect.supports_map_literal_with_angle_brackets() => { + self.expect_token(&Token::Lt)?; + let key_data_type = self.parse_data_type()?; + self.expect_token(&Token::Comma)?; + let (value_data_type, _trailing_bracket) = self.parse_data_type_helper()?; + trailing_bracket = self.expect_closing_angle_bracket(_trailing_bracket)?; + Ok(DataType::Map( + Box::new(key_data_type), + Box::new(value_data_type), + )) + } Keyword::MAP if dialect_is!(dialect is ClickHouseDialect | GenericDialect) => { self.prev_token(); let (key_data_type, value_data_type) = self.parse_click_house_map_def()?; diff --git a/tests/sqlparser_spark.rs b/tests/sqlparser_spark.rs new file mode 100644 index 000000000..3e1886c1d --- /dev/null +++ b/tests/sqlparser_spark.rs @@ -0,0 +1,329 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![warn(clippy::all)] +//! Test SQL syntax specific to Apache Spark SQL. + +use sqlparser::ast::*; +use sqlparser::dialect::SparkSqlDialect; +use test_utils::*; + +#[macro_use] +mod test_utils; + +fn spark() -> TestedDialects { + TestedDialects::new(vec![Box::new(SparkSqlDialect {})]) +} + +// -------------------------------- +// CREATE TABLE USING +// -------------------------------- + +#[test] +fn test_create_table_using() { + let stmt = spark().verified_stmt("CREATE TABLE t (i INT, s STRING) USING parquet"); + match stmt { + Statement::CreateTable(ct) => { + assert_eq!(ct.name.to_string(), "t"); + assert_eq!(ct.columns.len(), 2); + assert_eq!( + ct.hive_formats.unwrap().storage, + Some(HiveIOFormat::Using { + format: Ident::new("parquet") + }) + ); + } + _ => panic!("Expected CreateTable"), + } +} + +#[test] +fn test_create_table_using_if_not_exists() { + spark().verified_stmt("CREATE TABLE IF NOT EXISTS t (i INT) USING delta"); +} + +#[test] +fn test_create_table_using_with_location() { + spark().verified_stmt("CREATE TABLE t (i INT) USING parquet LOCATION '/data/t'"); +} + +#[test] +fn test_create_table_multi_column() { + spark().verified_stmt( + "CREATE TABLE t (i INT, l BIGINT, f FLOAT, d DOUBLE, s STRING, b BOOLEAN) USING parquet", + ); +} + +#[test] +fn test_create_table_long_type() { + // LONG is an alias for BIGINT; round-trips as BIGINT + spark().one_statement_parses_to( + "CREATE TABLE t (id LONG, val LONG) USING parquet", + "CREATE TABLE t (id BIGINT, val BIGINT) USING parquet", + ); +} + +#[test] +fn test_create_table_array_type() { + spark().verified_stmt("CREATE TABLE t (arr ARRAY) USING parquet"); +} + +#[test] +fn test_create_table_map_type() { + // MAP parses and stores as DataType::Map (which displays as Map(K, V)) + spark() + .parse_sql_statements("CREATE TABLE t (m MAP) USING parquet") + .unwrap(); +} + +#[test] +fn test_create_table_struct_type() { + // STRUCT field definitions drop the colon separator on round-trip + spark().one_statement_parses_to( + "CREATE TABLE t (s STRUCT) USING parquet", + "CREATE TABLE t (s STRUCT) USING parquet", + ); +} + +#[test] +fn test_create_table_nested_types() { + // Nested types parse successfully + spark() + .parse_sql_statements( + "CREATE TABLE t (arr ARRAY>) USING parquet", + ) + .unwrap(); + spark() + .parse_sql_statements("CREATE TABLE t (m MAP, arr ARRAY) USING parquet") + .unwrap(); +} + +#[test] +fn test_create_table_decimal_type() { + spark() + .verified_stmt("CREATE TABLE t (grp STRING, d DECIMAL(10,2), flag BOOLEAN) USING parquet"); +} + +// -------------------------------- +// INSERT INTO +// -------------------------------- + +#[test] +fn test_insert_values() { + spark().verified_stmt( + "INSERT INTO t VALUES (1, 'a'), (2, 'b'), (3, 'c'), (NULL, 'd'), (1, NULL), (NULL, NULL)", + ); +} + +#[test] +fn test_insert_values_multiline() { + // Multi-line whitespace is normalized to single-line on round-trip + spark().one_statement_parses_to( + "INSERT INTO t VALUES\n (1, 10, 'a'),\n (2, 20, 'a'),\n (3, 30, 'b')", + "INSERT INTO t VALUES (1, 10, 'a'), (2, 20, 'a'), (3, 30, 'b')", + ); +} + +// -------------------------------- +// Lambda expressions +// -------------------------------- + +#[test] +fn test_lambda_single_param() { + spark().verified_stmt("SELECT filter(arr, x -> x > 2) FROM t"); +} + +#[test] +fn test_lambda_two_params() { + spark().verified_stmt("SELECT filter(arr, (x, i) -> i > 0) FROM t"); +} + +#[test] +fn test_lambda_transform() { + spark().verified_stmt("SELECT transform(arr, x -> x * 2) FROM t"); +} + +// -------------------------------- +// DIV integer division +// -------------------------------- + +#[test] +fn test_div_operator() { + spark().one_statement_parses_to("SELECT c1 div c2 FROM t", "SELECT c1 DIV c2 FROM t"); +} + +#[test] +fn test_div_literal() { + spark().one_statement_parses_to("SELECT 10 div 3", "SELECT 10 DIV 3"); +} + +// -------------------------------- +// Struct support +// -------------------------------- + +#[test] +fn test_named_struct() { + spark().verified_stmt("SELECT named_struct('x', a, 'y', b, 'z', c) FROM t"); +} + +#[test] +fn test_struct_function() { + // Parses as a STRUCT literal; round-trips with uppercase STRUCT keyword + spark().one_statement_parses_to( + "SELECT struct(a, b, c) FROM t", + "SELECT STRUCT(a, b, c) FROM t", + ); +} + +// -------------------------------- +// Aggregate FILTER +// -------------------------------- + +#[test] +fn test_aggregate_filter() { + spark().verified_stmt( + "SELECT COUNT(*) FILTER (WHERE i > 0), SUM(val) FILTER (WHERE val IS NOT NULL) FROM t", + ); +} + +#[test] +fn test_aggregate_filter_with_group_by() { + spark().verified_stmt( + "SELECT grp, SUM(i) FILTER (WHERE flag = true) FROM t GROUP BY grp ORDER BY grp", + ); +} + +// -------------------------------- +// Window functions with IGNORE NULLS +// -------------------------------- + +#[test] +fn test_lag_ignore_nulls() { + spark().verified_stmt("SELECT LAG(val) IGNORE NULLS OVER (ORDER BY id) AS lag_val FROM t"); +} + +#[test] +fn test_lead_ignore_nulls() { + spark().verified_stmt( + "SELECT LEAD(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) AS lead_val FROM t", + ); +} + +#[test] +fn test_lag_with_offset_and_default() { + spark().verified_stmt("SELECT LAG(val, 2, -1) OVER (ORDER BY id) AS lag_val FROM t"); +} + +// -------------------------------- +// CASE WHEN +// -------------------------------- + +#[test] +fn test_case_when() { + spark().verified_stmt( + "SELECT CASE WHEN i = 1 THEN 'one' WHEN i = 2 THEN 'two' ELSE 'other' END FROM t", + ); +} + +#[test] +fn test_case_value() { + spark().verified_stmt("SELECT CASE i WHEN 1 THEN 'one' WHEN 2 THEN 'two' END FROM t"); +} + +// -------------------------------- +// CAST expressions +// -------------------------------- + +#[test] +fn test_cast_basic_types() { + // cast() lower-case round-trips as CAST() upper-case + spark().one_statement_parses_to( + "SELECT cast(i AS BIGINT), cast(i AS DOUBLE), cast(i AS STRING) FROM t", + "SELECT CAST(i AS BIGINT), CAST(i AS DOUBLE), CAST(i AS STRING) FROM t", + ); +} + +#[test] +fn test_cast_to_timestamp() { + spark().one_statement_parses_to( + "SELECT cast('2020-01-01' AS TIMESTAMP)", + "SELECT CAST('2020-01-01' AS TIMESTAMP)", + ); + spark().one_statement_parses_to( + "SELECT cast('2020-01-01T12:34:56' AS TIMESTAMP)", + "SELECT CAST('2020-01-01T12:34:56' AS TIMESTAMP)", + ); +} + +#[test] +fn test_cast_special_float_values() { + spark().one_statement_parses_to( + "SELECT cast('NaN' AS FLOAT), cast('Infinity' AS DOUBLE)", + "SELECT CAST('NaN' AS FLOAT), CAST('Infinity' AS DOUBLE)", + ); +} + +// -------------------------------- +// Aggregate functions +// -------------------------------- + +#[test] +fn test_count_aggregate() { + spark().verified_stmt("SELECT count(*), count(i), count(s) FROM t"); + spark().verified_stmt("SELECT grp, count(*), count(i) FROM t GROUP BY grp ORDER BY grp"); +} + +#[test] +fn test_sum_avg() { + spark().verified_stmt("SELECT avg(i), avg(l), avg(f), avg(d) FROM t"); +} + +#[test] +fn test_bit_aggregates() { + spark().verified_stmt("SELECT bit_and(i), bit_or(i), bit_xor(i) FROM t"); +} + +// -------------------------------- +// Arithmetic +// -------------------------------- + +#[test] +fn test_arithmetic_operators() { + spark().verified_stmt("SELECT a + b, a - b, a * b, a / b, a % b FROM t"); +} + +#[test] +fn test_unary_negative() { + spark().verified_stmt("SELECT negative(col1), -(col1) FROM t"); +} + +// -------------------------------- +// String operations +// -------------------------------- + +#[test] +fn test_like_pattern() { + spark().verified_stmt("SELECT s FROM t WHERE s LIKE 'foo%'"); +} + +#[test] +fn test_substring() { + spark().one_statement_parses_to( + "SELECT substring(s, 1, 3) FROM t", + "SELECT SUBSTRING(s, 1, 3) FROM t", + ); +}