小编典典

在PySpark ML中创建自定义变压器

python

我是Spark SQL
DataFrames和ML(PySpark)的新手。如何创建自定义令牌生成器,例如删除停用词并使用nltk中的某些库?我可以扩展默认值吗?


阅读 217

收藏
2020-12-20

共1个答案

小编典典

我可以扩展默认值吗?

并不是的。DefaultTokenizer是的子类,pyspark.ml.wrapper.JavaTransformer并且与的其他转换器和估计器相同pyspark.ml.feature,将实际处理委托给它的Scala对等方。由于要使用Python,因此应pyspark.ml.pipeline.Transformer直接扩展。

import nltk

from pyspark import keyword_only  ## < 2.0 -> pyspark.ml.util.keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
# Available in PySpark >= 2.3.0 
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable  
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType

class NLTKWordPunctTokenizer(
        Transformer, HasInputCol, HasOutputCol,
        # Credits https://stackoverflow.com/a/52467470
        # by https://stackoverflow.com/users/234944/benjamin-manns
        DefaultParamsReadable, DefaultParamsWritable):

    stopwords = Param(Params._dummy(), "stopwords", "stopwords",
                      typeConverter=TypeConverters.toListString)


    @keyword_only
    def __init__(self, inputCol=None, outputCol=None, stopwords=None):
        super(NLTKWordPunctTokenizer, self).__init__()
        self.stopwords = Param(self, "stopwords", "")
        self._setDefault(stopwords=[])
        kwargs = self._input_kwargs
        self.setParams(**kwargs)

    @keyword_only
    def setParams(self, inputCol=None, outputCol=None, stopwords=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setStopwords(self, value):
        return self._set(stopwords=list(value))

    def getStopwords(self):
        return self.getOrDefault(self.stopwords)

    # Required in Spark >= 3.0
    def setInputCol(self, value):
        """
        Sets the value of :py:attr:`inputCol`.
        """
        return self._set(inputCol=value)

    # Required in Spark >= 3.0
    def setOutputCol(self, value):
        """
        Sets the value of :py:attr:`outputCol`.
        """
        return self._set(outputCol=value)

    def _transform(self, dataset):
        stopwords = set(self.getStopwords())

        def f(s):
            tokens = nltk.tokenize.wordpunct_tokenize(s)
            return [t for t in tokens if t.lower() not in stopwords]

        t = ArrayType(StringType())
        out_col = self.getOutputCol()
        in_col = dataset[self.getInputCol()]
        return dataset.withColumn(out_col, udf(f, t)(in_col))

用法示例(来自ML的数据-功能):

sentenceDataFrame = spark.createDataFrame([
  (0, "Hi I heard about Spark"),
  (0, "I wish Java could use case classes"),
  (1, "Logistic regression models are neat")
], ["label", "sentence"])

tokenizer = NLTKWordPunctTokenizer(
    inputCol="sentence", outputCol="words",  
    stopwords=nltk.corpus.stopwords.words('english'))

tokenizer.transform(sentenceDataFrame).show()

对于自定义Python,Estimator请参见如何在PySpark
mllib中滚动自定义估算器

answer此答案取决于内部API,并且与Spark
2.0.3、2.1.1、2.2.0或更高版本(SPARK-19348)兼容。有关与以前的Spark版本兼容的代码,请参见修订版8

2020-12-20