PySpark Groupby Explained with Examples

Spread the love

In this article, we will explore the concept of GroupBy in PySpark in-depth and understand its usage through practical examples.

Introduction to GroupBy in PySpark

GroupBy in PySpark is an operation that groups rows that have the same values in specified columns into aggregated data. It’s similar to the SQL GROUP BY statement and is crucial in data analysis for summarizing data.

The GroupBy operation can be used with aggregate functions like count(), avg(), min(), max(), and sum() to derive meaningful insights from the data. The syntax for the GroupBy operation is as follows:

df.groupBy("ColumnName").agg({"ColumnName": "aggregationFunction"})

Here, df is the DataFrame, ColumnName is the column on which the grouping is based, and aggregationFunction is the function to apply to the grouped data.

Let’s take a closer look at the different aggregate functions you can use with GroupBy in PySpark.


This function returns the number of rows for each group. This can be useful in identifying the number of occurrences of each value in a particular column. Here’s an example:



This function returns the sum of all values in a group. It can be used in conjunction with GroupBy to sum a certain column’s values for each group. Here’s how you use the sum() function:


Max() and Min()

These functions return the maximum and minimum values in a group, respectively. They can be used with GroupBy to find the highest or lowest value in a certain column for each group. Here’s how you use the max() and min() functions:


Avg() and Mean()

These functions return the average of all values in a group. They can be used with GroupBy to calculate the average of a certain column’s values for each group. Here’s how you use the avg() or mean() function:


Practical Example

To better understand the concept of GroupBy in PySpark, let’s consider a practical example. Assume we have a DataFrame representing sales data from a multinational company, which includes columns such as Region, Product, and Sales:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('GroupByExample').getOrCreate()

data = [("North America", "Apple", 1000), 
        ("North America", "Banana", 1500), 
        ("Europe", "Apple", 800), 
        ("Europe", "Banana", 1200), 
        ("Asia", "Apple", 600), 
        ("Asia", "Banana", 700)]

columns = ["Region", "Product", "Sales"]

df = spark.createDataFrame(data, columns)

This script will output:

|       Region|Product|Sales|
|North America|  Apple| 1000|
|North America| Banana| 1500|
|       Europe|  Apple|  800|
|       Europe| Banana| 1200|
|         Asia|  Apple|  600|
|         Asia| Banana|  700|

Now, let’s perform a GroupBy operation on the Region column and use the sum() function to calculate the total sales in each region:


This will give us:

|       Region|sum(Sales)|
|         Asia|      1300|
|North America|      2500|
|       Europe|      2000|

Here, we can see the total sales for each region. This sort of analysis is useful in understanding how sales differ across different regions.

Another example could be finding the average sales per product:



| Banana|   1133.33|
|  Apple|    800.00|

From the above, we can see that Bananas sell more on average than Apples.

PySpark also allows you to perform the GroupBy operation on multiple columns at the same time. It can be quite useful when you need to perform an aggregation on a combination of columns. The syntax for this operation is quite simple:

df.groupBy(["ColumnName1", "ColumnName2"]).agg({"AnotherColumnName": "aggregationFunction"})

You simply pass a list of column names to the groupBy() function.

Let’s expand on our previous sales data example to illustrate this. Assume we now also have a ‘Year’ column, and we want to analyze the total sales per product per year in each region:

data = [("North America", "Apple", 1000, "2022"), 
        ("North America", "Banana", 1500, "2022"), 
        ("Europe", "Apple", 800, "2022"), 
        ("Europe", "Banana", 1200, "2022"), 
        ("Asia", "Apple", 600, "2022"), 
        ("Asia", "Banana", 700, "2022"),
        ("North America", "Apple", 1100, "2023"), 
        ("North America", "Banana", 1600, "2023"), 
        ("Europe", "Apple", 850, "2023"), 
        ("Europe", "Banana", 1300, "2023"), 
        ("Asia", "Apple", 650, "2023"), 
        ("Asia", "Banana", 750, "2023")]

columns = ["Region", "Product", "Sales", "Year"]

df = spark.createDataFrame(data, columns)

Now, we can group the data by both Region and Year and calculate the total sales:

df.groupBy(["Region", "Year"]).sum("Sales").show()

The output will look like this:

|       Region|Year|sum(Sales)|
|         Asia|2023|      1400|
|North America|2023|      2700|
|       Europe|2023|      2150|
|         Asia|2022|      1300|
|North America|2022|      2500|
|       Europe|2022|      2000|

This output gives us the total sales for each region for each year, allowing us to understand how sales evolve over time in different regions.

Similarly, we could find out the average sales of each product per year in each region:

df.groupBy(["Product", "Year"]).avg("Sales").show()

In summary, grouping by multiple columns allows for more detailed data analysis by considering multiple factors at once. It’s a simple but powerful operation that can enhance the level of insights you derive from your data.


GroupBy is a very powerful operation in PySpark that helps us perform various aggregations over grouped data. It’s especially useful in data analysis where we often need to summarize data to derive meaningful insights. We’ve looked at several aggregate functions like count(), sum(), min(), max(), avg(), and mean() that can be used with GroupBy. Through practical examples, we’ve seen how to apply GroupBy to real-world data.

Mastering GroupBy and other data transformations in PySpark is a crucial skill when dealing with big data. The key is to practice different scenarios and use cases, which will also help you understand when to use which aggregation function.

Leave a Reply