8000 Groupby apply spark 20396 wrap refactor by icexelloss · Pull Request #5 · icexelloss/spark · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Groupby apply spark 20396 wrap refactor #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: groupby-apply-SPARK-20396
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 52 additions & 6 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,9 +2030,10 @@ def map_values(col):


# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
def _wrap_function(sc, user_func, wrap_user_func, returnType):
func = wrap_user_func(user_func, returnType) if wrap_user_func else user_func
command = (func, returnType)

pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
Expand All @@ -2044,7 +2045,7 @@ class UserDefinedFunction(object):

.. versionadded:: 1.3
"""
def __init__(self, func, returnType, name=None, vectorized=False):
def __init__(self, func, returnType, name=None, vectorized=False, wrap_user_func=None):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
Expand All @@ -2059,6 +2060,13 @@ def __init__(self, func, returnType, name=None, vectorized=False):
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.vectorized = vectorized
# This is an optional field that allows wrapping of the user function and adapt
# its input/output to the format that serializer is expecting.
# This is mainly used for vectorized udf to adapt the serializer output
# to the user-defined function input, e.g, list(pandas.Series) -> pd.DataFrame
# This also allows output verification on user output before it gets passed
# to serializer, e.g., check the output has the same length as input
self.wrap_user_func = wrap_user_func

@property
def returnType(self):
Expand Down Expand Up @@ -2087,7 +2095,7 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

wrapped_func = _wrap_function(sc, self.func, self.returnType)
wrapped_func = _wrap_function(sc, self.func, self.wrap_user_func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.vectorized)
Expand All @@ -2100,7 +2108,10 @@ def __call__(self, *cols):

def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
Wrap this udf with a function and attach docstring from func.

The function returned should have the same public fields as this udf and
have a udf field that points back to this udf object.
"""

# It is possible for a callable instance without __name__ attribute or/and
Expand All @@ -2115,17 +2126,49 @@ def _wrapped(self):
def wrapper(*args):
return self(*args)


wrapper.__name__ = self._name
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
else self.func.__class__.__module__)

wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.vectorized = self.vectorized
wrapper.wrap_user_func = self.wrap_user_func

wrapper.udf = self

return wrapper


class VectorizedUserDefinedFunction(UserDefinedFunction):

def _default_wrap_user_func(func, returnType):
"""Default wrapper for vectorized user-defined function.
"""
def verify_result_length(*a):
result = func(*a)
if not hasattr(result, "__len__"):
raise TypeError("Return type of pandas_udf should be a Pandas.Series")
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
return result

return verify_result_length

def __init__(self, func, returnType, name=None, wrap_user_func=_default_wrap_user_func):
super(VectorizedUserDefinedFunction, self).__init__(
func=func, returnType=returnType, name=name,
vectorized=True, wrap_user_func=wrap_user_func)

def with_wrap_user_func(self, new_wrap_user_func):
"""Change the wrap_user_func and returns a new object
"""
return VectorizedUserDefinedFunction(
self.func, returnType=self.returnType, name=self._name,
wrap_user_func=new_wrap_user_func)

def _create_udf(f, returnType, vectorized):

def _udf(f, returnType=StringType(), vectorized=vectorized):
Expand All @@ -2137,7 +2180,10 @@ def _udf(f, returnType=StringType(), vectorized=vectorized):
"0-arg pandas_udfs are not supported. "
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
udf_obj = VectorizedUserDefinedFunction(f, returnType)
else:
udf_obj = UserDefinedFunction(f, returnType)

return udf_obj._wrapped()

# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
Expand Down
30 changes: 19 additions & 11 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,27 @@ def apply(self, udf):

df = self._df
func = udf.func
returnType = udf.returnType

# The python executors expects the function to take a list of pd.Series as input
# So we to create a wrapper function that turns that to a pd.DataFrame before passing
# down to the user function
return_type = udf.returnType
columns = df.columns

def wrapped(*cols):
import pandas as pd
return func(pd.concat(cols, axis=1, keys=columns))

wrapped_udf_obj = pandas_udf(wrapped, returnType)
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
def wrap_user_func(f, return_type):
# Verify the return type and number of columns in result
def verify_result_type(*series):
import pandas as pd
result = f(pd.concat(series, axis=1, keys=columns))
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of pandas_udf should be a Pandas.DataFrame")
if not len(result.columns) == len(return_type):
raise RuntimeError(
"Number of columns of the returned Pandas.DataFrame " \
"doesn't match specified schema. " \
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
return result

return verify_result_type

wrapped_udf = udf.udf.with_wrap_user_func(wrap_user_func)
udf_column = wrapped_udf(*[df[col] for col in columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.sql_ctx)

Expand Down
37 changes: 6 additions & 31 deletions python/pyspark/worker.py
39ED
Original file line number Diff line number Diff line change
Expand Up @@ -74,47 +74,22 @@ def wrap_udf(f, return_type):


def wrap_pandas_udf(f, return_type):
# If the return_type is a StructType, it indicates this is a groupby apply udf,
# otherwise, it's a vectorized column udf.
# We can distinguish these two by return type because in groupby apply, we always specify
# returnType as a StructType, and in vectorized column udf, StructType is not supported.
#
# TODO: This logic is a bit hacky and might not work for future pandas udfs. Need refactoring.
if isinstance(return_type, StructType):
arrow_return_types = [to_arrow_type(field.dataType) for field in return_type]

# Verify the return type and number of columns in result
def verify_result_type(*a):
def attach_datatype(*a):
import pandas as pd
result = f(*a)
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be a "
"Pandas.DataFrame")
if not len(result.columns) == len(arrow_return_types):
raise RuntimeError(
"Number of columns of the returned Pandas.DataFrame " \
"doesn't match specified schema. " \
"Expected: {} Actual: {}".format(len(arrow_return_types), len(result.columns)))

return [(result[result.columns[i]], arrow_return_types[i])
for i in range(len(arrow_return_types))]

return verify_result_type

return attach_datatype
else:
arrow_return_type = to_arrow_type(return_type)

def verify_result_length(*a):
def attach_datatype(*a):
result = f(*a)
if not hasattr(result, "__len__"):
raise TypeError("Return type of the user-defined functon should be a "
"Pandas.Series")
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
return result

return lambda *a: (verify_result_length(*a), arrow_return_type)
return (result, arrow_return_type)

return attach_datatype


def read_single_udf(pickleSer, infile, eval_type):
Expand Down
0