StructBERT相似度模型教程:中文文本标准化预处理脚本
本文介绍了如何在星图GPU平台上自动化部署StructBERT文本相似度-中文-通用-large镜像,并利用该模型进行中文文本相似度计算。通过使用配套的文本标准化预处理脚本,可以高效处理繁简体、标点等中文文本问题,提升模型在智能客服问答匹配、内容去重等场景下的准确性与稳定性。
StructBERT相似度模型教程:中文文本标准化预处理脚本
你是不是遇到过这样的问题?手里有一堆中文文本数据,想要用StructBERT模型计算它们的相似度,但直接扔给模型效果总是不理想?有时候明明意思相近的句子,模型给出的相似度分数却很低;有时候完全不同的句子,分数反而挺高。
这很可能不是模型的问题,而是你的文本数据"不干净"。
今天我要分享一个专门为StructBERT相似度模型设计的中文文本标准化预处理脚本。这个脚本能帮你把乱七八糟的中文文本变得"规规矩矩",让模型更好地理解文本的真正含义,从而得到更准确的相似度计算结果。
1. 为什么需要中文文本标准化?
1.1 中文文本的特殊性
中文文本处理比英文复杂得多,主要因为:
- 没有空格分隔:英文单词之间有空格,中文词语之间没有
- 繁简体混用:同一个词可能有简体、繁体两种写法
- 全角半角问题:标点符号、数字、字母都有全角和半角之分
- 特殊字符干扰:表情符号、HTML标签、乱码字符等
- 停用词影响:"的"、"了"、"在"这些词出现频率高但信息量低
1.2 StructBERT模型的特点
StructBERT文本相似度-中文-通用-large模型是在structbert-large-chinese预训练模型基础上,用多个中文数据集训练出来的。这个模型对文本质量比较敏感:
- 训练数据经过了标准化处理
- 模型期望输入的是"干净"的中文文本
- 非标准文本会影响词向量表示的质量
1.3 预处理带来的好处
做好文本标准化预处理,能带来几个明显的好处:
- 提高相似度计算准确率:减少噪声干扰,让模型专注于文本的语义
- 提升计算效率:清理掉无用字符,减少模型计算负担
- 结果更稳定:相同的文本经过预处理后,每次计算的结果更一致
- 便于后续处理:标准化后的文本更适合做其他NLP任务
2. 预处理脚本完整实现
下面是我在实际项目中使用的完整预处理脚本,你可以直接复制使用:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
StructBERT中文文本标准化预处理脚本
专门为StructBERT文本相似度-中文-通用-large模型设计
"""
import re
import jieba
import zhconv
from typing import List, Union, Tuple
import unicodedata
class ChineseTextPreprocessor:
"""中文文本标准化预处理类"""
def __init__(self, use_jieba: bool = True, custom_stopwords: List[str] = None):
"""
初始化预处理器
Args:
use_jieba: 是否使用jieba分词(默认True)
custom_stopwords: 自定义停用词列表
"""
self.use_jieba = use_jieba
# 基础停用词列表(针对中文相似度任务优化)
self.stopwords = set([
'的', '了', '在', '是', '我', '有', '和', '就',
'不', '人', '都', '一', '一个', '上', '也', '很',
'到', '说', '要', '去', '你', '会', '着', '没有',
'看', '好', '自己', '这', '那', '但', '把', '又',
'这', '那', '哪', '谁', '什么', '怎么', '为什么',
'可以', '可能', '应该', '能够', '需要', '必须'
])
# 添加自定义停用词
if custom_stopwords:
self.stopwords.update(custom_stopwords)
# 编译常用正则表达式(提高效率)
self.url_pattern = re.compile(r'https?://\S+|www\.\S+')
self.email_pattern = re.compile(r'\S+@\S+\.\S+')
self.html_pattern = re.compile(r'<[^>]+>')
self.emoji_pattern = re.compile(
"["
u"\U0001F600-\U0001F64F" # emoticons
u"\U0001F300-\U0001F5FF" # symbols & pictographs
u"\U0001F680-\U0001F6FF" # transport & map symbols
u"\U0001F1E0-\U0001F1FF" # flags (iOS)
"]+", flags=re.UNICODE
)
def normalize_text(self, text: str) -> str:
"""
文本标准化主函数
Args:
text: 原始文本
Returns:
标准化后的文本
"""
if not isinstance(text, str) or not text.strip():
return ""
# 1. 转换为字符串并去除首尾空白
text = str(text).strip()
# 2. 统一换行符
text = text.replace('\r\n', '\n').replace('\r', '\n')
# 3. 移除URL链接
text = self.url_pattern.sub(' ', text)
# 4. 移除邮箱地址
text = self.email_pattern.sub(' ', text)
# 5. 移除HTML标签
text = self.html_pattern.sub(' ', text)
# 6. 移除表情符号
text = self.emoji_pattern.sub(' ', text)
# 7. 繁体转简体
text = zhconv.convert(text, 'zh-cn')
# 8. 全角转半角
text = self._full_to_half(text)
# 9. 标准化标点符号
text = self._normalize_punctuation(text)
# 10. 移除多余空白字符
text = re.sub(r'\s+', ' ', text).strip()
return text
def _full_to_half(self, text: str) -> str:
"""全角字符转半角字符"""
result = []
for char in text:
code = ord(char)
# 全角空格直接转换
if code == 0x3000:
code = 0x20
# 全角字符转半角
elif 0xFF01 <= code <= 0xFF5E:
code -= 0xFEE0
result.append(chr(code))
return ''.join(result)
def _normalize_punctuation(self, text: str) -> str:
"""标准化中文标点符号"""
# 中文标点转英文标点(根据相似度任务需求)
punctuation_map = {
',': ',', '。': '.', '!': '!', '?': '?',
';': ';', ':': ':', '「': '"', '」': '"',
'『': '"', '』': '"', '(': '(', ')': ')',
'【': '[', '】': ']', '《': '<', '》': '>',
}
for cn_punc, en_punc in punctuation_map.items():
text = text.replace(cn_punc, en_punc)
return text
def tokenize_and_filter(self, text: str, remove_stopwords: bool = True) -> str:
"""
分词并过滤(可选)
Args:
text: 标准化后的文本
remove_stopwords: 是否移除停用词
Returns:
处理后的文本
"""
if not text:
return ""
if self.use_jieba:
# 使用jieba精确模式分词
words = jieba.lcut(text)
else:
# 按空格简单分割(适用于已分词的文本)
words = text.split()
if remove_stopwords:
# 过滤停用词
words = [word for word in words if word not in self.stopwords]
# 重新组合为字符串
return ' '.join(words)
def preprocess_for_similarity(self, text1: str, text2: str) -> Tuple[str, str]:
"""
专门为相似度计算准备的预处理
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
处理后的两个文本
"""
# 标准化处理
norm_text1 = self.normalize_text(text1)
norm_text2 = self.normalize_text(text2)
# 分词处理(StructBERT不需要分词,但可以用于分析)
# 这里返回标准化后的文本,StructBERT会自己处理分词
return norm_text1, norm_text2
def batch_preprocess(self, texts: List[str]) -> List[str]:
"""
批量预处理文本
Args:
texts: 文本列表
Returns:
处理后的文本列表
"""
return [self.normalize_text(text) for text in texts]
# 使用示例
def example_usage():
"""使用示例"""
# 初始化预处理器
preprocessor = ChineseTextPreprocessor()
# 示例文本(包含各种问题)
raw_text1 = "今天天气很好,我们去公园玩吧!😊 网址:https://example.com"
raw_text2 = "今天天氣很好,我們去公園玩吧!#開心"
print("原始文本1:", raw_text1)
print("原始文本2:", raw_text2)
print()
# 单个文本处理
processed1 = preprocessor.normalize_text(raw_text1)
processed2 = preprocessor.normalize_text(raw_text2)
print("处理后文本1:", processed1)
print("处理后文本2:", processed2)
print()
# 专门为相似度计算处理
text1, text2 = preprocessor.preprocess_for_similarity(raw_text1, raw_text2)
print("相似度计算专用处理:")
print("文本1:", text1)
print("文本2:", text2)
print()
# 批量处理
texts = [raw_text1, raw_text2, "Hello World! 你好世界!"]
processed_texts = preprocessor.batch_preprocess(texts)
print("批量处理结果:")
for i, (orig, proc) in enumerate(zip(texts, processed_texts)):
print(f"{i+1}. 原始: {orig}")
print(f" 处理: {proc}")
if __name__ == "__main__":
# 运行示例
example_usage()
3. 预处理步骤详解
3.1 文本清理:移除干扰内容
文本清理是预处理的第一步,目的是移除对语义理解没有帮助的噪声:
def clean_text(text: str) -> str:
"""文本清理函数"""
# 移除URL(模型不需要知道链接)
text = re.sub(r'https?://\S+|www\.\S+', ' ', text)
# 移除邮箱地址
text = re.sub(r'\S+@\S+\.\S+', ' ', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', ' ', text)
# 移除表情符号(StructBERT不支持)
emoji_pattern = re.compile(
"["
u"\U0001F600-\U0001F64F" # 表情符号
u"\U0001F300-\U0001F5FF" # 符号和象形文字
u"\U0001F680-\U0001F6FF" # 交通和地图符号
"]+", flags=re.UNICODE
)
text = emoji_pattern.sub(' ', text)
return text
为什么这么做?
- URL和邮箱不包含语义信息,反而可能干扰模型
- HTML标签是格式信息,不是内容信息
- 表情符号在文本相似度任务中通常不重要
3.2 字符标准化:统一文本格式
中文文本经常遇到字符格式不一致的问题:
def normalize_characters(text: str) -> str:
"""字符标准化"""
# 繁体转简体(使用zhconv库)
text = zhconv.convert(text, 'zh-cn')
# 全角转半角
result = []
for char in text:
code = ord(char)
if code == 0x3000: # 全角空格
code = 0x20 # 半角空格
elif 0xFF01 <= code <= 0xFF5E: # 全角字符范围
code -= 0xFEE0 # 转换为半角
result.append(chr(code))
return ''.join(result)
常见问题示例:
| 问题类型 | 示例 | 标准化后 |
|---|---|---|
| 繁简体混用 | "學習中文很重要" | "学习中文很重要" |
| 全角半角混用 | "Hello World!" | "Hello World!" |
| 空格不一致 | "今天 天气 很好" | "今天 天气 很好" |
3.3 标点标准化:统一标点符号
中文标点符号的标准化对相似度计算很重要:
def normalize_punctuation(text: str) -> str:
"""标点符号标准化"""
# 中文标点转英文标点
punctuation_map = {
',': ',', # 中文逗号转英文逗号
'。': '.', # 中文句号转英文句号
'!': '!', # 中文感叹号转英文感叹号
'?': '?', # 中文问号转英文问号
';': ';', # 中文分号转英文分号
':': ':', # 中文冒号转英文冒号
'「': '"', '」': '"', # 中文引号转英文引号
'『': '"', '』': '"',
'(': '(', ')': ')', # 中文括号转英文括号
'【': '[', '】': ']',
'《': '<', '》': '>',
}
for cn_punc, en_punc in punctuation_map.items():
text = text.replace(cn_punc, en_punc)
return text
为什么需要标点标准化?
- StructBERT在训练时使用的是标准标点
- 统一的标点有助于模型更好地理解句子结构
- 避免因为标点差异导致相似度计算偏差
3.4 停用词处理:根据任务决定
停用词处理需要根据具体任务来决定:
class StopwordHandler:
"""停用词处理器"""
def __init__(self, task_type: str = "similarity"):
self.task_type = task_type
self.stopwords = self._load_stopwords()
def _load_stopwords(self) -> set:
"""加载停用词"""
base_stopwords = {
'的', '了', '在', '是', '我', '有', '和', '就',
'不', '人', '都', '一', '一个', '上', '也', '很',
}
if self.task_type == "similarity":
# 相似度任务:保留更多停用词,因为它们可能影响语义
return set() # 不删除任何停用词
elif self.task_type == "classification":
# 分类任务:可以适当删除停用词
return base_stopwords
else:
return base_stopwords
def process(self, text: str) -> str:
"""处理停用词"""
if not self.stopwords:
return text
words = text.split()
filtered_words = [w for w in words if w not in self.stopwords]
return ' '.join(filtered_words)
对于StructBERT相似度模型的建议:
- 不要删除停用词:StructBERT能够理解停用词的语义作用
- "的"、"了"、"在"这些词在中文中承载着重要的语法信息
- 删除停用词可能改变句子的语义,影响相似度计算
4. 与StructBERT模型集成
4.1 完整的预处理流水线
下面是如何将预处理脚本与StructBERT模型集成的完整示例:
import torch
from sentence_transformers import SentenceTransformer
from typing import List, Tuple
import numpy as np
class StructBERTSimilarityPipeline:
"""StructBERT相似度计算流水线"""
def __init__(self, model_path: str = None):
"""
初始化流水线
Args:
model_path: 模型路径,如果为None则使用默认模型
"""
# 初始化预处理器
self.preprocessor = ChineseTextPreprocessor(use_jieba=False)
# 加载StructBERT模型
if model_path:
self.model = SentenceTransformer(model_path)
else:
# 使用StructBERT文本相似度-中文-通用-large模型
model_name = "structbert-large-chinese-similarity"
self.model = SentenceTransformer(model_name)
# 设置设备
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
def preprocess_texts(self, texts: List[str]) -> List[str]:
"""预处理文本列表"""
return self.preprocessor.batch_preprocess(texts)
def calculate_similarity(self, text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
相似度分数(0-1之间)
"""
# 预处理
processed1, processed2 = self.preprocessor.preprocess_for_similarity(text1, text2)
# 编码为向量
embeddings = self.model.encode([processed1, processed2],
convert_to_tensor=True,
device=self.device)
# 计算余弦相似度
cos_sim = torch.nn.functional.cosine_similarity(embeddings[0].unsqueeze(0),
embeddings[1].unsqueeze(0))
# 转换为0-1之间的分数
similarity = (cos_sim.item() + 1) / 2
return similarity
def batch_similarity(self, text_pairs: List[Tuple[str, str]]) -> List[float]:
"""
批量计算相似度
Args:
text_pairs: 文本对列表
Returns:
相似度分数列表
"""
similarities = []
for text1, text2 in text_pairs:
sim = self.calculate_similarity(text1, text2)
similarities.append(sim)
return similarities
def find_most_similar(self, query: str, candidates: List[str], top_k: int = 5) -> List[Tuple[str, float]]:
"""
在候选文本中查找最相似的文本
Args:
query: 查询文本
candidates: 候选文本列表
top_k: 返回最相似的前k个
Returns:
(文本, 相似度) 列表
"""
# 预处理所有文本
processed_query = self.preprocessor.normalize_text(query)
processed_candidates = self.preprocessor.batch_preprocess(candidates)
# 编码所有文本
all_texts = [processed_query] + processed_candidates
embeddings = self.model.encode(all_texts,
convert_to_tensor=True,
device=self.device)
# 计算查询与所有候选的相似度
query_embedding = embeddings[0]
candidate_embeddings = embeddings[1:]
similarities = []
for i, emb in enumerate(candidate_embeddings):
cos_sim = torch.nn.functional.cosine_similarity(query_embedding.unsqueeze(0),
emb.unsqueeze(0))
similarity = (cos_sim.item() + 1) / 2
similarities.append((candidates[i], similarity))
# 按相似度排序
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:top_k]
# 使用示例
def pipeline_example():
"""流水线使用示例"""
# 初始化流水线
pipeline = StructBERTSimilarityPipeline()
# 示例文本对
text_pairs = [
("今天天气很好,我们去公园玩吧!", "今天天氣很好,我們去公園玩吧!"),
("我喜欢吃苹果", "我爱吃香蕉"),
("深度学习是人工智能的一个重要分支", "深度学习属于人工智能领域"),
("这个产品的质量非常好", "这个商品的质量很棒"),
]
print("文本相似度计算结果:")
print("-" * 50)
for text1, text2 in text_pairs:
similarity = pipeline.calculate_similarity(text1, text2)
print(f"文本1: {text1}")
print(f"文本2: {text2}")
print(f"相似度: {similarity:.4f}")
print("-" * 50)
# 查找最相似的文本
query = "如何学习编程"
candidates = [
"编程学习方法",
"学习编程的步骤",
"编程入门指南",
"今天天气真好",
"人工智能的发展",
]
print(f"\n查询: {query}")
print("最相似的候选文本:")
results = pipeline.find_most_similar(query, candidates, top_k=3)
for text, sim in results:
print(f" {text} (相似度: {sim:.4f})")
if __name__ == "__main__":
pipeline_example()
4.2 预处理效果对比
让我们看看预处理前后的效果差异:
def compare_preprocessing_effect():
"""对比预处理效果"""
preprocessor = ChineseTextPreprocessor()
pipeline = StructBERTSimilarityPipeline()
# 测试用例
test_cases = [
{
"text1": "Hello World! 你好世界!",
"text2": "hello world! 你好世界!",
"description": "大小写差异"
},
{
"text1": "我喜欢吃苹果🍎",
"text2": "我喜欢吃苹果",
"description": "表情符号差异"
},
{
"text1": "今天天气很好,我们去公园玩吧!",
"text2": "今天天氣很好,我們去公園玩吧!",
"description": "繁简体差异"
},
{
"text1": "深度学习是AI的重要分支",
"text2": "深度学习是人工智能的重要分支",
"description": "缩写差异"
},
]
print("预处理效果对比:")
print("=" * 80)
for case in test_cases:
text1 = case["text1"]
text2 = case["text2"]
desc = case["description"]
# 预处理前
sim_before = pipeline.calculate_similarity(text1, text2)
# 预处理后
processed1 = preprocessor.normalize_text(text1)
processed2 = preprocessor.normalize_text(text2)
sim_after = pipeline.calculate_similarity(processed1, processed2)
print(f"\n案例: {desc}")
print(f"原始文本1: {text1}")
print(f"原始文本2: {text2}")
print(f"处理后文本1: {processed1}")
print(f"处理后文本2: {processed2}")
print(f"预处理前相似度: {sim_before:.4f}")
print(f"预处理后相似度: {sim_after:.4f}")
print(f"差异: {abs(sim_after - sim_before):.4f}")
print("-" * 80)
# 运行对比
compare_preprocessing_effect()
5. 实际应用场景
5.1 场景一:智能客服问答匹配
在智能客服系统中,用户的问题可能与标准问题库中的表述方式不同:
class CustomerServiceMatcher:
"""客服问答匹配器"""
def __init__(self):
self.pipeline = StructBERTSimilarityPipeline()
self.qa_pairs = [] # 标准问答对
def load_qa_pairs(self, qa_data: List[Tuple[str, str]]):
"""加载标准问答对"""
self.qa_pairs = qa_data
def find_best_answer(self, user_question: str, threshold: float = 0.8) -> str:
"""
查找最佳答案
Args:
user_question: 用户问题
threshold: 相似度阈值
Returns:
最佳答案或默认回复
"""
if not self.qa_pairs:
return "抱歉,我还没有学习相关知识。"
# 预处理用户问题
preprocessor = ChineseTextPreprocessor()
processed_question = preprocessor.normalize_text(user_question)
best_similarity = 0
best_answer = ""
# 遍历所有标准问题
for std_question, answer in self.qa_pairs:
processed_std = preprocessor.normalize_text(std_question)
similarity = self.pipeline.calculate_similarity(processed_question, processed_std)
if similarity > best_similarity:
best_similarity = similarity
best_answer = answer
# 检查是否达到阈值
if best_similarity >= threshold:
return best_answer
else:
return f"我找到的相关答案是:{best_answer}(匹配度:{best_similarity:.2%})"
# 示例:客服问答匹配
def customer_service_example():
"""客服问答匹配示例"""
matcher = CustomerServiceMatcher()
# 标准问答库
qa_data = [
("怎么重置密码", "您可以在登录页面点击'忘记密码',按照提示操作重置密码。"),
("如何修改个人信息", "请登录后进入'个人中心'-'账户设置'修改个人信息。"),
("客服电话是多少", "我们的客服电话是400-123-4567,工作时间是9:00-18:00。"),
("产品怎么退货", "在收货后7天内,商品完好无损的情况下可以申请退货。"),
]
matcher.load_qa_pairs(qa_data)
# 用户问题(可能有各种表述方式)
user_questions = [
"我忘记密码了怎么办?", # 与"怎么重置密码"相似
"想改一下我的资料", # 与"如何修改个人信息"相似
"你们的联系电话?", # 与"客服电话是多少"相似
"这个东西不想要了能退吗", # 与"产品怎么退货"相似
"今天天气怎么样", # 不在知识库中
]
print("客服问答匹配示例:")
print("=" * 60)
for question in user_questions:
answer = matcher.find_best_answer(question, threshold=0.7)
print(f"用户问题: {question}")
print(f"系统回复: {answer}")
print("-" * 60)
5.2 场景二:内容去重检测
在内容平台或新闻聚合应用中,检测重复或高度相似的内容:
class ContentDeduplicator:
"""内容去重检测器"""
def __init__(self, similarity_threshold: float = 0.9):
self.pipeline = StructBERTSimilarityPipeline()
self.preprocessor = ChineseTextPreprocessor()
self.threshold = similarity_threshold
self.content_database = [] # 存储已发布内容
def add_content(self, content: str) -> bool:
"""
添加新内容,检查是否重复
Args:
content: 新内容
Returns:
True表示可以添加(不重复),False表示重复
"""
processed_new = self.preprocessor.normalize_text(content)
# 与已有内容比较
for existing in self.content_database:
processed_existing = self.preprocessor.normalize_text(existing)
similarity = self.pipeline.calculate_similarity(processed_new, processed_existing)
if similarity >= self.threshold:
print(f"检测到重复内容,相似度: {similarity:.2%}")
print(f"新内容: {content[:50]}...")
print(f"已有内容: {existing[:50]}...")
return False
# 没有重复,添加到数据库
self.content_database.append(content)
return True
def batch_check(self, contents: List[str]) -> Tuple[List[str], List[str]]:
"""
批量检查内容
Args:
contents: 内容列表
Returns:
(不重复内容列表, 重复内容列表)
"""
unique_contents = []
duplicate_contents = []
for content in contents:
if self.add_content(content):
unique_contents.append(content)
else:
duplicate_contents.append(content)
return unique_contents, duplicate_contents
# 示例:内容去重检测
def deduplication_example():
"""内容去重示例"""
deduplicator = ContentDeduplicator(similarity_threshold=0.85)
# 测试内容(有些是重复的,有些是相似的)
test_contents = [
"今天天气很好,适合出去散步。",
"今日天气不错,很适合外出散步。", # 与第一条高度相似
"深度学习是人工智能的一个重要分支。",
"人工智能领域中,深度学习是一个重要方向。", # 与第三条相似
"Python是一种流行的编程语言。",
"Java也是一种常用的编程语言。", # 与第五条不相似
"今天天气很好,适合出去散步。", # 与第一条完全相同
]
print("内容去重检测示例:")
print("=" * 60)
unique, duplicates = deduplicator.batch_check(test_contents)
print(f"\n原始内容数量: {len(test_contents)}")
print(f"去重后唯一内容数量: {len(unique)}")
print(f"检测到的重复内容数量: {len(duplicates)}")
print("\n唯一内容:")
for i, content in enumerate(unique, 1):
print(f"{i}. {content}")
print("\n重复内容:")
for i, content in enumerate(duplicates, 1):
print(f"{i}. {content}")
5.3 场景三:文档相似度搜索
在企业知识库或文档管理系统中,快速找到相似文档:
class DocumentSimilaritySearch:
"""文档相似度搜索引擎"""
def __init__(self):
self.pipeline = StructBERTSimilarityPipeline()
self.preprocessor = ChineseTextPreprocessor()
self.documents = [] # 文档列表
self.document_embeddings = None # 文档向量缓存
def add_document(self, doc_id: str, content: str):
"""添加文档"""
processed_content = self.preprocessor.normalize_text(content)
self.documents.append({
'id': doc_id,
'content': content,
'processed': processed_content
})
self.document_embeddings = None # 清空缓存
def build_index(self):
"""构建文档索引(预计算向量)"""
if not self.documents:
return
# 获取所有处理后的文档内容
processed_contents = [doc['processed'] for doc in self.documents]
# 批量编码为向量
self.document_embeddings = self.pipeline.model.encode(
processed_contents,
convert_to_tensor=True,
device=self.pipeline.device
)
def search_similar(self, query: str, top_k: int = 5) -> List[Tuple[str, str, float]]:
"""
搜索相似文档
Args:
query: 查询文本
top_k: 返回最相似的前k个文档
Returns:
(文档ID, 文档内容摘要, 相似度) 列表
"""
if not self.documents:
return []
# 如果没有构建索引,先构建
if self.document_embeddings is None:
self.build_index()
# 预处理查询
processed_query = self.preprocessor.normalize_text(query)
# 编码查询
query_embedding = self.pipeline.model.encode(
processed_query,
convert_to_tensor=True,
device=self.pipeline.device
)
# 计算相似度
similarities = []
for i, doc in enumerate(self.documents):
doc_embedding = self.document_embeddings[i]
cos_sim = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(0),
doc_embedding.unsqueeze(0)
)
similarity = (cos_sim.item() + 1) / 2
similarities.append((doc['id'], doc['content'], similarity))
# 排序并返回top_k
similarities.sort(key=lambda x: x[2], reverse=True)
# 生成结果(包含内容摘要)
results = []
for doc_id, content, similarity in similarities[:top_k]:
# 生成内容摘要(前100字符)
summary = content[:100] + "..." if len(content) > 100 else content
results.append((doc_id, summary, similarity))
return results
# 示例:文档相似度搜索
def document_search_example():
"""文档搜索示例"""
searcher = DocumentSimilaritySearch()
# 添加示例文档
documents = [
("doc1", "机器学习是人工智能的一个分支,它使计算机能够在没有明确编程的情况下学习。"),
("doc2", "深度学习是机器学习的一个子领域,它使用神经网络模拟人脑的工作方式。"),
("doc3", "自然语言处理是人工智能的一个领域,专注于计算机与人类语言之间的交互。"),
("doc4", "计算机视觉使计算机能够从图像和视频中获取高级理解。"),
("doc5", "强化学习是一种机器学习方法,智能体通过与环境互动来学习最佳行为。"),
("doc6", "监督学习使用标记数据训练模型,而非监督学习使用未标记数据。"),
("doc7", "神经网络由相互连接的节点组成,这些节点以类似于人脑神经元的方式工作。"),
("doc8", "卷积神经网络专门用于处理图像识别和计算机视觉任务。"),
("doc9", "循环神经网络适用于序列数据,如时间序列或自然语言。"),
("doc10", "Transformer模型在自然语言处理任务中取得了突破性进展。"),
]
for doc_id, content in documents:
searcher.add_document(doc_id, content)
# 构建索引
searcher.build_index()
# 搜索查询
queries = [
"什么是人工智能学习",
"如何处理图像识别",
"自然语言的技术",
]
print("文档相似度搜索示例:")
print("=" * 80)
for query in queries:
print(f"\n查询: {query}")
results = searcher.search_similar(query, top_k=3)
print("最相关文档:")
for doc_id, summary, similarity in results:
print(f" ID: {doc_id}")
print(f" 摘要: {summary}")
print(f" 相似度: {similarity:.4f}")
print()
print("-" * 80)
6. 性能优化与最佳实践
6.1 批量处理优化
当需要处理大量文本时,批量处理可以显著提高效率:
class OptimizedPreprocessor:
"""优化版预处理器(支持批量处理)"""
def __init__(self, batch_size: int = 32):
self.batch_size = batch_size
self.preprocessor = ChineseTextPreprocessor()
def parallel_preprocess(self, texts: List[str]) -> List[str]:
"""并行预处理(简化版)"""
import concurrent.futures
# 分批处理
batches = [texts[i:i + self.batch_size]
for i in range(0, len(texts), self.batch_size)]
processed_texts = []
# 使用线程池并行处理
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for batch in batches:
future = executor.submit(self._process_batch, batch)
futures.append(future)
for future in concurrent.futures.as_completed(futures):
processed_texts.extend(future.result())
return processed_texts
def _process_batch(self, batch: List[str]) -> List[str]:
"""处理单个批次"""
return [self.preprocessor.normalize_text(text) for text in batch]
def preprocess_large_dataset(self, file_path: str, output_path: str):
"""处理大型数据集"""
processed_lines = []
# 逐行读取和处理
with open(file_path, 'r', encoding='utf-8') as f:
batch = []
for line in f:
line = line.strip()
if line: # 跳过空行
batch.append(line)
# 达到批次大小时处理
if len(batch) >= self.batch_size:
processed_batch = self._process_batch(batch)
processed_lines.extend(processed_batch)
batch = []
# 处理剩余行
if batch:
processed_batch = self._process_batch(batch)
processed_lines.extend(processed_batch)
# 保存结果
with open(output_path, 'w', encoding='utf-8') as f:
for line in processed_lines:
f.write(line + '\n')
print(f"处理完成!共处理 {len(processed_lines)} 行文本。")
6.2 缓存机制
对于重复的文本处理,可以使用缓存提高效率:
import hashlib
from functools import lru_cache
class CachedPreprocessor:
"""带缓存的预处理器"""
def __init__(self):
self.preprocessor = ChineseTextPreprocessor()
self.cache = {} # 简单缓存字典
self.cache_hits = 0
self.cache_misses = 0
def _get_hash(self, text: str) -> str:
"""生成文本哈希值(用于缓存键)"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def normalize_with_cache(self, text: str) -> str:
"""带缓存的标准化处理"""
# 生成缓存键
cache_key = self._get_hash(text)
# 检查缓存
if cache_key in self.cache:
self.cache_hits += 1
return self.cache[cache_key]
# 缓存未命中,进行处理
self.cache_misses += 1
processed_text = self.preprocessor.normalize_text(text)
# 存入缓存
self.cache[cache_key] = processed_text
# 如果缓存太大,清理一部分
if len(self.cache) > 10000: # 最大缓存大小
# 简单策略:清理前1000个
keys_to_remove = list(self.cache.keys())[:1000]
for key in keys_to_remove:
del self.cache[key]
return processed_text
def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
return {
'cache_size': len(self.cache),
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'hit_rate': self.cache_hits / (self.cache_hits + self.cache_misses)
if (self.cache_hits + self.cache_misses) > 0 else 0
}
@lru_cache(maxsize=10000)
def normalize_lru(self, text: str) -> str:
"""使用LRU缓存的标准化处理"""
return self.preprocessor.normalize_text(text)
# 测试缓存效果
def test_cache_performance():
"""测试缓存性能"""
preprocessor = CachedPreprocessor()
# 测试文本(有些重复)
test_texts = [
"今天天气很好",
"今天天气很好", # 重复
"我喜欢编程",
"深度学习很有趣",
"今天天气很好", # 重复
"我喜欢编程", # 重复
"机器学习",
"人工智能",
"今天天气很好", # 重复
] * 100 # 重复100次
print("开始处理...")
# 处理所有文本
for text in test_texts:
_ = preprocessor.normalize_with_cache(text)
# 获取统计信息
stats = preprocessor.get_cache_stats()
print(f"处理完成!")
print(f"缓存大小: {stats['cache_size']}")
print(f"缓存命中: {stats['cache_hits']}")
print(f"缓存未命中: {stats['cache_misses']}")
print(f"命中率: {stats['hit_rate']:.2%}")
6.3 错误处理与日志
在生产环境中,良好的错误处理和日志记录很重要:
import logging
from datetime import datetime
class RobustPreprocessor:
"""健壮性预处理器(带错误处理和日志)"""
def __init__(self, log_file: str = None):
self.preprocessor = ChineseTextPreprocessor()
# 配置日志
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
# 控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 文件处理器(如果指定了日志文件)
if log_file:
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setLevel(logging.INFO)
self.logger.addHandler(file_handler)
# 格式化器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(formatter)
if log_file:
file_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
def safe_normalize(self, text: str, text_id: str = None) -> str:
"""
安全的标准化处理(带错误处理)
Args:
text: 要处理的文本
text_id: 文本标识(用于日志)
Returns:
处理后的文本,如果出错返回空字符串
"""
try:
start_time = datetime.now()
# 检查输入
if text is None:
self.logger.warning(f"文本为空,text_id: {text_id}")
return ""
if not isinstance(text, str):
self.logger.warning(f"文本不是字符串类型,text_id: {text_id}, 类型: {type(text)}")
text = str(text)
# 处理
result = self.preprocessor.normalize_text(text)
# 记录处理时间
process_time = (datetime.now() - start_time).total_seconds() * 1000
# 记录日志
self.logger.info(
f"文本处理完成 - "
f"text_id: {text_id}, "
f"原始长度: {len(text)}, "
f"处理后长度: {len(result)}, "
f"处理时间: {process_time:.2f}ms"
)
return result
except Exception as e:
# 记录错误
self.logger.error(
f"文本处理失败 - "
f"text_id: {text_id}, "
f"错误: {str(e)}, "
f"文本前50字符: {str(text)[:50] if text else 'None'}"
)
# 返回空字符串或原始文本(根据需求)
return ""
def batch_safe_normalize(self, texts: List[str], text_ids: List[str] = None) -> List[str]:
"""
批量安全处理
Args:
texts: 文本列表
text_ids: 文本ID列表(可选)
Returns:
处理后的文本列表
"""
if text_ids is None:
text_ids = [f"text_{i}" for i in range(len(texts))]
if len(texts) != len(text_ids):
self.logger.warning("文本列表和ID列表长度不一致")
text_ids = [f"text_{i}" for i in range(len(texts))]
results = []
success_count = 0
fail_count = 0
for text, text_id in zip(texts, text_ids):
try:
result = self.safe_normalize(text, text_id)
results.append(result)
success_count += 1
except Exception as e:
self.logger.error(f"批量处理失败 - text_id: {text_id}, 错误: {str(e)}")
results.append("") # 失败时返回空字符串
fail_count += 1
self.logger.info(
f"批量处理完成 - "
f"总数: {len(texts)}, "
f"成功: {success_count}, "
f"失败: {fail_count}, "
f"成功率: {success_count/len(texts):.2%}"
)
return results
# 测试健壮性处理器
def test_robust_preprocessor():
"""测试健壮性预处理器"""
# 创建预处理器(带日志)
preprocessor = RobustPreprocessor(log_file="preprocessor.log")
# 测试各种输入
test_cases = [
("正常文本", "今天天气很好,我们去公园玩吧!"),
("空文本", ""),
("None输入", None),
("非字符串", 12345),
("特殊字符", "Hello\nWorld\t测试\r\n"),
("超长文本", "测试" * 1000),
]
print("测试健壮性预处理器:")
print("=" * 60)
for case_name, text in test_cases:
print(f"\n测试用例: {case_name}")
print(f"输入: {str(text)[:50] if text else 'None'}")
result = preprocessor.safe_normalize(text, f"test_{case_name}")
print(f"输出: {result[:50] if result else '空字符串'}")
print(f"输出长度: {len(result)}")
7. 总结
7.1 关键要点回顾
通过本文的讲解,你应该掌握了以下几个关键点:
-
预处理的重要性:中文文本标准化预处理能显著提升StructBERT相似度模型的准确性和稳定性。
-
完整的处理流程:从文本清理、字符标准化到标点统一,每个步骤都有其特定作用。
-
与模型的无缝集成:预处理脚本可以轻松集成到现有的StructBERT相似度计算流程中。
-
实际应用场景:预处理技术在智能客服、内容去重、文档搜索等场景中都有重要应用。
-
性能优化技巧:通过批量处理、缓存机制和错误处理,可以提升预处理效率和生产环境稳定性。
7.2 最佳实践建议
根据我的实践经验,给你几个实用建议:
-
预处理程度要适中:不要过度预处理,保留对语义理解重要的信息。
-
根据任务调整:不同的相似度任务可能需要不同的预处理策略。
-
测试验证:在实际数据上测试预处理效果,确保它真的提升了模型性能。
-
监控和优化:在生产环境中监控预处理效果,根据实际情况调整参数。
-
保持更新:中文语言使用习惯在变化,预处理规则也需要定期更新。
7.3 下一步学习建议
如果你对这个主题感兴趣,可以继续深入学习:
-
探索更多预处理技术:如词干还原、词形还原(对中文有限)、命名实体识别等。
-
了解其他相似度模型:除了StructBERT,还有BERT、RoBERTa、ERNIE等模型。
-
学习向量检索技术:如Faiss、Annoy等向量数据库,用于大规模相似度搜索。
-
实践完整项目:尝试构建一个完整的文本相似度应用系统。
-
关注最新研究:自然语言处理领域发展很快,新的预处理技术不断出现。
预处理可能看起来是个小细节,但在实际应用中,它往往是决定项目成败的关键因素之一。一个好的预处理流程能让模型性能提升10%-30%,这个投入是非常值得的。
希望这个预处理脚本和教程能帮助你在StructBERT相似度计算任务上取得更好的效果!
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。
更多推荐
所有评论(0)