Speed up your Python code 60x

Rust
Author

Adam Cseresznye

Published

June 7, 2024

Photo by NASA on UnSplash

Ever felt the need to turbo boost your Python code? Tired of the seemingly endless wait for your iterations or simulations to complete? Allow me to present the rustimport library. In its most basic form, rustimport empowers you to call highly optimized Rust functions from within your code. This way, you can delegate the heavy-duty tasks to Rust, instead of relying on Python, a language known for its versatility but also for its slower speed. With rustimport, you can enjoy the best of both worlds - the simplicity of Python and the performance of Rust.

Let me give you an example: imagine you’re tackling one of the problems from Project Euler, specifically problem #50. Here’s what the problem statement looks like:

The prime \(41\), can be written as the sum of six consecutive primes: \[41 = 2 + 3 + 5 + 7 + 11 + 13.\] This is the longest sum of consecutive primes that adds to a prime below one-hundred. The longest sum of consecutive primes below one-thousand that adds to a prime, contains \(21\) terms, and is equal to \(953\). Which prime, below one-million, can be written as the sum of the most consecutive primes?

While there are many ways to solve this problem, one approach (though not the fastest) is as follows:

import cProfile
import pstats
import io


def is_prime(n):
    if n <= 1:
        return False
    i = 2
    while i * i <= n:
        if n % i == 0:
            return False
        i += 1
    return True


def generate_primes(target):
    primes = [n for n in range(2, target + 1) if is_prime(n)]
    return primes


def longest_sum_of_consecutive_primes(primes, target):
    max_length = 0
    max_prime = 0

    for i in range(len(primes)):
        for j in range(i + max_length, len(primes)):
            sum_of_primes = sum(primes[i:j])
            if sum_of_primes > target:
                break
            if is_prime(sum_of_primes) and (j - i) > max_length:
                max_length = j - i
                max_prime = sum_of_primes

    return max_prime, max_length


def main():
    TARGET = 1_000_000
    primes = generate_primes(TARGET)
    max_prime, max_length = longest_sum_of_consecutive_primes(primes, TARGET)
    print(f"Prime: {max_prime}, Length: {max_length}")

Let’s dive in the code quickly:

Note: To evaluate the execution time and identify which functions take longer to run, we use cProfile. The statistics generated are then formatted using the pstats module. The complete code for all the examples can be found here.

    1157601 function calls in 5.517 seconds

       Ordered by: cumulative time

       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
            1    0.000    0.000    5.517    5.517 PE50_v1.py:36(main)
            1    0.000    0.000    5.025    5.025 PE50_v1.py:17(generate_primes)
            1    0.134    0.134    5.025    5.025 PE50_v1.py:18(<listcomp>)
      1000570    4.894    0.000    4.894    0.000 PE50_v1.py:7(is_prime)
            1    0.288    0.288    0.491    0.491 PE50_v1.py:21(longest_sum_of_consecutive_primes)
        78526    0.197    0.000    0.197    0.000 {built-in method builtins.sum}
        78499    0.004    0.000    0.004    0.000 {built-in method builtins.len}
            1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
            1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

As you can see, this straightforward task required over 5.5 seconds to execute. The primary function, PE50_v1.py:36(main), was invoked once and took a total of 5.517 seconds. The majority of this time was spent calling the function PE50_v1.py:17(generate_primes), which itself spent most of its time calling PE50_v1.py:7(is_prime) over a million times. The function PE50_v1.py:21(longest_sum_of_consecutive_primes) also consumed a significant amount of time. Additionally, the built-in functions sum and len were called multiple times.

Replacing the is_prime function

From the profiling results, it’s clear that is_prime is a prime candidate (pun intended) for replacement with a Rust equivalent. Let’s explore how we can achieve this.

Firstly, we need to install the rustimport library. This can be done using pip: pip install rustimport. Following this, we can create a Rust file, which we’ll name rs_extension.rs for this example. Now, let’s write some Rust code.

// rustimport:pyo3

use pyo3::prelude::*;

#[pyfunction]
fn is_prime(n: u32) -> bool {
    if n <= 1 {
        return false;
    }
    let mut i = 2;
    while i * i <= n {
        if n % i == 0 {
            return false;
        }
        i += 1;
    }
    true
}

Upon comparing the Python is_prime function with its Rust counterpart, you’ll notice some similarities. The key difference comes from Rust’s statically and strongly typed system, which requires us to specify the type of the input variables and return types. Not to worry, it’s simple; we anticipate an unsigned integer (positive integers only) and return a boolean value indicating whether or not the integer is prime. Let’s see what kind of speed up we get.

Warning

If your Rust code isn’t performing as expected, it could be because it was compiled in debug mode, which doesn’t enable optimizations. However, by running the python -m rustimport build --release command, you can generate Rust code that is highly optimized and significantly faster. Give it a try!`.

         1157601 function calls in 0.641 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.641    0.641 PE50_v1.py:30(main)
        1    0.197    0.197    0.391    0.391 PE50_v1.py:14(longest_sum_of_consecutive_primes)
        1    0.000    0.000    0.249    0.249 PE50_v1.py:9(generate_primes)
        1    0.089    0.089    0.249    0.249 PE50_v1.py:10(<listcomp>)
    78526    0.191    0.000    0.191    0.000 {built-in method builtins.sum}
  1000570    0.160    0.000    0.160    0.000 {built-in method rs_extension.is_prime}
    78499    0.004    0.000    0.004    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

🤯🤯🤯🤯 Our execution time went from 5.517 seconds to 0.641 seconds 🚀. That’s a whopping 8.6 times speed up, achieved by replacing just one function. The is_prime function now takes a total of 0.16 seconds to execute. But why stop here? We can push the boundaries even further.

Consider the longest_sum_of_consecutive_primes function. It performs a substantial amount of work, and if we could harness the power of Rust for this function, we could potentially slash the execution time even more. Let´s give it a try.

Replacing the longest_sum_of_consecutive_primes function

#[pyfunction]
fn longest_sum_of_consecutive_primes(primes: Vec<u32>, target: u32) -> (u32, usize) {
    let mut max_length: usize = 0;
    let mut max_prime: u32 = 0;

    for i in 0..primes.len() {
        for j in i + max_length..primes.len() {
            let sum_of_primes: u32 = primes[i..j].iter().sum();
            if sum_of_primes > target {
                break;
            }
            if is_prime(sum_of_primes) && (j - i) > max_length {
                max_length = j - i;
                max_prime = sum_of_primes;
            }
        }
    }
    (max_prime, max_length)
}

Now this Rust function looks quite a bit different from our original Python function, but it’s still possible to understand what’s going on even if you haven’t seen Rust code before. In this case, we’re supplying the function with a vector of prime numbers, each represented as u32 values, and a target number, which is 1000 in our case. The function then returns the solution to our problem: the number (max_prime) that can be expressed as the sum of the greatest number of consecutive primes, and the actual count of these consecutive primes (max-length). This is the result from our second optimization:

         1000005 function calls in 0.252 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.252    0.252 PE50_v2.py:14(main)
        1    0.000    0.000    0.248    0.248 PE50_v2.py:9(generate_primes)
        1    0.088    0.088    0.248    0.248 PE50_v2.py:10(<listcomp>)
   999999    0.160    0.000    0.160    0.000 {built-in method rs_extension.is_prime}
        1    0.004    0.004    0.004    0.004 {built-in method rs_extension.longest_sum_of_consecutive_primes}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

🤯🤯🤯🤯 Our execution time has now been reduced to just 0.252 seconds, a further improvement of 2.5 times. In total, we’ve achieved a staggering 21.5 times speedup compared to our original Python implementation. The longest_sum_of_consecutive_primes function, which previously took 0.391 seconds to execute, now completes in a mere 0.004 seconds. This is a significant achievement!

As you might have guessed, we’re not stopping here 😈. We still have one more Python function left to convert to Rust 🦀. Let’s do that and see how much more we can optimize our code!

Replacing the generate_primes function

#[pyfunction]
fn generate_primes(target: u32) -> Vec<u32> {
    (2..=target).filter(|&n| is_prime(n)).collect::<Vec<u32>>()
}

This Rust code, while seemingly simple, is quite powerful. It generates, or more accurately, filters out prime numbers up to a specified range. Given that Rust is heavily inspired by OCaml and Haskell, you may notice the influence of functional programming paradigms throughout the language (and in this example).

         5 function calls in 0.092 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.092    0.092 PE50_v3.py:9(main)
        1    0.088    0.088    0.088    0.088 {built-in method rs_extension.generate_primes}
        1    0.004    0.004    0.004    0.004 {built-in method rs_extension.longest_sum_of_consecutive_primes}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.print}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

😲😵‍💫🤯💥💥💥 And there you have it! 🎉 Rust did it again. The execution time has been slashed from 0.252 seconds to a mere 0.092 seconds, making it 2.7 times faster. The generate_primes function, which previously took 0.248 seconds, now completes in just 0.088 seconds. And thus our code went from taking 5.517 sec to just 0.092 sec…

Here is our final iteration:

import cProfile
import pstats
import io

# import rustimport.import_hook
import rs_extension


def main():
    TARGET = 1_000_000
    primes = rs_extension.generate_primes(TARGET)
    max_prime, max_length = rs_extension.longest_sum_of_consecutive_primes(
        primes, TARGET
    )
    print(f"Prime: {max_prime}, Length: {max_length}")
Code
import pandas as pd
import plotly.express as px

df = pd.DataFrame(
    {
        "step": [
            "original",
            "rewrite is_prime",
            "rewrite longest_sum_of_consecutive_primes",
            "rewrite generate_primes",
        ],
        "execution_time": [5.517, 0.641, 0.252, 0.092],
    }
)
fig = px.bar(
    df,
    x="execution_time",
    y="step",
    color="step",
    title="Turbocharging Performance: A 60x Speed Boost with Rust!",
    orientation="h",
    template="plotly_white",
)
fig.update_yaxes(title="", visible=True)
fig.update_xaxes(title="Execution time (s)")
fig.update_layout(showlegend=False)


fig.show()

And that’s a wrap, everyone! In this post, we’ve showed you how to use cProfile to profile your Python code, helping you pinpoint potential bottlenecks. We’ve also showcased the power of the rustimport library, which enables you to effortlessly invoke highly optimized Rust code straight from your Python script.

Have you had the chance to use these tools in your projects? Do you believe your work could benefit from Rust? I’d love to hear your thoughts. Until next time, happy coding!