StructBERT相似度模型教程:中文文本标准化预处理脚本

你是不是遇到过这样的问题?手里有一堆中文文本数据,想要用StructBERT模型计算它们的相似度,但直接扔给模型效果总是不理想?有时候明明意思相近的句子,模型给出的相似度分数却很低;有时候完全不同的句子,分数反而挺高。

这很可能不是模型的问题,而是你的文本数据"不干净"。

今天我要分享一个专门为StructBERT相似度模型设计的中文文本标准化预处理脚本。这个脚本能帮你把乱七八糟的中文文本变得"规规矩矩",让模型更好地理解文本的真正含义,从而得到更准确的相似度计算结果。

1. 为什么需要中文文本标准化?

1.1 中文文本的特殊性

中文文本处理比英文复杂得多,主要因为:

  • 没有空格分隔:英文单词之间有空格,中文词语之间没有
  • 繁简体混用:同一个词可能有简体、繁体两种写法
  • 全角半角问题:标点符号、数字、字母都有全角和半角之分
  • 特殊字符干扰:表情符号、HTML标签、乱码字符等
  • 停用词影响:"的"、"了"、"在"这些词出现频率高但信息量低

1.2 StructBERT模型的特点

StructBERT文本相似度-中文-通用-large模型是在structbert-large-chinese预训练模型基础上,用多个中文数据集训练出来的。这个模型对文本质量比较敏感:

  • 训练数据经过了标准化处理
  • 模型期望输入的是"干净"的中文文本
  • 非标准文本会影响词向量表示的质量

1.3 预处理带来的好处

做好文本标准化预处理,能带来几个明显的好处:

  1. 提高相似度计算准确率:减少噪声干扰,让模型专注于文本的语义
  2. 提升计算效率:清理掉无用字符,减少模型计算负担
  3. 结果更稳定:相同的文本经过预处理后,每次计算的结果更一致
  4. 便于后续处理:标准化后的文本更适合做其他NLP任务

2. 预处理脚本完整实现

下面是我在实际项目中使用的完整预处理脚本,你可以直接复制使用:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
StructBERT中文文本标准化预处理脚本
专门为StructBERT文本相似度-中文-通用-large模型设计
"""

import re
import jieba
import zhconv
from typing import List, Union, Tuple
import unicodedata

class ChineseTextPreprocessor:
    """中文文本标准化预处理类"""
    
    def __init__(self, use_jieba: bool = True, custom_stopwords: List[str] = None):
        """
        初始化预处理器
        
        Args:
            use_jieba: 是否使用jieba分词(默认True)
            custom_stopwords: 自定义停用词列表
        """
        self.use_jieba = use_jieba
        
        # 基础停用词列表(针对中文相似度任务优化)
        self.stopwords = set([
            '的', '了', '在', '是', '我', '有', '和', '就', 
            '不', '人', '都', '一', '一个', '上', '也', '很', 
            '到', '说', '要', '去', '你', '会', '着', '没有', 
            '看', '好', '自己', '这', '那', '但', '把', '又', 
            '这', '那', '哪', '谁', '什么', '怎么', '为什么',
            '可以', '可能', '应该', '能够', '需要', '必须'
        ])
        
        # 添加自定义停用词
        if custom_stopwords:
            self.stopwords.update(custom_stopwords)
        
        # 编译常用正则表达式(提高效率)
        self.url_pattern = re.compile(r'https?://\S+|www\.\S+')
        self.email_pattern = re.compile(r'\S+@\S+\.\S+')
        self.html_pattern = re.compile(r'<[^>]+>')
        self.emoji_pattern = re.compile(
            "["
            u"\U0001F600-\U0001F64F"  # emoticons
            u"\U0001F300-\U0001F5FF"  # symbols & pictographs
            u"\U0001F680-\U0001F6FF"  # transport & map symbols
            u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
            "]+", flags=re.UNICODE
        )
        
    def normalize_text(self, text: str) -> str:
        """
        文本标准化主函数
        
        Args:
            text: 原始文本
            
        Returns:
            标准化后的文本
        """
        if not isinstance(text, str) or not text.strip():
            return ""
        
        # 1. 转换为字符串并去除首尾空白
        text = str(text).strip()
        
        # 2. 统一换行符
        text = text.replace('\r\n', '\n').replace('\r', '\n')
        
        # 3. 移除URL链接
        text = self.url_pattern.sub(' ', text)
        
        # 4. 移除邮箱地址
        text = self.email_pattern.sub(' ', text)
        
        # 5. 移除HTML标签
        text = self.html_pattern.sub(' ', text)
        
        # 6. 移除表情符号
        text = self.emoji_pattern.sub(' ', text)
        
        # 7. 繁体转简体
        text = zhconv.convert(text, 'zh-cn')
        
        # 8. 全角转半角
        text = self._full_to_half(text)
        
        # 9. 标准化标点符号
        text = self._normalize_punctuation(text)
        
        # 10. 移除多余空白字符
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    def _full_to_half(self, text: str) -> str:
        """全角字符转半角字符"""
        result = []
        for char in text:
            code = ord(char)
            # 全角空格直接转换
            if code == 0x3000:
                code = 0x20
            # 全角字符转半角
            elif 0xFF01 <= code <= 0xFF5E:
                code -= 0xFEE0
            result.append(chr(code))
        return ''.join(result)
    
    def _normalize_punctuation(self, text: str) -> str:
        """标准化中文标点符号"""
        # 中文标点转英文标点(根据相似度任务需求)
        punctuation_map = {
            ',': ',', '。': '.', '!': '!', '?': '?',
            ';': ';', ':': ':', '「': '"', '」': '"',
            '『': '"', '』': '"', '(': '(', ')': ')',
            '【': '[', '】': ']', '《': '<', '》': '>',
        }
        
        for cn_punc, en_punc in punctuation_map.items():
            text = text.replace(cn_punc, en_punc)
        
        return text
    
    def tokenize_and_filter(self, text: str, remove_stopwords: bool = True) -> str:
        """
        分词并过滤(可选)
        
        Args:
            text: 标准化后的文本
            remove_stopwords: 是否移除停用词
            
        Returns:
            处理后的文本
        """
        if not text:
            return ""
        
        if self.use_jieba:
            # 使用jieba精确模式分词
            words = jieba.lcut(text)
        else:
            # 按空格简单分割(适用于已分词的文本)
            words = text.split()
        
        if remove_stopwords:
            # 过滤停用词
            words = [word for word in words if word not in self.stopwords]
        
        # 重新组合为字符串
        return ' '.join(words)
    
    def preprocess_for_similarity(self, text1: str, text2: str) -> Tuple[str, str]:
        """
        专门为相似度计算准备的预处理
        
        Args:
            text1: 第一个文本
            text2: 第二个文本
            
        Returns:
            处理后的两个文本
        """
        # 标准化处理
        norm_text1 = self.normalize_text(text1)
        norm_text2 = self.normalize_text(text2)
        
        # 分词处理(StructBERT不需要分词,但可以用于分析)
        # 这里返回标准化后的文本,StructBERT会自己处理分词
        return norm_text1, norm_text2
    
    def batch_preprocess(self, texts: List[str]) -> List[str]:
        """
        批量预处理文本
        
        Args:
            texts: 文本列表
            
        Returns:
            处理后的文本列表
        """
        return [self.normalize_text(text) for text in texts]


# 使用示例
def example_usage():
    """使用示例"""
    
    # 初始化预处理器
    preprocessor = ChineseTextPreprocessor()
    
    # 示例文本(包含各种问题)
    raw_text1 = "今天天气很好,我们去公园玩吧!😊 网址:https://example.com"
    raw_text2 = "今天天氣很好,我們去公園玩吧!#開心"
    
    print("原始文本1:", raw_text1)
    print("原始文本2:", raw_text2)
    print()
    
    # 单个文本处理
    processed1 = preprocessor.normalize_text(raw_text1)
    processed2 = preprocessor.normalize_text(raw_text2)
    
    print("处理后文本1:", processed1)
    print("处理后文本2:", processed2)
    print()
    
    # 专门为相似度计算处理
    text1, text2 = preprocessor.preprocess_for_similarity(raw_text1, raw_text2)
    print("相似度计算专用处理:")
    print("文本1:", text1)
    print("文本2:", text2)
    print()
    
    # 批量处理
    texts = [raw_text1, raw_text2, "Hello World! 你好世界!"]
    processed_texts = preprocessor.batch_preprocess(texts)
    print("批量处理结果:")
    for i, (orig, proc) in enumerate(zip(texts, processed_texts)):
        print(f"{i+1}. 原始: {orig}")
        print(f"   处理: {proc}")


if __name__ == "__main__":
    # 运行示例
    example_usage()

3. 预处理步骤详解

3.1 文本清理:移除干扰内容

文本清理是预处理的第一步,目的是移除对语义理解没有帮助的噪声:

def clean_text(text: str) -> str:
    """文本清理函数"""
    # 移除URL(模型不需要知道链接)
    text = re.sub(r'https?://\S+|www\.\S+', ' ', text)
    
    # 移除邮箱地址
    text = re.sub(r'\S+@\S+\.\S+', ' ', text)
    
    # 移除HTML标签
    text = re.sub(r'<[^>]+>', ' ', text)
    
    # 移除表情符号(StructBERT不支持)
    emoji_pattern = re.compile(
        "["
        u"\U0001F600-\U0001F64F"  # 表情符号
        u"\U0001F300-\U0001F5FF"  # 符号和象形文字
        u"\U0001F680-\U0001F6FF"  # 交通和地图符号
        "]+", flags=re.UNICODE
    )
    text = emoji_pattern.sub(' ', text)
    
    return text

为什么这么做?

  • URL和邮箱不包含语义信息,反而可能干扰模型
  • HTML标签是格式信息,不是内容信息
  • 表情符号在文本相似度任务中通常不重要

3.2 字符标准化:统一文本格式

中文文本经常遇到字符格式不一致的问题:

def normalize_characters(text: str) -> str:
    """字符标准化"""
    # 繁体转简体(使用zhconv库)
    text = zhconv.convert(text, 'zh-cn')
    
    # 全角转半角
    result = []
    for char in text:
        code = ord(char)
        if code == 0x3000:  # 全角空格
            code = 0x20      # 半角空格
        elif 0xFF01 <= code <= 0xFF5E:  # 全角字符范围
            code -= 0xFEE0  # 转换为半角
        result.append(chr(code))
    
    return ''.join(result)

常见问题示例:

问题类型 示例 标准化后
繁简体混用 "學習中文很重要" "学习中文很重要"
全角半角混用 "Hello World!" "Hello World!"
空格不一致 "今天 天气 很好" "今天 天气 很好"

3.3 标点标准化:统一标点符号

中文标点符号的标准化对相似度计算很重要:

def normalize_punctuation(text: str) -> str:
    """标点符号标准化"""
    # 中文标点转英文标点
    punctuation_map = {
        ',': ',',  # 中文逗号转英文逗号
        '。': '.',  # 中文句号转英文句号
        '!': '!',  # 中文感叹号转英文感叹号
        '?': '?',  # 中文问号转英文问号
        ';': ';',  # 中文分号转英文分号
        ':': ':',  # 中文冒号转英文冒号
        '「': '"', '」': '"',  # 中文引号转英文引号
        '『': '"', '』': '"',
        '(': '(', ')': ')',  # 中文括号转英文括号
        '【': '[', '】': ']',
        '《': '<', '》': '>',
    }
    
    for cn_punc, en_punc in punctuation_map.items():
        text = text.replace(cn_punc, en_punc)
    
    return text

为什么需要标点标准化?

  • StructBERT在训练时使用的是标准标点
  • 统一的标点有助于模型更好地理解句子结构
  • 避免因为标点差异导致相似度计算偏差

3.4 停用词处理:根据任务决定

停用词处理需要根据具体任务来决定:

class StopwordHandler:
    """停用词处理器"""
    
    def __init__(self, task_type: str = "similarity"):
        self.task_type = task_type
        self.stopwords = self._load_stopwords()
    
    def _load_stopwords(self) -> set:
        """加载停用词"""
        base_stopwords = {
            '的', '了', '在', '是', '我', '有', '和', '就',
            '不', '人', '都', '一', '一个', '上', '也', '很',
        }
        
        if self.task_type == "similarity":
            # 相似度任务:保留更多停用词,因为它们可能影响语义
            return set()  # 不删除任何停用词
        elif self.task_type == "classification":
            # 分类任务:可以适当删除停用词
            return base_stopwords
        else:
            return base_stopwords
    
    def process(self, text: str) -> str:
        """处理停用词"""
        if not self.stopwords:
            return text
        
        words = text.split()
        filtered_words = [w for w in words if w not in self.stopwords]
        return ' '.join(filtered_words)

对于StructBERT相似度模型的建议:

  • 不要删除停用词:StructBERT能够理解停用词的语义作用
  • "的"、"了"、"在"这些词在中文中承载着重要的语法信息
  • 删除停用词可能改变句子的语义,影响相似度计算

4. 与StructBERT模型集成

4.1 完整的预处理流水线

下面是如何将预处理脚本与StructBERT模型集成的完整示例:

import torch
from sentence_transformers import SentenceTransformer
from typing import List, Tuple
import numpy as np

class StructBERTSimilarityPipeline:
    """StructBERT相似度计算流水线"""
    
    def __init__(self, model_path: str = None):
        """
        初始化流水线
        
        Args:
            model_path: 模型路径,如果为None则使用默认模型
        """
        # 初始化预处理器
        self.preprocessor = ChineseTextPreprocessor(use_jieba=False)
        
        # 加载StructBERT模型
        if model_path:
            self.model = SentenceTransformer(model_path)
        else:
            # 使用StructBERT文本相似度-中文-通用-large模型
            model_name = "structbert-large-chinese-similarity"
            self.model = SentenceTransformer(model_name)
        
        # 设置设备
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)
    
    def preprocess_texts(self, texts: List[str]) -> List[str]:
        """预处理文本列表"""
        return self.preprocessor.batch_preprocess(texts)
    
    def calculate_similarity(self, text1: str, text2: str) -> float:
        """
        计算两个文本的相似度
        
        Args:
            text1: 第一个文本
            text2: 第二个文本
            
        Returns:
            相似度分数(0-1之间)
        """
        # 预处理
        processed1, processed2 = self.preprocessor.preprocess_for_similarity(text1, text2)
        
        # 编码为向量
        embeddings = self.model.encode([processed1, processed2], 
                                      convert_to_tensor=True,
                                      device=self.device)
        
        # 计算余弦相似度
        cos_sim = torch.nn.functional.cosine_similarity(embeddings[0].unsqueeze(0),
                                                       embeddings[1].unsqueeze(0))
        
        # 转换为0-1之间的分数
        similarity = (cos_sim.item() + 1) / 2
        
        return similarity
    
    def batch_similarity(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
        """
        批量计算相似度
        
        Args:
            text_pairs: 文本对列表
            
        Returns:
            相似度分数列表
        """
        similarities = []
        
        for text1, text2 in text_pairs:
            sim = self.calculate_similarity(text1, text2)
            similarities.append(sim)
        
        return similarities
    
    def find_most_similar(self, query: str, candidates: List[str], top_k: int = 5) -> List[Tuple[str, float]]:
        """
        在候选文本中查找最相似的文本
        
        Args:
            query: 查询文本
            candidates: 候选文本列表
            top_k: 返回最相似的前k个
            
        Returns:
            (文本, 相似度) 列表
        """
        # 预处理所有文本
        processed_query = self.preprocessor.normalize_text(query)
        processed_candidates = self.preprocessor.batch_preprocess(candidates)
        
        # 编码所有文本
        all_texts = [processed_query] + processed_candidates
        embeddings = self.model.encode(all_texts, 
                                      convert_to_tensor=True,
                                      device=self.device)
        
        # 计算查询与所有候选的相似度
        query_embedding = embeddings[0]
        candidate_embeddings = embeddings[1:]
        
        similarities = []
        for i, emb in enumerate(candidate_embeddings):
            cos_sim = torch.nn.functional.cosine_similarity(query_embedding.unsqueeze(0),
                                                          emb.unsqueeze(0))
            similarity = (cos_sim.item() + 1) / 2
            similarities.append((candidates[i], similarity))
        
        # 按相似度排序
        similarities.sort(key=lambda x: x[1], reverse=True)
        
        return similarities[:top_k]


# 使用示例
def pipeline_example():
    """流水线使用示例"""
    
    # 初始化流水线
    pipeline = StructBERTSimilarityPipeline()
    
    # 示例文本对
    text_pairs = [
        ("今天天气很好,我们去公园玩吧!", "今天天氣很好,我們去公園玩吧!"),
        ("我喜欢吃苹果", "我爱吃香蕉"),
        ("深度学习是人工智能的一个重要分支", "深度学习属于人工智能领域"),
        ("这个产品的质量非常好", "这个商品的质量很棒"),
    ]
    
    print("文本相似度计算结果:")
    print("-" * 50)
    
    for text1, text2 in text_pairs:
        similarity = pipeline.calculate_similarity(text1, text2)
        print(f"文本1: {text1}")
        print(f"文本2: {text2}")
        print(f"相似度: {similarity:.4f}")
        print("-" * 50)
    
    # 查找最相似的文本
    query = "如何学习编程"
    candidates = [
        "编程学习方法",
        "学习编程的步骤",
        "编程入门指南",
        "今天天气真好",
        "人工智能的发展",
    ]
    
    print(f"\n查询: {query}")
    print("最相似的候选文本:")
    results = pipeline.find_most_similar(query, candidates, top_k=3)
    
    for text, sim in results:
        print(f"  {text} (相似度: {sim:.4f})")


if __name__ == "__main__":
    pipeline_example()

4.2 预处理效果对比

让我们看看预处理前后的效果差异:

def compare_preprocessing_effect():
    """对比预处理效果"""
    
    preprocessor = ChineseTextPreprocessor()
    pipeline = StructBERTSimilarityPipeline()
    
    # 测试用例
    test_cases = [
        {
            "text1": "Hello World! 你好世界!",
            "text2": "hello world! 你好世界!",
            "description": "大小写差异"
        },
        {
            "text1": "我喜欢吃苹果🍎",
            "text2": "我喜欢吃苹果",
            "description": "表情符号差异"
        },
        {
            "text1": "今天天气很好,我们去公园玩吧!",
            "text2": "今天天氣很好,我們去公園玩吧!",
            "description": "繁简体差异"
        },
        {
            "text1": "深度学习是AI的重要分支",
            "text2": "深度学习是人工智能的重要分支",
            "description": "缩写差异"
        },
    ]
    
    print("预处理效果对比:")
    print("=" * 80)
    
    for case in test_cases:
        text1 = case["text1"]
        text2 = case["text2"]
        desc = case["description"]
        
        # 预处理前
        sim_before = pipeline.calculate_similarity(text1, text2)
        
        # 预处理后
        processed1 = preprocessor.normalize_text(text1)
        processed2 = preprocessor.normalize_text(text2)
        sim_after = pipeline.calculate_similarity(processed1, processed2)
        
        print(f"\n案例: {desc}")
        print(f"原始文本1: {text1}")
        print(f"原始文本2: {text2}")
        print(f"处理后文本1: {processed1}")
        print(f"处理后文本2: {processed2}")
        print(f"预处理前相似度: {sim_before:.4f}")
        print(f"预处理后相似度: {sim_after:.4f}")
        print(f"差异: {abs(sim_after - sim_before):.4f}")
        print("-" * 80)


# 运行对比
compare_preprocessing_effect()

5. 实际应用场景

5.1 场景一:智能客服问答匹配

在智能客服系统中,用户的问题可能与标准问题库中的表述方式不同:

class CustomerServiceMatcher:
    """客服问答匹配器"""
    
    def __init__(self):
        self.pipeline = StructBERTSimilarityPipeline()
        self.qa_pairs = []  # 标准问答对
        
    def load_qa_pairs(self, qa_data: List[Tuple[str, str]]):
        """加载标准问答对"""
        self.qa_pairs = qa_data
    
    def find_best_answer(self, user_question: str, threshold: float = 0.8) -> str:
        """
        查找最佳答案
        
        Args:
            user_question: 用户问题
            threshold: 相似度阈值
            
        Returns:
            最佳答案或默认回复
        """
        if not self.qa_pairs:
            return "抱歉,我还没有学习相关知识。"
        
        # 预处理用户问题
        preprocessor = ChineseTextPreprocessor()
        processed_question = preprocessor.normalize_text(user_question)
        
        best_similarity = 0
        best_answer = ""
        
        # 遍历所有标准问题
        for std_question, answer in self.qa_pairs:
            processed_std = preprocessor.normalize_text(std_question)
            similarity = self.pipeline.calculate_similarity(processed_question, processed_std)
            
            if similarity > best_similarity:
                best_similarity = similarity
                best_answer = answer
        
        # 检查是否达到阈值
        if best_similarity >= threshold:
            return best_answer
        else:
            return f"我找到的相关答案是:{best_answer}(匹配度:{best_similarity:.2%})"


# 示例:客服问答匹配
def customer_service_example():
    """客服问答匹配示例"""
    
    matcher = CustomerServiceMatcher()
    
    # 标准问答库
    qa_data = [
        ("怎么重置密码", "您可以在登录页面点击'忘记密码',按照提示操作重置密码。"),
        ("如何修改个人信息", "请登录后进入'个人中心'-'账户设置'修改个人信息。"),
        ("客服电话是多少", "我们的客服电话是400-123-4567,工作时间是9:00-18:00。"),
        ("产品怎么退货", "在收货后7天内,商品完好无损的情况下可以申请退货。"),
    ]
    
    matcher.load_qa_pairs(qa_data)
    
    # 用户问题(可能有各种表述方式)
    user_questions = [
        "我忘记密码了怎么办?",  # 与"怎么重置密码"相似
        "想改一下我的资料",      # 与"如何修改个人信息"相似
        "你们的联系电话?",       # 与"客服电话是多少"相似
        "这个东西不想要了能退吗", # 与"产品怎么退货"相似
        "今天天气怎么样",         # 不在知识库中
    ]
    
    print("客服问答匹配示例:")
    print("=" * 60)
    
    for question in user_questions:
        answer = matcher.find_best_answer(question, threshold=0.7)
        print(f"用户问题: {question}")
        print(f"系统回复: {answer}")
        print("-" * 60)

5.2 场景二:内容去重检测

在内容平台或新闻聚合应用中,检测重复或高度相似的内容:

class ContentDeduplicator:
    """内容去重检测器"""
    
    def __init__(self, similarity_threshold: float = 0.9):
        self.pipeline = StructBERTSimilarityPipeline()
        self.preprocessor = ChineseTextPreprocessor()
        self.threshold = similarity_threshold
        self.content_database = []  # 存储已发布内容
    
    def add_content(self, content: str) -> bool:
        """
        添加新内容,检查是否重复
        
        Args:
            content: 新内容
            
        Returns:
            True表示可以添加(不重复),False表示重复
        """
        processed_new = self.preprocessor.normalize_text(content)
        
        # 与已有内容比较
        for existing in self.content_database:
            processed_existing = self.preprocessor.normalize_text(existing)
            similarity = self.pipeline.calculate_similarity(processed_new, processed_existing)
            
            if similarity >= self.threshold:
                print(f"检测到重复内容,相似度: {similarity:.2%}")
                print(f"新内容: {content[:50]}...")
                print(f"已有内容: {existing[:50]}...")
                return False
        
        # 没有重复,添加到数据库
        self.content_database.append(content)
        return True
    
    def batch_check(self, contents: List[str]) -> Tuple[List[str], List[str]]:
        """
        批量检查内容
        
        Args:
            contents: 内容列表
            
        Returns:
            (不重复内容列表, 重复内容列表)
        """
        unique_contents = []
        duplicate_contents = []
        
        for content in contents:
            if self.add_content(content):
                unique_contents.append(content)
            else:
                duplicate_contents.append(content)
        
        return unique_contents, duplicate_contents


# 示例:内容去重检测
def deduplication_example():
    """内容去重示例"""
    
    deduplicator = ContentDeduplicator(similarity_threshold=0.85)
    
    # 测试内容(有些是重复的,有些是相似的)
    test_contents = [
        "今天天气很好,适合出去散步。",
        "今日天气不错,很适合外出散步。",  # 与第一条高度相似
        "深度学习是人工智能的一个重要分支。",
        "人工智能领域中,深度学习是一个重要方向。",  # 与第三条相似
        "Python是一种流行的编程语言。",
        "Java也是一种常用的编程语言。",  # 与第五条不相似
        "今天天气很好,适合出去散步。",  # 与第一条完全相同
    ]
    
    print("内容去重检测示例:")
    print("=" * 60)
    
    unique, duplicates = deduplicator.batch_check(test_contents)
    
    print(f"\n原始内容数量: {len(test_contents)}")
    print(f"去重后唯一内容数量: {len(unique)}")
    print(f"检测到的重复内容数量: {len(duplicates)}")
    
    print("\n唯一内容:")
    for i, content in enumerate(unique, 1):
        print(f"{i}. {content}")
    
    print("\n重复内容:")
    for i, content in enumerate(duplicates, 1):
        print(f"{i}. {content}")

5.3 场景三:文档相似度搜索

在企业知识库或文档管理系统中,快速找到相似文档:

class DocumentSimilaritySearch:
    """文档相似度搜索引擎"""
    
    def __init__(self):
        self.pipeline = StructBERTSimilarityPipeline()
        self.preprocessor = ChineseTextPreprocessor()
        self.documents = []  # 文档列表
        self.document_embeddings = None  # 文档向量缓存
    
    def add_document(self, doc_id: str, content: str):
        """添加文档"""
        processed_content = self.preprocessor.normalize_text(content)
        self.documents.append({
            'id': doc_id,
            'content': content,
            'processed': processed_content
        })
        self.document_embeddings = None  # 清空缓存
    
    def build_index(self):
        """构建文档索引(预计算向量)"""
        if not self.documents:
            return
        
        # 获取所有处理后的文档内容
        processed_contents = [doc['processed'] for doc in self.documents]
        
        # 批量编码为向量
        self.document_embeddings = self.pipeline.model.encode(
            processed_contents,
            convert_to_tensor=True,
            device=self.pipeline.device
        )
    
    def search_similar(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
        """
        搜索相似文档
        
        Args:
            query: 查询文本
            top_k: 返回最相似的前k个文档
            
        Returns:
            (文档ID, 文档内容摘要, 相似度) 列表
        """
        if not self.documents:
            return []
        
        # 如果没有构建索引,先构建
        if self.document_embeddings is None:
            self.build_index()
        
        # 预处理查询
        processed_query = self.preprocessor.normalize_text(query)
        
        # 编码查询
        query_embedding = self.pipeline.model.encode(
            processed_query,
            convert_to_tensor=True,
            device=self.pipeline.device
        )
        
        # 计算相似度
        similarities = []
        for i, doc in enumerate(self.documents):
            doc_embedding = self.document_embeddings[i]
            cos_sim = torch.nn.functional.cosine_similarity(
                query_embedding.unsqueeze(0),
                doc_embedding.unsqueeze(0)
            )
            similarity = (cos_sim.item() + 1) / 2
            similarities.append((doc['id'], doc['content'], similarity))
        
        # 排序并返回top_k
        similarities.sort(key=lambda x: x[2], reverse=True)
        
        # 生成结果(包含内容摘要)
        results = []
        for doc_id, content, similarity in similarities[:top_k]:
            # 生成内容摘要(前100字符)
            summary = content[:100] + "..." if len(content) > 100 else content
            results.append((doc_id, summary, similarity))
        
        return results


# 示例:文档相似度搜索
def document_search_example():
    """文档搜索示例"""
    
    searcher = DocumentSimilaritySearch()
    
    # 添加示例文档
    documents = [
        ("doc1", "机器学习是人工智能的一个分支,它使计算机能够在没有明确编程的情况下学习。"),
        ("doc2", "深度学习是机器学习的一个子领域,它使用神经网络模拟人脑的工作方式。"),
        ("doc3", "自然语言处理是人工智能的一个领域,专注于计算机与人类语言之间的交互。"),
        ("doc4", "计算机视觉使计算机能够从图像和视频中获取高级理解。"),
        ("doc5", "强化学习是一种机器学习方法,智能体通过与环境互动来学习最佳行为。"),
        ("doc6", "监督学习使用标记数据训练模型,而非监督学习使用未标记数据。"),
        ("doc7", "神经网络由相互连接的节点组成,这些节点以类似于人脑神经元的方式工作。"),
        ("doc8", "卷积神经网络专门用于处理图像识别和计算机视觉任务。"),
        ("doc9", "循环神经网络适用于序列数据,如时间序列或自然语言。"),
        ("doc10", "Transformer模型在自然语言处理任务中取得了突破性进展。"),
    ]
    
    for doc_id, content in documents:
        searcher.add_document(doc_id, content)
    
    # 构建索引
    searcher.build_index()
    
    # 搜索查询
    queries = [
        "什么是人工智能学习",
        "如何处理图像识别",
        "自然语言的技术",
    ]
    
    print("文档相似度搜索示例:")
    print("=" * 80)
    
    for query in queries:
        print(f"\n查询: {query}")
        results = searcher.search_similar(query, top_k=3)
        
        print("最相关文档:")
        for doc_id, summary, similarity in results:
            print(f"  ID: {doc_id}")
            print(f"  摘要: {summary}")
            print(f"  相似度: {similarity:.4f}")
            print()
        print("-" * 80)

6. 性能优化与最佳实践

6.1 批量处理优化

当需要处理大量文本时,批量处理可以显著提高效率:

class OptimizedPreprocessor:
    """优化版预处理器(支持批量处理)"""
    
    def __init__(self, batch_size: int = 32):
        self.batch_size = batch_size
        self.preprocessor = ChineseTextPreprocessor()
    
    def parallel_preprocess(self, texts: List[str]) -> List[str]:
        """并行预处理(简化版)"""
        import concurrent.futures
        
        # 分批处理
        batches = [texts[i:i + self.batch_size] 
                  for i in range(0, len(texts), self.batch_size)]
        
        processed_texts = []
        
        # 使用线程池并行处理
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            for batch in batches:
                future = executor.submit(self._process_batch, batch)
                futures.append(future)
            
            for future in concurrent.futures.as_completed(futures):
                processed_texts.extend(future.result())
        
        return processed_texts
    
    def _process_batch(self, batch: List[str]) -> List[str]:
        """处理单个批次"""
        return [self.preprocessor.normalize_text(text) for text in batch]
    
    def preprocess_large_dataset(self, file_path: str, output_path: str):
        """处理大型数据集"""
        processed_lines = []
        
        # 逐行读取和处理
        with open(file_path, 'r', encoding='utf-8') as f:
            batch = []
            for line in f:
                line = line.strip()
                if line:  # 跳过空行
                    batch.append(line)
                
                # 达到批次大小时处理
                if len(batch) >= self.batch_size:
                    processed_batch = self._process_batch(batch)
                    processed_lines.extend(processed_batch)
                    batch = []
            
            # 处理剩余行
            if batch:
                processed_batch = self._process_batch(batch)
                processed_lines.extend(processed_batch)
        
        # 保存结果
        with open(output_path, 'w', encoding='utf-8') as f:
            for line in processed_lines:
                f.write(line + '\n')
        
        print(f"处理完成!共处理 {len(processed_lines)} 行文本。")

6.2 缓存机制

对于重复的文本处理,可以使用缓存提高效率:

import hashlib
from functools import lru_cache

class CachedPreprocessor:
    """带缓存的预处理器"""
    
    def __init__(self):
        self.preprocessor = ChineseTextPreprocessor()
        self.cache = {}  # 简单缓存字典
        self.cache_hits = 0
        self.cache_misses = 0
    
    def _get_hash(self, text: str) -> str:
        """生成文本哈希值(用于缓存键)"""
        return hashlib.md5(text.encode('utf-8')).hexdigest()
    
    def normalize_with_cache(self, text: str) -> str:
        """带缓存的标准化处理"""
        # 生成缓存键
        cache_key = self._get_hash(text)
        
        # 检查缓存
        if cache_key in self.cache:
            self.cache_hits += 1
            return self.cache[cache_key]
        
        # 缓存未命中,进行处理
        self.cache_misses += 1
        processed_text = self.preprocessor.normalize_text(text)
        
        # 存入缓存
        self.cache[cache_key] = processed_text
        
        # 如果缓存太大,清理一部分
        if len(self.cache) > 10000:  # 最大缓存大小
            # 简单策略:清理前1000个
            keys_to_remove = list(self.cache.keys())[:1000]
            for key in keys_to_remove:
                del self.cache[key]
        
        return processed_text
    
    def get_cache_stats(self) -> dict:
        """获取缓存统计信息"""
        return {
            'cache_size': len(self.cache),
            'cache_hits': self.cache_hits,
            'cache_misses': self.cache_misses,
            'hit_rate': self.cache_hits / (self.cache_hits + self.cache_misses) 
                       if (self.cache_hits + self.cache_misses) > 0 else 0
        }
    
    @lru_cache(maxsize=10000)
    def normalize_lru(self, text: str) -> str:
        """使用LRU缓存的标准化处理"""
        return self.preprocessor.normalize_text(text)


# 测试缓存效果
def test_cache_performance():
    """测试缓存性能"""
    
    preprocessor = CachedPreprocessor()
    
    # 测试文本(有些重复)
    test_texts = [
        "今天天气很好",
        "今天天气很好",  # 重复
        "我喜欢编程",
        "深度学习很有趣",
        "今天天气很好",  # 重复
        "我喜欢编程",     # 重复
        "机器学习",
        "人工智能",
        "今天天气很好",  # 重复
    ] * 100  # 重复100次
    
    print("开始处理...")
    
    # 处理所有文本
    for text in test_texts:
        _ = preprocessor.normalize_with_cache(text)
    
    # 获取统计信息
    stats = preprocessor.get_cache_stats()
    
    print(f"处理完成!")
    print(f"缓存大小: {stats['cache_size']}")
    print(f"缓存命中: {stats['cache_hits']}")
    print(f"缓存未命中: {stats['cache_misses']}")
    print(f"命中率: {stats['hit_rate']:.2%}")

6.3 错误处理与日志

在生产环境中,良好的错误处理和日志记录很重要:

import logging
from datetime import datetime

class RobustPreprocessor:
    """健壮性预处理器(带错误处理和日志)"""
    
    def __init__(self, log_file: str = None):
        self.preprocessor = ChineseTextPreprocessor()
        
        # 配置日志
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)
        
        # 控制台处理器
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)
        
        # 文件处理器(如果指定了日志文件)
        if log_file:
            file_handler = logging.FileHandler(log_file, encoding='utf-8')
            file_handler.setLevel(logging.INFO)
            self.logger.addHandler(file_handler)
        
        # 格式化器
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        console_handler.setFormatter(formatter)
        if log_file:
            file_handler.setFormatter(formatter)
        
        self.logger.addHandler(console_handler)
    
    def safe_normalize(self, text: str, text_id: str = None) -> str:
        """
        安全的标准化处理(带错误处理)
        
        Args:
            text: 要处理的文本
            text_id: 文本标识(用于日志)
            
        Returns:
            处理后的文本,如果出错返回空字符串
        """
        try:
            start_time = datetime.now()
            
            # 检查输入
            if text is None:
                self.logger.warning(f"文本为空,text_id: {text_id}")
                return ""
            
            if not isinstance(text, str):
                self.logger.warning(f"文本不是字符串类型,text_id: {text_id}, 类型: {type(text)}")
                text = str(text)
            
            # 处理
            result = self.preprocessor.normalize_text(text)
            
            # 记录处理时间
            process_time = (datetime.now() - start_time).total_seconds() * 1000
            
            # 记录日志
            self.logger.info(
                f"文本处理完成 - "
                f"text_id: {text_id}, "
                f"原始长度: {len(text)}, "
                f"处理后长度: {len(result)}, "
                f"处理时间: {process_time:.2f}ms"
            )
            
            return result
            
        except Exception as e:
            # 记录错误
            self.logger.error(
                f"文本处理失败 - "
                f"text_id: {text_id}, "
                f"错误: {str(e)}, "
                f"文本前50字符: {str(text)[:50] if text else 'None'}"
            )
            
            # 返回空字符串或原始文本(根据需求)
            return ""
    
    def batch_safe_normalize(self, texts: List[str], text_ids: List[str] = None) -> List[str]:
        """
        批量安全处理
        
        Args:
            texts: 文本列表
            text_ids: 文本ID列表(可选)
            
        Returns:
            处理后的文本列表
        """
        if text_ids is None:
            text_ids = [f"text_{i}" for i in range(len(texts))]
        
        if len(texts) != len(text_ids):
            self.logger.warning("文本列表和ID列表长度不一致")
            text_ids = [f"text_{i}" for i in range(len(texts))]
        
        results = []
        success_count = 0
        fail_count = 0
        
        for text, text_id in zip(texts, text_ids):
            try:
                result = self.safe_normalize(text, text_id)
                results.append(result)
                success_count += 1
            except Exception as e:
                self.logger.error(f"批量处理失败 - text_id: {text_id}, 错误: {str(e)}")
                results.append("")  # 失败时返回空字符串
                fail_count += 1
        
        self.logger.info(
            f"批量处理完成 - "
            f"总数: {len(texts)}, "
            f"成功: {success_count}, "
            f"失败: {fail_count}, "
            f"成功率: {success_count/len(texts):.2%}"
        )
        
        return results


# 测试健壮性处理器
def test_robust_preprocessor():
    """测试健壮性预处理器"""
    
    # 创建预处理器(带日志)
    preprocessor = RobustPreprocessor(log_file="preprocessor.log")
    
    # 测试各种输入
    test_cases = [
        ("正常文本", "今天天气很好,我们去公园玩吧!"),
        ("空文本", ""),
        ("None输入", None),
        ("非字符串", 12345),
        ("特殊字符", "Hello\nWorld\t测试\r\n"),
        ("超长文本", "测试" * 1000),
    ]
    
    print("测试健壮性预处理器:")
    print("=" * 60)
    
    for case_name, text in test_cases:
        print(f"\n测试用例: {case_name}")
        print(f"输入: {str(text)[:50] if text else 'None'}")
        
        result = preprocessor.safe_normalize(text, f"test_{case_name}")
        
        print(f"输出: {result[:50] if result else '空字符串'}")
        print(f"输出长度: {len(result)}")

7. 总结

7.1 关键要点回顾

通过本文的讲解,你应该掌握了以下几个关键点:

  1. 预处理的重要性:中文文本标准化预处理能显著提升StructBERT相似度模型的准确性和稳定性。

  2. 完整的处理流程:从文本清理、字符标准化到标点统一,每个步骤都有其特定作用。

  3. 与模型的无缝集成:预处理脚本可以轻松集成到现有的StructBERT相似度计算流程中。

  4. 实际应用场景:预处理技术在智能客服、内容去重、文档搜索等场景中都有重要应用。

  5. 性能优化技巧:通过批量处理、缓存机制和错误处理,可以提升预处理效率和生产环境稳定性。

7.2 最佳实践建议

根据我的实践经验,给你几个实用建议:

  1. 预处理程度要适中:不要过度预处理,保留对语义理解重要的信息。

  2. 根据任务调整:不同的相似度任务可能需要不同的预处理策略。

  3. 测试验证:在实际数据上测试预处理效果,确保它真的提升了模型性能。

  4. 监控和优化:在生产环境中监控预处理效果,根据实际情况调整参数。

  5. 保持更新:中文语言使用习惯在变化,预处理规则也需要定期更新。

7.3 下一步学习建议

如果你对这个主题感兴趣,可以继续深入学习:

  1. 探索更多预处理技术:如词干还原、词形还原(对中文有限)、命名实体识别等。

  2. 了解其他相似度模型:除了StructBERT,还有BERT、RoBERTa、ERNIE等模型。

  3. 学习向量检索技术:如Faiss、Annoy等向量数据库,用于大规模相似度搜索。

  4. 实践完整项目:尝试构建一个完整的文本相似度应用系统。

  5. 关注最新研究:自然语言处理领域发展很快,新的预处理技术不断出现。

预处理可能看起来是个小细节,但在实际应用中,它往往是决定项目成败的关键因素之一。一个好的预处理流程能让模型性能提升10%-30%,这个投入是非常值得的。

希望这个预处理脚本和教程能帮助你在StructBERT相似度计算任务上取得更好的效果!


获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐