import numpy as np
from numpy import fft
from scipy.linalg import dft
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import tensorflow as tf
import time  # 已移至顶部
import sys  # 已移至顶部
import warnings  # 新增:用于忽略警告

# --- 忽略所有警告 ---
warnings.filterwarnings('ignore')

# --- Matplotlib 中文显示配置 ---
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']  # 使用微软雅黑字体
# 解决负号显示问题
plt.rcParams['axes.unicode_minus'] = False  # 处理负号显示异常


# 预处理复数数据,将复数分解为实部和虚部
def preprocess_complex_data(data):
    data_real = np.real(data)  # 提取实部
    data_imag = np.imag(data)  # 提取虚部
    data_combined = np.stack((data_real, data_imag), axis=-1)  # 合并为实部和虚部两个通道
    return data_combined


# 后处理复数数据,将实部和虚部重新组合为复数
def postprocess_complex_data(data):
    data = data.reshape(64, 2)  # 重塑为64x2的数组
    return data[:, 0] + 1j * data[:, 1]  # 重组为复数形式


# 16-QAM映射表,定义16种符号对应的复数点
mapping_table = {
    (0, 0, 0, 0): -3 - 3j,
    (0, 0, 0, 1): -3 - 1j,
    (0, 0, 1, 0): -3 + 3j,
    (0, 0, 1, 1): -3 + 1j,
    (0, 1, 0, 0): -1 - 3j,
    (0, 1, 0, 1): -1 - 1j,
    (0, 1, 1, 0): -1 + 3j,
    (0, 1, 1, 1): -1 + 1j,
    (1, 0, 0, 0): 3 - 3j,
    (1, 0, 0, 1): 3 - 1j,
    (1, 0, 1, 0): 3 + 3j,
    (1, 0, 1, 1): 3 + 1j,
    (1, 1, 0, 0): 1 - 3j,
    (1, 1, 0, 1): 1 - 1j,
    (1, 1, 1, 0): 1 + 3j,
    (1, 1, 1, 1): 1 + 1j
}

# 解映射表,反转映射表以便解调
demapping_table = {v: k for k, v in mapping_table.items()}


# 生成DFT矩阵,用于频域变换
def DFT_matrix(N):
    i, j = np.meshgrid(np.arange(N), np.arange(N))  # 生成网格
    omega = np.exp(-2 * np.pi * 1j / N)  # 计算DFT的旋转因子
    W = np.power(omega, i * j) / np.sqrt(N)  # 生成DFT矩阵
    return W


# 最小二乘(LS)信道估计
def LS_estimator(OFDM_freq, pilotValue, pilotCarriers):
    H_LS = OFDM_freq[pilotCarriers] / pilotValue  # 在导频位置进行LS估计
    return H_LS


# 最小均方误差(MMSE)信道估计(修复版)
def MMSE_estimator(OFDM_freq, pilotValue, N, P, sigma2, avg_power, L, pilotCarriers):
    """
    改进的MMSE估计器
    avg_power: 信道抽头的平均功率
    L: 信道抽头数
    """
    # 导频位置的LS估计
    H_LS = OFDM_freq[pilotCarriers] / pilotValue

    # 构造信道的频域协方差矩阵
    # 基于时域功率延迟谱构造
    R_hh = np.zeros((N, N), dtype=complex)
    for i in range(N):
        for j in range(N):
            for l in range(L):
                R_hh[i, j] += avg_power[l] * np.exp(-2j * np.pi * l * (i - j) / N)

    # 提取导频位置的协方差矩阵
    R_HH = R_hh[np.ix_(pilotCarriers, pilotCarriers)]

    # 噪声协方差矩阵
    pilot_power = np.abs(pilotValue) ** 2
    R_nn = (sigma2 / pilot_power) * np.eye(P)

    # MMSE估计公式: H_MMSE = R_HH * (R_HH + R_nn)^(-1) * H_LS
    try:
        H_MMSE = R_HH @ np.linalg.inv(R_HH + R_nn) @ H_LS
    except:
        # 如果矩阵奇异,使用伪逆
        H_MMSE = R_HH @ np.linalg.pinv(R_HH + R_nn) @ H_LS

    return H_MMSE


# 16-QAM解调函数
def Demapping(QAM):
    constellation = np.array([x for x in demapping_table.keys()])  # 星座点数组
    dists = abs(QAM.reshape((-1, 1)) - constellation.reshape((1, -1)))  # 计算接收点到星座点的距离
    const_index = dists.argmin(axis=1)  # 找到最近的星座点索引
    hardDecision = constellation[const_index]  # 获取对应的星座点
    return np.vstack([demapping_table[C] for C in hardDecision]), hardDecision  # 返回解调后的比特和硬判决点


# OFDM仿真函数
# 完整的、修正后的 OFDM_Simulation 函数

def OFDM_Simulation(interpolation='Linear', estimator='LS', plot_H=False):
    M, N = 16, 64  # 调制阶数和子载波数
    m = int(np.log2(M))  # 每符号的比特数
    CP = 16  # 循环前缀长度
    P = N // 8  # 导频数量
    L = 3  # 信道抽头数
    pilotValue = 3 + 3j  # 导频值
    SNR_db = np.arange(0, 31, 5)  # SNR范围(0到30dB,步长5dB)
    BER = np.zeros_like(SNR_db, dtype='float32')  # 初始化BER数组
    MSE = np.zeros_like(SNR_db, dtype='float32')  # 初始化MSE数组

    allCarriers = np.arange(N)  # 所有子载波索引
    pilotCarriers = allCarriers[::N // P]  # 每隔N/P个子载波插入一个导频
    pilotCarriers = np.hstack([pilotCarriers, np.array([allCarriers[-1]])])  # 最后一个子载波也作为导频
    P = P + 1  # 更新导频数量
    effective_N = N - P  # 有效数据子载波数
    dataCarriers = np.delete(allCarriers, pilotCarriers)  # 数据子载波索引

    # CNN模型加载(根据SNR选择合适的模型)
    cnn_models = {}
    if estimator == 'CNN':
        try:
            print(f"        正在加载CNN模型...", end='')
            sys.stdout.flush()
            cnn_models[10] = tf.keras.models.load_model('model_weights_10.h5', compile=False)
            cnn_models[20] = tf.keras.models.load_model('model_weights_20_new.h5', compile=False)
            cnn_models[30] = tf.keras.models.load_model('model_weights_30.h5', compile=False)
            print(" ✓")
        except Exception as e:
            # 警告已被 warnings.filterwarnings('ignore') 抑制
            print(f" ✗ 警告: CNN模型加载失败 ({str(e)})")
            print("        将跳过CNN估计器")
            return SNR_db, BER * np.nan, MSE * np.nan

    for idx, snr_db in enumerate(SNR_db):
        print(f"        SNR={snr_db}dB [{idx + 1}/{len(SNR_db)}]", end='')
        sys.stdout.flush()
        blocks = 100  # 仿真块数
        error = 0  # 初始化误比特计数
        index = snr_db // 5  # SNR索引(步长为5)

        for t in range(blocks):
            info_bits = np.random.randint(2, size=(1, effective_N * m))  # 生成随机信息比特
            info_bits_blocks = info_bits.reshape((-1, m))  # 重塑为每组m位
            data = np.zeros((effective_N, 1), dtype='complex64')  # 初始化数据数组
            for i in range(effective_N):
                data[i] = mapping_table[tuple(info_bits_blocks[i])]  # 映射到16-QAM符号

            OFDM_data = np.zeros((N, 1), dtype='complex64')  # 初始化OFDM数据
            OFDM_data[pilotCarriers] = pilotValue  # 插入导频
            OFDM_data[dataCarriers] = data  # 插入数据

            OFDM_time = np.fft.ifft(OFDM_data.ravel()).reshape(-1, 1)  # IFFT变换到时域
            cp = OFDM_time[-CP:]  # 提取循环前缀
            OFDM_time_CP = np.vstack([cp, OFDM_time])  # 添加循环前缀

            # 信道模型
            avg_power = np.array([0.3, 0.8, 0.2])  # 信道平均功率
            channelResponse = np.array([1, 0, 0.3 + 0.3j])  # 固定信道响应
            channelResponse = channelResponse.reshape(1, -1)  # 重塑信道响应
            H_exact = np.fft.fft(channelResponse, N).reshape(-1, 1)  # 频域信道响应

            output = np.convolve(OFDM_time_CP.ravel(), channelResponse.ravel())[:N + CP]  # 信道卷积
            signal_power = np.mean(abs(output ** 2))  # 计算信号功率
            sigma2 = signal_power * 10 ** (-snr_db / 10)  # 计算噪声功率
            noise = np.sqrt(sigma2 / 2) * (
                    np.random.randn(*output.shape) + 1j * np.random.randn(*output.shape))  # 生成复数噪声
            OFDM_RX = output + noise  # 添加噪声

            OFDM_RX_noCP = OFDM_RX[CP:(CP + N)]  # 移除循环前缀
            OFDM_freq = np.fft.fft(OFDM_RX_noCP)  # FFT变换到频域

            # 信道估计
            if estimator == 'LS':
                H_est_pilots = LS_estimator(OFDM_freq, pilotValue, pilotCarriers)  # LS估计
            elif estimator == 'MMSE':
                H_est_pilots = MMSE_estimator(OFDM_freq, pilotValue, N, P, sigma2, avg_power, L,
                                              pilotCarriers)  # MMSE估计

            #
            # --- 逻辑修正部分 (已修复) ---
            #
            if estimator == 'exact':
                Hest = H_exact  # 直接使用真实信道
            elif estimator == 'CNN':
                # 根据SNR选择合适的模型
                if snr_db <= 15:
                    selected_model = cnn_models[10]
                elif snr_db <= 25:
                    selected_model = cnn_models[20]
                else:
                    selected_model = cnn_models[30]

                data_preprocessed = preprocess_complex_data(OFDM_freq)  # 预处理数据
                data_preprocessed = data_preprocessed.reshape(1, 64, 2)  # 重塑为CNN输入格式
                estimate = selected_model.predict(data_preprocessed, verbose=0)  # CNN预测
                Hest = postprocess_complex_data(estimate)  # 后处理得到估计信道

            # 对于LS和MMSE估计器,需要进行插值
            elif estimator in ['LS', 'MMSE']:
                if interpolation == 'Linear':
                    Hest_abs = interp1d(pilotCarriers, abs(H_est_pilots), kind='linear')(allCarriers)  # 线性插值幅度
                    Hest_phase = interp1d(pilotCarriers, np.angle(H_est_pilots), kind='linear')(allCarriers)  # 线性插值相位
                elif interpolation == 'Quadratic':
                    Hest_abs = interp1d(pilotCarriers, abs(H_est_pilots), kind='quadratic')(allCarriers)  # 二次插值幅度
                    Hest_phase = interp1d(pilotCarriers, np.angle(H_est_pilots), kind='quadratic')(
                        allCarriers)  # 二次插值相位
                elif interpolation == 'Cubic':
                    Hest_abs = interp1d(pilotCarriers, abs(H_est_pilots), kind='cubic')(allCarriers)  # 三次插值幅度
                    Hest_phase = interp1d(pilotCarriers, np.angle(H_est_pilots), kind='cubic')(allCarriers)  # 三次插值相位

主要对比:不同插值方法下LS、MMSE、CNN算法对比

对比性能:误码率(BER)和均方误差(MSE)

包含参考文献

代码一键运行即可

python版本3.10 Tensorflow版本2.10

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐