Ad

How Can I Create Spark Udf For Interpolation Of Float To INT And How Can I Write Better Logic Than I Have Done

Below is my Spark Dataframe I want to do interpolation and write a Spark UDF for this I am not sure how can I write better logic and create a UDF from above

This is for converting Position_float and interpolate it to integer for converting Position to appropriate Integer Value

def dirty_fill(df, id_col, y_cols):
    from pyspark.sql import types as T
    df = df.withColumn('position_plus', (df.position_float + 0.5).cast(T.IntegerType()))
    df = df.withColumn('position_minus', (df.position_float - 0.5).cast(T.IntegerType()))
    df = df.withColumn('position', df.position_float.cast(T.IntegerType()))
    df1 = df.select([id_col, 'position_plus'] + y_cols).withColumnRenamed('position_plus', 'position')
    df2 = df.select([id_col, 'position_minus'] + y_cols).withColumnRenamed('position_minus', 'position')
    df3 = df.select([id_col, 'position'] + y_cols)
    df123 = df1.union(df2).union(df3).sort([id_col, 'position']).dropDuplicates([id_col, 'position'])
    return df123
y_cols = ['entry_temperature']
finish_mill_entry_filled = dirty_fill(finish_mill_entry, 'finish_mill_id', y_cols)

This is my dataframe sample

| Finishing_mill_id  | Sample  | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529         | 1       | 0.000000       | 1986.0     |
| 2015418529         | 2       | 2.192982       | 1997.0     |
| 2015418529         | 3       | 4.385965       | 2003.0     |
| 2018171498         | 445     | 495.535714     | 1643.0     |
| 2018171498         | 446     | 496.651786     | 1734.0     |
| 2018171498         | 447     | 497.767857     | 1748.0     |
| 2018171498         | 448     | 498.883929     | 1755.0     |

I need to interpolate float to integer

What I want is

| Finishing_mill_id  | Sample  | Position_float | Entry_Temp |
|--------------------|---------|----------------|------------|
| 2015418529         | 1       | 0              | 1986.0     |
| 2015418529         | 2       | 1              | 1986       |
| 2015418529         | 3       | 2              | 1997.0     |
| 2015418529         | 4       | 3              | 1997       |
| 2015418529         | 5       | 4              | 2003.0     |
| 2018171498         | 445     | 496            | 1643.0     |
| 2018171498         | 446     | 497            | 1734.0     |
| 2018171498         | 447     | 498            | 1748.0     |
| 2018171498         | 448     | 499            | 1755.0     |

I need a spark user_defined function to do this and there shouldn't be any data points missed as I have Position_float in range of 0-500 I also need to take care there are every points without missing any point. Need to modify my interpolation logic in proper way

To make it little clear say I have position 0.000 2.19 but I dont have datapaoint for that but what I need when I do I need to have position for 1.00 ..I need value for position 1.00 even if data is not there sort of linear interpolation .I hope it helps

Ad

Answer

1. Window functions

You can use window functions to fill the gaps and interpolate the values.

Let's start with a sample dataframe:

import pyspark.sql.functions as psf
import pyspark.sql.types as pst
from pyspark.sql import Window
import numpy as np

df = spark.createDataFrame(
        [[float(t)/10., float(v)] for t, v in zip(np.random.randint(0, 1000, 20), np.random.randint(100, 200, 20))], 
        schema=pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position', 'value']])) \
    .withColumn('position_round', psf.round('position'))

        +--------+-----+--------------+
        |position|value|position_round|
        +--------+-----+--------------+
        |    68.5|121.0|          69.0|
        |    76.3|126.0|          76.0|
        |    88.3|150.0|          88.0|
        |    59.0|197.0|          59.0|
        |    20.7|119.0|          21.0|
        |     0.1|167.0|           0.0|
        |    20.1|177.0|          20.0|
        |    81.9|199.0|          82.0|
        |    63.6|163.0|          64.0|
        |    32.4|115.0|          32.0|
        |    43.6|130.0|          44.0|
        |    11.9|175.0|          12.0|
        |    68.2|176.0|          68.0|
        |    28.9|184.0|          29.0|
        |    46.3|199.0|          46.0|
        |     9.7|155.0|          10.0|
        |    57.8|163.0|          58.0|
        |    83.6|173.0|          84.0|
        |    16.2|169.0|          16.0|
        |    87.1|127.0|          87.0|
        +--------+-----+--------------+

In order to fill the gaps we'll create a range of integers:

start, end = list(df.agg(psf.min('position_round'), psf.max('position_round')).collect()[0])
pos_df = spark.range(start=start, end=end, step=1) \
    .withColumnRenamed('id', 'position_round')

Now we can join the two dataframes:

w1 = Window.orderBy('position_round')
w2 = Window.partitionBy('group').orderBy('position_round')

df_resample = df \
    .select(
        '*', 
        psf.lead('position_round', 1).over(w1).alias('next_position'), 
        psf.lead('value', 1).over(w1).alias('next_value')) \
    .join(pos_df, on='position_round', how='right') \
    .withColumn('group', psf.sum((~psf.isnull('position')).cast('int')).over(w1)) \
    .select(
        '*', 
        (psf.row_number().over(w2) - 1).alias('i'), 
        psf.first(psf.col('next_position') - psf.col('position_round')).over(w2).alias('dx'), 
        psf.first('value').over(w2).alias('value0'), 
        psf.first(psf.col('next_value') - psf.col('value')).over(w2).alias('dy')) \
    .withColumn(
        'value_round', 
        psf.when((psf.col('dx') > 0) | psf.isnull('next_value'), psf.col('value0') + psf.col('i') * psf.col('dy') / psf.col('dx')) \
            .otherwise(psf.col('value')))
  • The first window function is to store next_value and next_position to later be able to compute our dx and dy
  • We then need to identify each gaps with a distinct group id so that we can interpolate the values for each distinct linear segment
  • last but not least we bring together all the elements that we need:
    • length of gap: dx
    • delta in values: dy
    • current row index in gap i

We can now compute value_round, the interpolation of value at position position_round

        +--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
        |position_round|position|value|next_position|next_value|group|  i|  dx|value0|   dy|value_round|
        +--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+
        |             0|     0.1|167.0|         10.0|     155.0|    1|  0|10.0| 167.0|-12.0|      167.0|
        |             1|    null| null|         null|      null|    1|  1|10.0| 167.0|-12.0|      165.8|
        |             2|    null| null|         null|      null|    1|  2|10.0| 167.0|-12.0|      164.6|
        |             3|    null| null|         null|      null|    1|  3|10.0| 167.0|-12.0|      163.4|
        |             4|    null| null|         null|      null|    1|  4|10.0| 167.0|-12.0|      162.2|
        |             5|    null| null|         null|      null|    1|  5|10.0| 167.0|-12.0|      161.0|
        |             6|    null| null|         null|      null|    1|  6|10.0| 167.0|-12.0|      159.8|
        |             7|    null| null|         null|      null|    1|  7|10.0| 167.0|-12.0|      158.6|
        |             8|    null| null|         null|      null|    1|  8|10.0| 167.0|-12.0|      157.4|
        |             9|    null| null|         null|      null|    1|  9|10.0| 167.0|-12.0|      156.2|
        |            10|     9.7|155.0|         12.0|     175.0|    2|  0| 2.0| 155.0| 20.0|      155.0|
        |            11|    null| null|         null|      null|    2|  1| 2.0| 155.0| 20.0|      165.0|
        |            12|    11.9|175.0|         16.0|     169.0|    3|  0| 4.0| 175.0| -6.0|      175.0|
        |            13|    null| null|         null|      null|    3|  1| 4.0| 175.0| -6.0|      173.5|
        |            14|    null| null|         null|      null|    3|  2| 4.0| 175.0| -6.0|      172.0|
        |            15|    null| null|         null|      null|    3|  3| 4.0| 175.0| -6.0|      170.5|
        |            16|    16.2|169.0|         20.0|     177.0|    4|  0| 4.0| 169.0|  8.0|      169.0|
        |            17|    null| null|         null|      null|    4|  1| 4.0| 169.0|  8.0|      171.0|
        |            18|    null| null|         null|      null|    4|  2| 4.0| 169.0|  8.0|      173.0|
        |            19|    null| null|         null|      null|    4|  3| 4.0| 169.0|  8.0|      175.0|
        +--------------+--------+-----+-------------+----------+-----+---+----+------+-----+-----------+

2. UDF

If you don't want to use window functions you can write a UDF to do the interpolation in python and then return an array of (position, value) tuples:

def interpolate(pos, next_pos, value, next_value):
    if pos == next_pos or next_value is None:
        return [(pos, value)]
    return [[pos + i, value + i * (next_value - value) / (next_pos - pos)] for i in range(int(next_pos - pos))]
interpolate_udf = psf.udf(interpolate, pst.ArrayType(pst.StructType([pst.StructField(c, pst.FloatType()) for c in ['position_round', 'value_round']])))

Note that the tuples are of type StructType to make it easier to "flatten" the tuples into columns.

w1 = Window.orderBy('position_round')
df_udf = df \
    .select(
        '*', 
        psf.lead('position_round', 1).over(w1).alias('next_position'), 
        psf.lead('value', 1).over(w1).alias('next_value')) \
    .withColumn('tmp', psf.explode(interpolate_udf('position_round', 'next_position', 'value', 'next_value'))) \
    .select('*', 'tmp.*').drop('tmp')

Here is what we get:

        +--------+-----+--------------+-------------+----------+--------------+----------+
        |position|value|position_round|next_position|next_value|position_round|value_round|
        +--------+-----+--------------+-------------+----------+--------------+----------+
        |     0.1|167.0|           0.0|         10.0|     155.0|           0.0|     167.0|
        |     0.1|167.0|           0.0|         10.0|     155.0|           1.0|     165.8|
        |     0.1|167.0|           0.0|         10.0|     155.0|           2.0|     164.6|
        |     0.1|167.0|           0.0|         10.0|     155.0|           3.0|     163.4|
        |     0.1|167.0|           0.0|         10.0|     155.0|           4.0|     162.2|
        |     0.1|167.0|           0.0|         10.0|     155.0|           5.0|     161.0|
        |     0.1|167.0|           0.0|         10.0|     155.0|           6.0|     159.8|
        |     0.1|167.0|           0.0|         10.0|     155.0|           7.0|     158.6|
        |     0.1|167.0|           0.0|         10.0|     155.0|           8.0|     157.4|
        |     0.1|167.0|           0.0|         10.0|     155.0|           9.0|     156.2|
        |     9.7|155.0|          10.0|         12.0|     175.0|          10.0|     155.0|
        |     9.7|155.0|          10.0|         12.0|     175.0|          11.0|     165.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          12.0|     175.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          13.0|     173.5|
        |    11.9|175.0|          12.0|         16.0|     169.0|          14.0|     172.0|
        |    11.9|175.0|          12.0|         16.0|     169.0|          15.0|     170.5|
        |    16.2|169.0|          16.0|         20.0|     177.0|          16.0|     169.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          17.0|     171.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          18.0|     173.0|
        |    16.2|169.0|          16.0|         20.0|     177.0|          19.0|     175.0|
        +--------+-----+--------------+-------------+----------+--------------+----------+
Ad
source: stackoverflow.com
Ad