Source code for pipeasy_spark.convenience

from pyspark.ml.feature import (
    StringIndexer,
    OneHotEncoderEstimator,
    VectorAssembler,
    StandardScaler,
)
from pyspark.sql.types import NumericType, StringType

from .core import build_pipeline


[docs]def build_pipeline_by_dtypes( dataframe, exclude_columns=tuple(), string_transformers=tuple(), numeric_transformers=tuple(), ): """Build simple transformation pipeline (untrained) for the given dataframe. Parameters ---------- dataframe: pyspark.sql.Dataframe only the schema of the dataframe is used, not actual data. exclude_columns: list of str name of columns for which we want no transformation to apply. string_transformers: list of transformer instances The successive transformations that will be applied to string columns Each element is an instance of a pyspark.ml.feature transformer class. numeric_transformers: list of transformer instances The successive transformations that will be applied to numeric columns Each element is an instance of a pyspark.ml.feature transformer class. Returns ------- pipeline: pyspark.ml.Pipeline instance (untrained) """ column_transformers = {} for field in dataframe.schema: column, dtype = field.name, field.dataType if column not in exclude_columns: if isinstance(dtype, NumericType): column_transformers[column] = numeric_transformers elif isinstance(dtype, StringType): column_transformers[column] = string_transformers return build_pipeline(column_transformers)
[docs]def build_default_pipeline(dataframe, exclude_columns=tuple()): """Build simple transformation pipeline (untrained) for the given dataframe. By defaults numeric columns are processed with StandardScaler and string columns are processed with StringIndexer + OneHotEncoderEstimator Parameters ---------- dataframe: pyspark.sql.Dataframe only the schema of the dataframe is used, not actual data. exclude_columns: list of str name of columns for which we want no transformation to apply. Returns ------- pipeline: pyspark.ml.Pipeline instance (untrained) """ return build_pipeline_by_dtypes( dataframe, exclude_columns=exclude_columns, string_transformers=(StringIndexer(), OneHotEncoderEstimator()), numeric_transformers=(VectorAssembler(), StandardScaler()), )