【Numba】正确使用numba,让你的python代码原地起飞!
摘要: Python在计算密集型任务中性能较差,Numba提供了简洁高效的JIT编译解决方案。只需添加装饰器即可将Python代码编译为机器码,实现接近C语言的性能。文章介绍了Numba的安装、基础语法(@jit/@njit装饰器)、编译选项(缓存、并行计算等)以及支持的数据类型。测试显示Numba版本比纯Python快数十倍,且无需修改现有代码,是提升Python数值计算性能的理想工具。
前言
Python 因其简洁优雅的语法和丰富的生态系统而广受欢迎,但在计算密集型任务中,Python 的执行速度往往成为瓶颈。虽然我们可以使用 C/C++ 扩展或者 Cython 来提升性能,但这些方案的学习成本和开发复杂度都比较高。
Numba 的出现改变了这一切!它是一个针对 Python 的即时(JIT)编译器,能够将 Python 代码直接编译为机器码,实现接近 C 语言的执行速度。最关键的是,你几乎不需要修改现有的 Python 代码,只需要添加一个装饰器就能获得 10-100 倍的性能提升!
本文将全面介绍 Numba 的使用方法,从基础概念到高级技巧,帮助你掌握这个强大的性能优化工具。
1. Numba 简介和安装
1.1 什么是 Numba?
Numba 是一个开源的 JIT 编译器,它使用 LLVM 编译器库将 Python 函数编译为优化的机器码。Numba 专门针对 NumPy 数组和数值计算进行了优化,能够显著提升数值计算代码的性能。
Numba 的主要特点:
- 易于使用:只需添加装饰器,无需重写代码
- 高性能:能够实现接近 C 语言的执行速度
- 兼容性好:支持大部分 NumPy 功能和 Python 语法
- 自动优化:自动进行循环优化、向量化等
- 并行支持:支持 CPU 和 GPU 并行计算
1.2 安装 Numba
使用 pip 安装 Numba:
# 注意如果要使用numba 建议使用 python3.9或3.10
pip install numba
# 下面的版本实测不会产生依赖冲突
# numba==0.56.4
# numpy==1.23.5
# llvmlite==0.39.1
1.3 第一个 Numba 程序
让我们从一个简单的例子开始,体验 Numba 的威力:
import numpy as np
import time
from numba import jit
def python_sum(arr):
"""传统 Python 实现"""
total = 0.0
for i in range(len(arr)):
total += arr[i]
return total
@jit
def numba_sum(arr):
"""Numba 优化版本"""
total = 0.0
for i in range(len(arr)):
total += arr[i]
return total
# 测试性能
if __name__ == "__main__":
# 生成测试数据
data = np.random.random(1000000)
# 预热 Numba 函数
_ = numba_sum(data[:100])
# 测试 Python 版本
start_time = time.time()
result_python = python_sum(data)
python_time = time.time() - start_time
# 测试 Numba 版本
start_time = time.time()
result_numba = numba_sum(data)
numba_time = time.time() - start_time
print(f"Python 版本耗时: {python_time:.4f}秒")
print(f"Numba 版本耗时: {numba_time:.4f}秒")
print(f"性能提升: {python_time/(numba_time if numba_time > 0 else 0.0001):.1f}倍")
print(f"结果一致: {abs(result_python - result_numba) < 1e-10}")
运行这个例子,你会发现 Numba 版本比 Python 版本快了几十甚至上百倍!这就是 Numba 的魅力所在。
2. Numba 基础语法和装饰器
2.1 @jit 装饰器
@jit 是 Numba 最基本的装饰器,它会在函数首次调用时将 Python 代码编译为机器码。
from numba import jit
import numpy as np
@jit
def calculate_distance(x1, y1, x2, y2):
"""计算两点距离"""
return np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
# 使用示例
distance = calculate_distance(0, 0, 3, 4)
print(f"距离: {distance}") # 输出: 5.0
2.2 @njit 装饰器
@njit 是 @jit(nopython=True) 的简写,它强制 Numba 完全脱离 Python 解释器运行,通常能获得更好的性能:
import numba as nb
import numpy as np
import time
def pure_python_sum(arr: np.ndarray) -> float:
"""纯Python版本的数组求和"""
total = 0.0
for i in range(len(arr)):
total += arr[i]
return total
@nb.njit
def numba_sum(arr: np.ndarray) -> float:
"""Numba加速版本的数组求和"""
total = 0.0
for i in range(len(arr)):
total += arr[i]
return total
test_array = np.random.random(10000000) * 100
start_time = time.time()
result_python = pure_python_sum(test_array)
python_time = time.time() - start_time
start_time = time.time()
result_numba = numba_sum(test_array)
numba_time = time.time() - start_time
start_time = time.time()
result_numba2 = numba_sum(test_array)
numba_time2 = time.time() - start_time
print(f"纯Python耗时: {python_time:.6f}秒")
print(f"Numba首次调用(含编译): {numba_time:.6f}秒")
print(f"Numba第二次调用: {numba_time2:.6f}秒")
print(f"性能提升: {python_time / (numba_time2 if numba_time2 > 0 else 0.0001):.1f}倍")
print(f"结果一致性: {np.isclose(result_python, result_numba)}")
print(f"结果一致性: {np.isclose(result_python, result_numba2)}")

2.3 编译模式和选项
Numba 提供多种编译模式和选项来控制编译行为:
from numba import njit, prange
import numpy as np
# 缓存编译结果,避免重复编译
@njit(cache=True)
def cached_function(x):
return x * x + 2 * x + 1
# 指定函数签名,提前编译
@njit("float64[:](float64[:])")
def typed_function(x):
return np.sin(x) + np.cos(x)
# 启用并行计算(注意:使用numba并行计算, 循环需要使用prange, 原始的range不支持并行)
@njit(parallel=True)
def parallel_function(arr):
result = np.zeros_like(arr)
for i in prange(len(arr)):
result[i] = arr[i] ** 2 + arr[i] ** 0.5
return result
# 错误处理模式
@njit(error_model='numpy')
def error_safe_function(x):
return np.sqrt(x) # 对负数返回 NaN 而不是抛出异常
# 使用示例
x = np.linspace(-10, 10, 1000)
y1 = cached_function(x)
y2 = typed_function(x)
y3 = parallel_function(np.abs(x))
y4 = error_safe_function(x) # 包含负数,会产生 NaN
print(f"缓存函数结果: {y1[:5]}")
print(f"类型化函数结果: {y2[:5]}")
print(f"并行函数结果: {y3[:5]}")
print(f"错误安全函数结果: {y4[:5]}")
3. Numba 支持的数据类型和操作
3.1 支持的数据类型
Numba 支持大部分 NumPy 数据类型和 Python 基本类型:
from numba import njit, types
import numpy as np
@njit
def data_types_demo():
"""演示 Numba 支持的数据类型"""
# 基本数值类型
int_val = 42
float_val = 3.14
complex_val = 1.0 + 2.0j
bool_val = True
# NumPy 数组
int_array = np.array([1, 2, 3], dtype=np.int32)
float_array = np.array([1.0, 2.0, 3.0], dtype=np.float64)
bool_array = np.array([True, False, True])
# 多维数组
matrix = np.zeros((3, 3), dtype=np.float32)
matrix[0, 0] = 1.0
matrix[1, 1] = 1.0
matrix[2, 2] = 1.0
trace_val = 0.0
for i in range(matrix.shape[0]):
trace_val += matrix[i, i]
# 元组
coordinates = (1.0, 2.0, 3.0)
return (int_val, float_val, complex_val, bool_val,
int_array.sum(), float_array.mean(), bool_array.sum(),
trace_val, coordinates[0])
# 显式类型声明
@njit("Tuple((int64, float64, complex128))(float64[:])")
def explicit_types(arr):
"""使用显式类型声明"""
total = arr.sum()
mean = arr.mean()
complex_result = total + 1j * mean
return int(total), mean, complex_result
# 测试
result = data_types_demo()
print(f"数据类型演示结果: {result}")
arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
typed_result = explicit_types(arr)
print(f"显式类型结果: {typed_result}")
3.2 NumPy 函数支持
Numba 支持大量的 NumPy 函数和操作:
from numba import njit
import numpy as np
@njit
def numpy_functions_demo(x, y):
"""演示 Numba 支持的 NumPy 函数"""
# 数学函数
sin_x = np.sin(x)
cos_x = np.cos(x)
exp_x = np.exp(x)
log_x = np.log(np.abs(x) + 1e-10) # 避免 log(0)
# 统计函数
mean_val = np.mean(x)
std_val = np.std(x)
min_val = np.min(x)
max_val = np.max(x)
# 数组操作
sorted_x = np.sort(x)
unique_x = np.unique(x.astype(np.int32))
# 线性代数(部分支持)
dot_product = np.dot(x, y)
# 逻辑操作
mask = x > 0
positive_x = x[mask]
# 数组创建
zeros = np.zeros(10)
ones = np.ones(5)
arange = np.arange(0, 10, 2)
return {
'sin_mean': np.mean(sin_x),
'cos_std': np.std(cos_x),
'exp_max': np.max(exp_x),
'log_min': np.min(log_x),
'stats': (mean_val, std_val, min_val, max_val),
'sorted_first': sorted_x[0],
'unique_count': len(unique_x),
'dot_product': dot_product,
'positive_count': len(positive_x),
'created_arrays': (len(zeros), len(ones), len(arange))
}
# 注意:Numba 不支持返回字典,这里仅作演示
# 实际使用时应该返回元组或数组
@njit
def numpy_functions_practical(x, y):
"""实用的 NumPy 函数演示"""
# 数学运算
result1 = np.sqrt(x**2 + y**2)
result2 = np.arctan2(y, x)
# 统计分析
correlation = np.corrcoef(x, y)[0, 1]
# 数组处理
combined = np.concatenate((x, y))
reshaped = combined.reshape(-1, 1)
return result1.mean(), result2.std(), correlation, reshaped.shape[0]
# 测试
x = np.random.randn(1000)
y = np.random.randn(1000)
mean_dist, std_angle, corr, total_len = numpy_functions_practical(x, y)
print(f"平均距离: {mean_dist:.4f}")
print(f"角度标准差: {std_angle:.4f}")
print(f"相关系数: {corr:.4f}")
print(f"总长度: {total_len}")
4. 控制流和循环优化
4.1 循环优化
Numba 对循环进行了特殊优化,能够自动向量化和并行化循环:
from numba import njit, prange
import numpy as np
import time
@njit
def sequential_loop(arr):
"""顺序循环"""
result = np.zeros_like(arr)
for i in range(len(arr)):
result[i] = arr[i] ** 2 + np.sin(arr[i]) + np.cos(arr[i])
return result
@njit(parallel=True)
def parallel_loop(arr):
"""并行循环"""
result = np.zeros_like(arr)
for i in prange(len(arr)): # 使用 prange 启用并行
result[i] = arr[i] ** 2 + np.sin(arr[i]) + np.cos(arr[i])
return result
@njit
def nested_loops_optimization(matrix):
"""嵌套循环优化"""
rows, cols = matrix.shape
result = np.zeros_like(matrix)
# Numba 会自动优化这种嵌套循环
for i in range(rows):
for j in range(cols):
# 复杂的计算
val = matrix[i, j]
result[i, j] = val**3 - 2*val**2 + val + 1
return result
@njit
def loop_with_conditions(arr, threshold):
"""带条件的循环优化"""
count = 0
total = 0.0
for i in range(len(arr)):
if arr[i] > threshold:
total += arr[i]
count += 1
elif arr[i] < -threshold:
total -= arr[i]
count += 1
return total / count if count > 0 else 0.0
# 性能测试
def test_loop_performance():
"""测试循环性能"""
data = np.random.randn(1000000)
matrix = np.random.randn(1000, 1000)
# 预热
_ = sequential_loop(data[:100])
_ = parallel_loop(data[:100])
# 顺序循环测试
start_time = time.time()
result1 = sequential_loop(data)
seq_time = time.time() - start_time
# 并行循环测试
start_time = time.time()
result2 = parallel_loop(data)
par_time = time.time() - start_time
# 嵌套循环测试
start_time = time.time()
result3 = nested_loops_optimization(matrix)
nested_time = time.time() - start_time
# 条件循环测试
start_time = time.time()
result4 = loop_with_conditions(data, 0.5)
cond_time = time.time() - start_time
print(f"顺序循环耗时: {seq_time:.4f}秒")
print(f"并行循环耗时: {par_time:.4f}秒")
print(f"并行加速比: {seq_time/par_time:.2f}x")
print(f"嵌套循环耗时: {nested_time:.4f}秒")
print(f"条件循环耗时: {cond_time:.4f}秒")
print(f"条件循环结果: {result4:.4f}")
# 验证结果一致性
print(f"结果一致性: {np.allclose(result1, result2)}")
# 运行性能测试
test_loop_performance()
4.2 条件语句优化
Numba 能够高效处理条件语句和分支预测:
from numba import njit
import numpy as np
@njit
def conditional_optimization(x, y):
"""条件语句优化示例"""
result = np.zeros_like(x)
for i in range(len(x)):
if x[i] > 0 and y[i] > 0:
# 第一象限
result[i] = x[i] + y[i]
elif x[i] < 0 and y[i] > 0:
# 第二象限
result[i] = -x[i] + y[i]
elif x[i] < 0 and y[i] < 0:
# 第三象限
result[i] = -x[i] - y[i]
else:
# 第四象限
result[i] = x[i] - y[i]
return result
@njit
def vectorized_conditions(x, y):
"""向量化条件处理"""
# 使用 NumPy 的 where 函数进行向量化条件处理
quad1 = (x > 0) & (y > 0)
quad2 = (x < 0) & (y > 0)
quad3 = (x < 0) & (y < 0)
quad4 = ~(quad1 | quad2 | quad3)
result = np.zeros_like(x)
result = np.where(quad1, x + y, result)
result = np.where(quad2, -x + y, result)
result = np.where(quad3, -x - y, result)
result = np.where(quad4, x - y, result)
return result
@njit
def complex_branching(data, mode):
"""复杂分支逻辑"""
n = len(data)
result = np.zeros(n)
for i in range(n):
val = data[i]
if mode == 1:
if val > 0:
result[i] = np.sqrt(val)
else:
result[i] = 0
elif mode == 2:
if val > 1:
result[i] = np.log(val)
elif val > 0:
result[i] = val
else:
result[i] = -val
else:
result[i] = np.abs(val)
return result
# 测试条件语句优化
x = np.random.randn(100000)
y = np.random.randn(100000)
result1 = conditional_optimization(x, y)
result2 = vectorized_conditions(x, y)
print(f"条件优化结果一致性: {np.allclose(result1, result2)}")
# 测试复杂分支
data = np.random.randn(10000)
for mode in [1, 2, 3]:
result = complex_branching(data, mode)
print(f"模式 {mode} 处理完成,平均值: {np.mean(result):.4f}")
5. 并行计算和prange
5.1 基础并行计算
Numba 提供了简单易用的并行计算功能,通过 prange 可以轻松实现多线程并行:
from numba import njit, prange
import numpy as np
import time
@njit
def serial_computation(data):
"""串行计算"""
result = np.zeros_like(data)
for i in range(len(data)):
# 模拟复杂计算
temp = data[i]
for j in range(100):
temp = temp * 0.99 + 0.01 * np.sin(temp)
result[i] = temp
return result
@njit(parallel=True)
def parallel_computation(data):
"""并行计算"""
result = np.zeros_like(data)
for i in prange(len(data)): # 使用 prange 实现并行
# 相同的复杂计算
temp = data[i]
for j in range(100):
temp = temp * 0.99 + 0.01 * np.sin(temp)
result[i] = temp
return result
@njit(parallel=True)
def parallel_matrix_operations(matrix):
"""并行矩阵操作"""
rows, cols = matrix.shape
result = np.zeros_like(matrix)
# 并行处理每一行
for i in prange(rows):
for j in range(cols):
# 对每个元素进行复杂变换
val = matrix[i, j]
result[i, j] = np.exp(-val**2) * np.cos(val) + np.sin(val)
return result
@njit(parallel=True)
def parallel_reduction(data):
"""并行归约操作"""
n = len(data)
# 计算平方和
sum_squares = 0.0
for i in prange(n):
sum_squares += data[i] ** 2
return sum_squares
# 性能对比测试
def benchmark_parallel():
"""并行计算性能基准测试"""
data = np.random.randn(10000)
matrix = np.random.randn(500, 500)
# 预热函数
_ = serial_computation(data[:100])
_ = parallel_computation(data[:100])
print("=" * 50)
print("并行计算性能测试")
print("=" * 50)
# 测试一维数组处理
start = time.time()
result_serial = serial_computation(data)
serial_time = time.time() - start
start = time.time()
result_parallel = parallel_computation(data)
parallel_time = time.time() - start
print(f"一维数组处理:")
print(f" 串行耗时: {serial_time:.4f}秒")
print(f" 并行耗时: {parallel_time:.4f}秒")
print(f" 加速比: {serial_time/parallel_time:.2f}x")
print(f" 结果一致: {np.allclose(result_serial, result_parallel)}")
# 测试矩阵操作
start = time.time()
matrix_result = parallel_matrix_operations(matrix)
matrix_time = time.time() - start
print(f"\n矩阵操作:")
print(f" 并行耗时: {matrix_time:.4f}秒")
print(f" 处理速度: {matrix.size/matrix_time/1000:.1f}K 元素/秒")
# 测试归约操作
start = time.time()
sum_result = parallel_reduction(data)
reduction_time = time.time() - start
# 验证结果
expected_sum = np.sum(data**2)
print(f"\n归约操作:")
print(f" 并行耗时: {reduction_time:.6f}秒")
print(f" 结果验证: {abs(sum_result - expected_sum) < 1e-10}")
# 运行基准测试
benchmark_parallel()
6. JitClass - 类的编译优化
6.1 JitClass 基础
Numba 允许使用 @jitclass 装饰器来编译类,实现高性能的面向对象编程:
from numba import jitclass, njit, types
import numpy as np
# 定义类的数据结构
spec = [
('value', types.float64),
('data', types.float64[:]),
('size', types.int64)
]
@jitclass(spec)
class FastArray:
"""高性能数组类"""
def __init__(self, size):
self.size = size
self.data = np.zeros(size)
self.value = 0.0
def set_value(self, val):
"""设置标量值"""
self.value = val
def fill(self, val):
"""填充数组"""
for i in range(self.size):
self.data[i] = val
def add_scalar(self, val):
"""数组加标量"""
for i in range(self.size):
self.data[i] += val
def multiply_scalar(self, val):
"""数组乘标量"""
for i in range(self.size):
self.data[i] *= val
def dot_product(self, other):
"""计算与另一个 FastArray 的点积"""
if self.size != other.size:
return -1.0 # 错误标志
result = 0.0
for i in range(self.size):
result += self.data[i] * other.data[i]
return result
def norm(self):
"""计算向量的模长"""
sum_squares = 0.0
for i in range(self.size):
sum_squares += self.data[i] * self.data[i]
return np.sqrt(sum_squares)
def normalize(self):
"""归一化向量"""
norm_val = self.norm()
if norm_val > 1e-10:
for i in range(self.size):
self.data[i] /= norm_val
# 使用 JitClass 的函数
@njit
def vector_operations(size):
"""演示 JitClass 的使用"""
# 创建两个向量
vec1 = FastArray(size)
vec2 = FastArray(size)
# 初始化向量
vec1.fill(1.0)
vec2.fill(2.0)
# 执行操作
vec1.add_scalar(0.5) # vec1 现在是 [1.5, 1.5, ...]
vec2.multiply_scalar(1.5) # vec2 现在是 [3.0, 3.0, ...]
# 计算点积
dot = vec1.dot_product(vec2)
# 计算模长
norm1 = vec1.norm()
norm2 = vec2.norm()
# 归一化
vec1.normalize()
vec2.normalize()
# 归一化后的模长
norm1_after = vec1.norm()
norm2_after = vec2.norm()
return dot, norm1, norm2, norm1_after, norm2_after
# 测试 JitClass
size = 10000
dot, norm1, norm2, norm1_after, norm2_after = vector_operations(size)
print(f"向量尺寸: {size}")
print(f"点积: {dot}")
print(f"归一化前模长: vec1={norm1:.6f}, vec2={norm2:.6f}")
print(f"归一化后模长: vec1={norm1_after:.6f}, vec2={norm2_after:.6f}")
结语
Numba 是一个强大的 Python 性能优化工具,它让我们能够在保持 Python 简洁性的同时获得接近 C 语言的执行速度。经历大量尝试之后总结下面的“教训”:
- 优先考虑使用@nb.njit装饰器进行JIT编译(或者使用@nb.jit(nopython=True)
- 尽量使用原生数据类型(int, float等)
- 避免使用Python容器(list/dict),优先使用NumPy数组
- 循环密集型计算性能提升最明显, 将复杂算法分解为多个简单函数
- 不要在Numba函数中使用print()调试, 建议先用纯Python版本验证逻辑
- 适合数值计算,不适合字符串/对象操作
- 使用@jitclass必须显式声明所有属性的类型, 构造函数参数类型必须明确
- jitclass类方法中不能调用不支持的Python函数
Numba 让 Python 在数值计算领域真正实现了"原地起飞",希望本文能帮助你更好地掌握这个强大的工具,在你的项目中发挥其威力!
相关链接:
示例代码仓库:
本文所有示例代码都经过测试验证,你可以直接复制运行。建议在博文直接运行查看结果,或者也可以复制代码到本地逐个测试这些示例,以加深理解。
更多推荐
所有评论(0)