GLM-OCR开源OCR实战:基于FastAPI封装RESTful接口供Java系统调用

1. 引言:为什么需要为GLM-OCR封装API?

如果你正在开发一个Java应用,比如一个文档管理系统、一个发票识别平台,或者一个移动端的拍照识字工具,你可能会遇到一个很实际的问题:如何让Java系统调用Python的AI模型?

GLM-OCR确实很强大,它支持文本、表格、公式的识别,但它的原生接口是基于Python的Gradio。对于Java开发者来说,直接调用Python脚本不仅麻烦,而且难以集成到现有的微服务架构中。想象一下,你的Java服务需要识别一张发票,难道要每次都用Runtime.exec()去调用Python脚本吗?这显然不是个好主意。

这就是为什么我们需要为GLM-OCR封装一个RESTful API。通过FastAPI,我们可以创建一个标准的HTTP服务,让Java系统像调用其他微服务一样,通过简单的HTTP请求就能获得OCR识别结果。这样做有几个明显的好处:

  • 解耦:Java系统不需要关心Python环境、模型加载等细节
  • 标准化:使用HTTP协议,任何语言都能调用
  • 可扩展:可以轻松添加负载均衡、监控、认证等功能
  • 易于维护:API服务可以独立部署和升级

在本文中,我将带你一步步实现这个目标。即使你之前没怎么接触过FastAPI,也能跟着完成。我们会从环境准备开始,到API设计、代码实现,最后还会给出Java客户端的调用示例。读完这篇文章,你就能在自己的项目中快速集成GLM-OCR的能力了。

2. 环境准备与项目结构

2.1 检查现有环境

首先,我们需要确认GLM-OCR已经正确安装并可以运行。根据你提供的项目说明,GLM-OCR已经部署在/root/GLM-OCR目录下。让我们先验证一下:

# 进入项目目录
cd /root/GLM-OCR

# 检查conda环境
/opt/miniconda3/envs/py310/bin/python --version
# 应该输出: Python 3.10.19

# 检查必要的Python包
/opt/miniconda3/envs/py310/bin/pip list | grep -E "(gradio|transformers)"

如果环境正常,你应该能看到gradio和transformers已经安装。接下来,我们需要安装FastAPI和相关依赖。

2.2 安装FastAPI及相关依赖

FastAPI是一个现代、快速(高性能)的Web框架,用于构建API。它基于标准Python类型提示,自动生成API文档,非常适合我们这种场景。

# 激活conda环境(如果还没激活)
source /opt/miniconda3/bin/activate py310

# 安装FastAPI、Uvicorn(ASGI服务器)和Pydantic(数据验证)
pip install fastapi uvicorn pydantic python-multipart pillow

# 验证安装
python -c "import fastapi; print(f'FastAPI版本: {fastapi.__version__}')"

2.3 创建API项目结构

一个好的项目结构能让代码更清晰,也便于维护。我建议按以下方式组织:

/root/GLM-OCR-api/
├── app/
│   ├── __init__.py
│   ├── main.py              # FastAPI应用入口
│   ├── api/
│   │   ├── __init__.py
│   │   └── endpoints.py     # API端点定义
│   ├── core/
│   │   ├── __init__.py
│   │   ├── config.py        # 配置文件
│   │   └── ocr_service.py   # OCR服务封装
│   ├── models/
│   │   ├── __init__.py
│   │   └── schemas.py       # 数据模型定义
│   └── utils/
│       ├── __init__.py
│       └── file_utils.py    # 文件处理工具
├── requirements.txt          # 依赖列表
├── start_api.sh             # 启动脚本
└── README.md                # 项目说明

让我们先创建这个目录结构:

# 创建项目目录
mkdir -p /root/GLM-OCR-api/app/{api,core,models,utils}

# 创建必要的文件
touch /root/GLM-OCR-api/app/__init__.py
touch /root/GLM-OCR-api/app/main.py
touch /root/GLM-OCR-api/app/api/__init__.py
touch /root/GLM-OCR-api/app/api/endpoints.py
touch /root/GLM-OCR-api/app/core/__init__.py
touch /root/GLM-OCR-api/app/core/config.py
touch /root/GLM-OCR-api/app/core/ocr_service.py
touch /root/GLM-OCR-api/app/models/__init__.py
touch /root/GLM-OCR-api/app/models/schemas.py
touch /root/GLM-OCR-api/app/utils/__init__.py
touch /root/GLM-OCR-api/app/utils/file_utils.py
touch /root/GLM-OCR-api/requirements.txt
touch /root/GLM-OCR-api/start_api.sh
touch /root/GLM-OCR-api/README.md

2.4 编写requirements.txt

/root/GLM-OCR-api/requirements.txt中添加以下内容:

fastapi==0.104.1
uvicorn[standard]==0.24.0
pydantic==2.5.0
python-multipart==0.0.6
pillow==10.1.0
gradio==4.19.2
transformers==5.0.1.dev0

这样,我们的基础环境就准备好了。接下来,我们开始设计API接口。

3. API设计与数据模型

3.1 确定API需求

在开始编码之前,我们需要明确API要提供哪些功能。基于GLM-OCR的能力,我们的API应该支持:

  1. 文本识别:识别图片中的文字
  2. 表格识别:识别图片中的表格,返回结构化数据
  3. 公式识别:识别图片中的数学公式
  4. 健康检查:检查服务是否正常运行
  5. 批量处理:一次处理多张图片(可选,根据需求)

3.2 设计API端点

基于RESTful设计原则,我建议设计以下端点:

方法 端点 功能 请求体 响应
POST /api/v1/ocr/text 文本识别 图片文件 + 参数 识别结果
POST /api/v1/ocr/table 表格识别 图片文件 + 参数 表格数据(JSON/CSV)
POST /api/v1/ocr/formula 公式识别 图片文件 + 参数 LaTeX公式
GET /api/health 健康检查 服务状态
POST /api/v1/ocr/batch 批量识别 多张图片 批量结果

3.3 定义数据模型

在Pydantic的帮助下,我们可以定义清晰的数据模型。这不仅能自动验证输入数据,还能生成漂亮的API文档。

打开/root/GLM-OCR-api/app/models/schemas.py,添加以下内容:

from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
from enum import Enum

class TaskType(str, Enum):
    """OCR任务类型枚举"""
    TEXT = "text"
    TABLE = "table"
    FORMULA = "formula"

class OCRRequest(BaseModel):
    """OCR请求基础模型"""
    task_type: TaskType = Field(
        default=TaskType.TEXT,
        description="识别任务类型:text(文本)、table(表格)、formula(公式)"
    )
    language: Optional[str] = Field(
        default="zh",
        description="语言代码,如:zh(中文)、en(英文)"
    )
    confidence_threshold: Optional[float] = Field(
        default=0.5,
        ge=0.0,
        le=1.0,
        description="置信度阈值,0-1之间"
    )

class TextOCRResponse(BaseModel):
    """文本识别响应模型"""
    success: bool = Field(description="是否成功")
    text: Optional[str] = Field(default=None, description="识别出的文本")
    confidence: Optional[float] = Field(default=None, description="整体置信度")
    bounding_boxes: Optional[List[Dict[str, Any]]] = Field(
        default=None,
        description="文本区域边界框列表"
    )
    processing_time: float = Field(description="处理时间(秒)")
    error_message: Optional[str] = Field(default=None, description="错误信息")

class TableOCRResponse(BaseModel):
    """表格识别响应模型"""
    success: bool = Field(description="是否成功")
    table_data: Optional[List[List[str]]] = Field(
        default=None,
        description="表格数据,二维数组"
    )
    csv_data: Optional[str] = Field(default=None, description="CSV格式数据")
    html_table: Optional[str] = Field(default=None, description="HTML表格")
    processing_time: float = Field(description="处理时间(秒)")
    error_message: Optional[str] = Field(default=None, description="错误信息")

class FormulaOCRResponse(BaseModel):
    """公式识别响应模型"""
    success: bool = Field(description="是否成功")
    latex: Optional[str] = Field(default=None, description="LaTeX公式")
    confidence: Optional[float] = Field(default=None, description="置信度")
    processing_time: float = Field(description="处理时间(秒)")
    error_message: Optional[str] = Field(default=None, description="错误信息")

class BatchOCRRequest(BaseModel):
    """批量OCR请求模型"""
    tasks: List[Dict[str, Any]] = Field(
        description="任务列表,每个任务包含图片和参数"
    )

class BatchOCRResponse(BaseModel):
    """批量OCR响应模型"""
    success: bool = Field(description="是否成功")
    results: List[Dict[str, Any]] = Field(description="识别结果列表")
    total_time: float = Field(description="总处理时间(秒)")
    error_message: Optional[str] = Field(default=None, description="错误信息")

class HealthResponse(BaseModel):
    """健康检查响应模型"""
    status: str = Field(description="服务状态")
    version: str = Field(description="API版本")
    model_loaded: bool = Field(description="模型是否已加载")
    uptime: float = Field(description="服务运行时间(秒)")

3.4 配置管理

接下来,我们创建配置文件。打开/root/GLM-OCR-api/app/core/config.py

import os
from typing import Optional
from pydantic_settings import BaseSettings

class Settings(BaseSettings):
    """应用配置"""
    
    # API配置
    API_TITLE: str = "GLM-OCR API"
    API_VERSION: str = "1.0.0"
    API_DESCRIPTION: str = "基于GLM-OCR的RESTful OCR识别服务"
    
    # 服务器配置
    HOST: str = "0.0.0.0"
    PORT: int = 8000
    RELOAD: bool = False  # 生产环境设为False
    
    # GLM-OCR配置
    GLM_OCR_HOST: str = "http://localhost:7860"
    GLM_OCR_TIMEOUT: int = 300  # 超时时间(秒)
    
    # 文件配置
    UPLOAD_DIR: str = "/tmp/glm_ocr_uploads"
    MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024  # 10MB
    ALLOWED_EXTENSIONS: set = {".png", ".jpg", ".jpeg", ".webp"}
    
    # 日志配置
    LOG_LEVEL: str = "INFO"
    LOG_FILE: str = "/var/log/glm_ocr_api.log"
    
    # 性能配置
    MAX_WORKERS: int = 4
    BATCH_SIZE: int = 10  # 批量处理最大数量
    
    class Config:
        env_file = ".env"
        case_sensitive = True

def get_settings() -> Settings:
    """获取配置实例"""
    return Settings()

# 创建配置实例
settings = get_settings()

# 确保上传目录存在
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)

注意:这里我们使用了pydantic-settings来管理配置,如果需要可以从环境变量或.env文件读取配置。你可以通过pip install pydantic-settings安装它。

现在,我们的API设计和数据模型已经准备好了。接下来,我们实现核心的OCR服务封装。

4. 封装GLM-OCR服务

4.1 创建OCR服务类

GLM-OCR原生是通过Gradio Client来调用的。我们需要封装这个调用过程,让它更易于在API中使用。

打开/root/GLM-OCR-api/app/core/ocr_service.py

import os
import time
import logging
from typing import Optional, Dict, Any, List, Tuple
from pathlib import Path
import tempfile

from gradio_client import Client, handle_file

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class GLMOCRService:
    """GLM-OCR服务封装类"""
    
    def __init__(self, host: str = "http://localhost:7860", timeout: int = 300):
        """
        初始化OCR服务
        
        Args:
            host: GLM-OCR服务地址
            timeout: 超时时间(秒)
        """
        self.host = host
        self.timeout = timeout
        self.client = None
        self._init_time = time.time()
        
    def connect(self) -> bool:
        """连接到GLM-OCR服务"""
        try:
            logger.info(f"正在连接到GLM-OCR服务: {self.host}")
            self.client = Client(self.host, timeout=self.timeout)
            logger.info("GLM-OCR服务连接成功")
            return True
        except Exception as e:
            logger.error(f"连接GLM-OCR服务失败: {e}")
            return False
    
    def disconnect(self):
        """断开连接"""
        self.client = None
        logger.info("GLM-OCR服务连接已断开")
    
    def is_connected(self) -> bool:
        """检查是否已连接"""
        return self.client is not None
    
    def get_uptime(self) -> float:
        """获取服务运行时间"""
        return time.time() - self._init_time
    
    def _prepare_image(self, image_path: str) -> str:
        """
        准备图片文件
        
        Args:
            image_path: 图片路径
            
        Returns:
            处理后的图片路径
        """
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"图片文件不存在: {image_path}")
        
        # 检查文件扩展名
        ext = Path(image_path).suffix.lower()
        if ext not in ['.png', '.jpg', '.jpeg', '.webp']:
            raise ValueError(f"不支持的图片格式: {ext}")
        
        return image_path
    
    def _get_prompt_by_task(self, task_type: str) -> str:
        """
        根据任务类型获取对应的prompt
        
        Args:
            task_type: 任务类型(text/table/formula)
            
        Returns:
            prompt字符串
        """
        prompts = {
            "text": "Text Recognition:",
            "table": "Table Recognition:",
            "formula": "Formula Recognition:"
        }
        
        if task_type not in prompts:
            raise ValueError(f"不支持的任务类型: {task_type}")
        
        return prompts[task_type]
    
    def recognize_text(
        self, 
        image_path: str, 
        language: str = "zh",
        confidence_threshold: float = 0.5
    ) -> Dict[str, Any]:
        """
        文本识别
        
        Args:
            image_path: 图片路径
            language: 语言代码
            confidence_threshold: 置信度阈值
            
        Returns:
            识别结果字典
        """
        start_time = time.time()
        
        try:
            # 准备图片
            image_path = self._prepare_image(image_path)
            
            # 获取prompt
            prompt = self._get_prompt_by_task("text")
            
            # 调用GLM-OCR
            logger.info(f"开始文本识别: {image_path}")
            result = self.client.predict(
                image_path=image_path,
                prompt=prompt,
                api_name="/predict"
            )
            
            # 处理结果
            processing_time = time.time() - start_time
            
            # 这里可以根据实际返回结果进行解析
            # GLM-OCR返回的结果格式可能需要根据实际情况调整
            if result and isinstance(result, str):
                # 假设返回的是纯文本
                return {
                    "success": True,
                    "text": result.strip(),
                    "confidence": 0.95,  # 示例值,实际应从结果中提取
                    "processing_time": processing_time,
                    "language": language
                }
            else:
                return {
                    "success": False,
                    "error_message": "识别结果为空或格式错误",
                    "processing_time": processing_time
                }
                
        except Exception as e:
            processing_time = time.time() - start_time
            logger.error(f"文本识别失败: {e}")
            return {
                "success": False,
                "error_message": str(e),
                "processing_time": processing_time
            }
    
    def recognize_table(
        self, 
        image_path: str,
        output_format: str = "json"
    ) -> Dict[str, Any]:
        """
        表格识别
        
        Args:
            image_path: 图片路径
            output_format: 输出格式(json/csv/html)
            
        Returns:
            表格识别结果
        """
        start_time = time.time()
        
        try:
            # 准备图片
            image_path = self._prepare_image(image_path)
            
            # 获取prompt
            prompt = self._get_prompt_by_task("table")
            
            # 调用GLM-OCR
            logger.info(f"开始表格识别: {image_path}")
            result = self.client.predict(
                image_path=image_path,
                prompt=prompt,
                api_name="/predict"
            )
            
            processing_time = time.time() - start_time
            
            # 解析表格结果
            # 这里需要根据GLM-OCR的实际返回格式进行解析
            # 以下是一个示例解析逻辑
            table_data = self._parse_table_result(result)
            
            # 根据输出格式转换
            output_data = None
            if output_format == "csv" and table_data:
                output_data = self._table_to_csv(table_data)
            elif output_format == "html" and table_data:
                output_data = self._table_to_html(table_data)
            
            return {
                "success": True,
                "table_data": table_data,
                "csv_data": output_data if output_format == "csv" else None,
                "html_table": output_data if output_format == "html" else None,
                "processing_time": processing_time
            }
                
        except Exception as e:
            processing_time = time.time() - start_time
            logger.error(f"表格识别失败: {e}")
            return {
                "success": False,
                "error_message": str(e),
                "processing_time": processing_time
            }
    
    def recognize_formula(
        self, 
        image_path: str
    ) -> Dict[str, Any]:
        """
        公式识别
        
        Args:
            image_path: 图片路径
            
        Returns:
            公式识别结果
        """
        start_time = time.time()
        
        try:
            # 准备图片
            image_path = self._prepare_image(image_path)
            
            # 获取prompt
            prompt = self._get_prompt_by_task("formula")
            
            # 调用GLM-OCR
            logger.info(f"开始公式识别: {image_path}")
            result = self.client.predict(
                image_path=image_path,
                prompt=prompt,
                api_name="/predict"
            )
            
            processing_time = time.time() - start_time
            
            # 解析公式结果
            # 假设返回的是LaTeX格式
            if result and isinstance(result, str):
                latex_formula = result.strip()
                return {
                    "success": True,
                    "latex": latex_formula,
                    "confidence": 0.9,  # 示例值
                    "processing_time": processing_time
                }
            else:
                return {
                    "success": False,
                    "error_message": "公式识别结果为空",
                    "processing_time": processing_time
                }
                
        except Exception as e:
            processing_time = time.time() - start_time
            logger.error(f"公式识别失败: {e}")
            return {
                "success": False,
                "error_message": str(e),
                "processing_time": processing_time
            }
    
    def batch_recognize(
        self, 
        tasks: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        """
        批量识别
        
        Args:
            tasks: 任务列表,每个任务包含image_path和task_type
            
        Returns:
            批量识别结果
        """
        start_time = time.time()
        results = []
        
        for i, task in enumerate(tasks):
            try:
                image_path = task.get("image_path")
                task_type = task.get("task_type", "text")
                params = task.get("params", {})
                
                logger.info(f"处理批量任务 {i+1}/{len(tasks)}: {task_type} - {image_path}")
                
                # 根据任务类型调用不同的方法
                if task_type == "text":
                    result = self.recognize_text(image_path, **params)
                elif task_type == "table":
                    result = self.recognize_table(image_path, **params)
                elif task_type == "formula":
                    result = self.recognize_formula(image_path)
                else:
                    result = {
                        "success": False,
                        "error_message": f"不支持的任务类型: {task_type}"
                    }
                
                results.append({
                    "task_id": i,
                    "image_path": image_path,
                    "task_type": task_type,
                    "result": result
                })
                
            except Exception as e:
                results.append({
                    "task_id": i,
                    "image_path": task.get("image_path", "unknown"),
                    "task_type": task.get("task_type", "unknown"),
                    "result": {
                        "success": False,
                        "error_message": str(e)
                    }
                })
        
        total_time = time.time() - start_time
        
        return {
            "success": True,
            "results": results,
            "total_time": total_time,
            "processed_count": len(tasks),
            "success_count": sum(1 for r in results if r["result"]["success"])
        }
    
    def _parse_table_result(self, result: Any) -> List[List[str]]:
        """
        解析表格识别结果
        
        Args:
            result: GLM-OCR返回的原始结果
            
        Returns:
            二维数组表示的表格数据
        """
        # 这里需要根据GLM-OCR的实际返回格式进行解析
        # 以下是一个示例解析逻辑
        if not result:
            return []
        
        # 假设result是字符串格式的表格
        if isinstance(result, str):
            # 尝试按行分割
            lines = result.strip().split('\n')
            table_data = []
            
            for line in lines:
                # 假设每行用|或制表符分隔
                if '|' in line:
                    cells = [cell.strip() for cell in line.split('|') if cell.strip()]
                elif '\t' in line:
                    cells = [cell.strip() for cell in line.split('\t') if cell.strip()]
                else:
                    cells = [line.strip()]
                
                if cells:
                    table_data.append(cells)
            
            return table_data
        
        return []
    
    def _table_to_csv(self, table_data: List[List[str]]) -> str:
        """将表格数据转换为CSV格式"""
        import csv
        from io import StringIO
        
        output = StringIO()
        writer = csv.writer(output)
        writer.writerows(table_data)
        
        return output.getvalue()
    
    def _table_to_html(self, table_data: List[List[str]]) -> str:
        """将表格数据转换为HTML格式"""
        if not table_data:
            return ""
        
        html = "<table border='1' style='border-collapse: collapse;'>\n"
        
        for row in table_data:
            html += "  <tr>\n"
            for cell in row:
                html += f"    <td>{cell}</td>\n"
            html += "  </tr>\n"
        
        html += "</table>"
        return html

# 创建全局服务实例
_ocr_service = None

def get_ocr_service() -> GLMOCRService:
    """获取OCR服务实例(单例模式)"""
    global _ocr_service
    if _ocr_service is None:
        from .config import settings
        _ocr_service = GLMOCRService(
            host=settings.GLM_OCR_HOST,
            timeout=settings.GLM_OCR_TIMEOUT
        )
        if not _ocr_service.connect():
            raise RuntimeError("无法连接到GLM-OCR服务")
    return _ocr_service

这个服务类封装了GLM-OCR的所有功能,并提供了统一的接口。接下来,我们创建API端点。

5. 实现FastAPI端点

5.1 创建API端点

打开/root/GLM-OCR-api/app/api/endpoints.py

import os
import time
import uuid
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse, FileResponse

from app.core.ocr_service import get_ocr_service
from app.models.schemas import (
    OCRRequest, TextOCRResponse, TableOCRResponse, 
    FormulaOCRResponse, BatchOCRRequest, BatchOCRResponse,
    HealthResponse, TaskType
)
from app.utils.file_utils import save_upload_file, cleanup_file

router = APIRouter()

@router.get("/health", response_model=HealthResponse)
async def health_check():
    """健康检查端点"""
    try:
        ocr_service = get_ocr_service()
        
        return HealthResponse(
            status="healthy",
            version="1.0.0",
            model_loaded=ocr_service.is_connected(),
            uptime=ocr_service.get_uptime()
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"服务异常: {str(e)}")

@router.post("/ocr/text", response_model=TextOCRResponse)
async def recognize_text(
    background_tasks: BackgroundTasks,
    image: UploadFile = File(..., description="上传的图片文件"),
    language: str = Form("zh", description="语言代码"),
    confidence_threshold: float = Form(0.5, description="置信度阈值")
):
    """
    文本识别端点
    
    - **image**: 上传的图片文件(PNG/JPG/WEBP)
    - **language**: 语言代码,默认中文(zh)
    - **confidence_threshold**: 置信度阈值,0-1之间
    """
    # 验证文件类型
    if not image.filename:
        raise HTTPException(status_code=400, detail="未提供文件名")
    
    file_ext = os.path.splitext(image.filename)[1].lower()
    if file_ext not in ['.png', '.jpg', '.jpeg', '.webp']:
        raise HTTPException(status_code=400, detail="不支持的文件格式")
    
    try:
        # 保存上传的文件
        temp_file_path = save_upload_file(image)
        
        # 添加清理任务
        background_tasks.add_task(cleanup_file, temp_file_path)
        
        # 调用OCR服务
        ocr_service = get_ocr_service()
        result = ocr_service.recognize_text(
            image_path=temp_file_path,
            language=language,
            confidence_threshold=confidence_threshold
        )
        
        return TextOCRResponse(**result)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"文本识别失败: {str(e)}")

@router.post("/ocr/table", response_model=TableOCRResponse)
async def recognize_table(
    background_tasks: BackgroundTasks,
    image: UploadFile = File(..., description="上传的图片文件"),
    output_format: str = Form("json", description="输出格式:json/csv/html")
):
    """
    表格识别端点
    
    - **image**: 上传的图片文件(PNG/JPG/WEBP)
    - **output_format**: 输出格式,支持json/csv/html
    """
    # 验证文件类型
    if not image.filename:
        raise HTTPException(status_code=400, detail="未提供文件名")
    
    file_ext = os.path.splitext(image.filename)[1].lower()
    if file_ext not in ['.png', '.jpg', '.jpeg', '.webp']:
        raise HTTPException(status_code=400, detail="不支持的文件格式")
    
    # 验证输出格式
    if output_format not in ["json", "csv", "html"]:
        raise HTTPException(status_code=400, detail="不支持的输出格式")
    
    try:
        # 保存上传的文件
        temp_file_path = save_upload_file(image)
        
        # 添加清理任务
        background_tasks.add_task(cleanup_file, temp_file_path)
        
        # 调用OCR服务
        ocr_service = get_ocr_service()
        result = ocr_service.recognize_table(
            image_path=temp_file_path,
            output_format=output_format
        )
        
        return TableOCRResponse(**result)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"表格识别失败: {str(e)}")

@router.post("/ocr/formula", response_model=FormulaOCRResponse)
async def recognize_formula(
    background_tasks: BackgroundTasks,
    image: UploadFile = File(..., description="上传的图片文件")
):
    """
    公式识别端点
    
    - **image**: 上传的图片文件(PNG/JPG/WEBP)
    """
    # 验证文件类型
    if not image.filename:
        raise HTTPException(status_code=400, detail="未提供文件名")
    
    file_ext = os.path.splitext(image.filename)[1].lower()
    if file_ext not in ['.png', '.jpg', '.jpeg', '.webp']:
        raise HTTPException(status_code=400, detail="不支持的文件格式")
    
    try:
        # 保存上传的文件
        temp_file_path = save_upload_file(image)
        
        # 添加清理任务
        background_tasks.add_task(cleanup_file, temp_file_path)
        
        # 调用OCR服务
        ocr_service = get_ocr_service()
        result = ocr_service.recognize_formula(
            image_path=temp_file_path
        )
        
        return FormulaOCRResponse(**result)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"公式识别失败: {str(e)}")

@router.post("/ocr/batch", response_model=BatchOCRResponse)
async def batch_recognize(
    request: BatchOCRRequest
):
    """
    批量识别端点
    
    - **tasks**: 任务列表,每个任务包含:
        - image_path: 图片路径(本地路径或URL)
        - task_type: 任务类型(text/table/formula)
        - params: 额外参数
    """
    try:
        # 验证任务数量
        if len(request.tasks) > 100:  # 限制批量任务数量
            raise HTTPException(status_code=400, detail="批量任务数量不能超过100")
        
        # 调用OCR服务
        ocr_service = get_ocr_service()
        result = ocr_service.batch_recognize(request.tasks)
        
        return BatchOCRResponse(**result)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"批量识别失败: {str(e)}")

@router.post("/ocr/url")
async def recognize_from_url(
    image_url: str = Form(..., description="图片URL"),
    task_type: TaskType = Form(TaskType.TEXT, description="任务类型"),
    language: str = Form("zh", description="语言代码"),
    output_format: str = Form("json", description="输出格式(仅表格识别有效)")
):
    """
    从URL识别图片
    
    - **image_url**: 图片URL地址
    - **task_type**: 任务类型
    - **language**: 语言代码
    - **output_format**: 输出格式
    """
    try:
        import requests
        from io import BytesIO
        
        # 下载图片
        response = requests.get(image_url, timeout=30)
        if response.status_code != 200:
            raise HTTPException(status_code=400, detail="无法下载图片")
        
        # 生成临时文件名
        temp_file_path = f"/tmp/glm_ocr_{uuid.uuid4().hex}.jpg"
        
        # 保存图片
        with open(temp_file_path, 'wb') as f:
            f.write(response.content)
        
        # 调用OCR服务
        ocr_service = get_ocr_service()
        
        if task_type == TaskType.TEXT:
            result = ocr_service.recognize_text(
                image_path=temp_file_path,
                language=language
            )
            return TextOCRResponse(**result)
            
        elif task_type == TaskType.TABLE:
            result = ocr_service.recognize_table(
                image_path=temp_file_path,
                output_format=output_format
            )
            return TableOCRResponse(**result)
            
        elif task_type == TaskType.FORMULA:
            result = ocr_service.recognize_formula(
                image_path=temp_file_path
            )
            return FormulaOCRResponse(**result)
        
        else:
            raise HTTPException(status_code=400, detail="不支持的任务类型")
            
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"URL识别失败: {str(e)}")
    finally:
        # 清理临时文件
        if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
            os.remove(temp_file_path)

@router.get("/ocr/supported_formats")
async def get_supported_formats():
    """获取支持的图片格式"""
    return {
        "supported_formats": [".png", ".jpg", ".jpeg", ".webp"],
        "max_file_size": "10MB"
    }

5.2 创建文件工具

打开/root/GLM-OCR-api/app/utils/file_utils.py

import os
import uuid
from fastapi import UploadFile
from pathlib import Path

from app.core.config import settings

def save_upload_file(upload_file: UploadFile) -> str:
    """
    保存上传的文件到临时目录
    
    Args:
        upload_file: 上传的文件对象
        
    Returns:
        保存的文件路径
    """
    # 生成唯一文件名
    file_ext = Path(upload_file.filename).suffix if upload_file.filename else ".jpg"
    filename = f"{uuid.uuid4().hex}{file_ext}"
    file_path = os.path.join(settings.UPLOAD_DIR, filename)
    
    # 确保目录存在
    os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
    
    # 保存文件
    with open(file_path, "wb") as buffer:
        content = upload_file.file.read()
        # 检查文件大小
        if len(content) > settings.MAX_UPLOAD_SIZE:
            raise ValueError(f"文件大小超过限制: {len(content)} > {settings.MAX_UPLOAD_SIZE}")
        buffer.write(content)
    
    return file_path

def cleanup_file(file_path: str):
    """
    清理临时文件
    
    Args:
        file_path: 文件路径
    """
    try:
        if os.path.exists(file_path):
            os.remove(file_path)
    except Exception as e:
        # 记录错误但不中断流程
        import logging
        logger = logging.getLogger(__name__)
        logger.warning(f"清理文件失败 {file_path}: {e}")

def validate_image_file(file_path: str) -> bool:
    """
    验证图片文件
    
    Args:
        file_path: 文件路径
        
    Returns:
        是否有效
    """
    if not os.path.exists(file_path):
        return False
    
    # 检查文件扩展名
    ext = Path(file_path).suffix.lower()
    if ext not in settings.ALLOWED_EXTENSIONS:
        return False
    
    # 检查文件大小
    file_size = os.path.getsize(file_path)
    if file_size > settings.MAX_UPLOAD_SIZE:
        return False
    
    return True

5.3 创建主应用

打开/root/GLM-OCR-api/app/main.py

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.openapi.utils import get_openapi

from app.core.config import settings
from app.api.endpoints import router as api_router

# 创建FastAPI应用
app = FastAPI(
    title=settings.API_TITLE,
    version=settings.API_VERSION,
    description=settings.API_DESCRIPTION,
    docs_url=None,  # 自定义docs路径
    redoc_url="/redoc"
)

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 生产环境应该限制来源
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加中间件:记录请求日志
@app.middleware("http")
async def log_requests(request, call_next):
    import time
    import logging
    
    logger = logging.getLogger("uvicorn.access")
    start_time = time.time()
    
    response = await call_next(request)
    
    process_time = time.time() - start_time
    logger.info(
        f"{request.method} {request.url.path} "
        f"completed in {process_time:.3f}s "
        f"status={response.status_code}"
    )
    
    return response

# 自定义OpenAPI文档
def custom_openapi():
    if app.openapi_schema:
        return app.openapi_schema
    
    openapi_schema = get_openapi(
        title=settings.API_TITLE,
        version=settings.API_VERSION,
        description=settings.API_DESCRIPTION,
        routes=app.routes,
    )
    
    # 添加服务器信息
    openapi_schema["servers"] = [
        {
            "url": "http://localhost:8000",
            "description": "本地开发服务器"
        },
        {
            "url": "https://your-production-domain.com",
            "description": "生产服务器"
        }
    ]
    
    app.openapi_schema = openapi_schema
    return app.openapi_schema

app.openapi = custom_openapi

# 自定义Swagger UI
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui_html():
    return get_swagger_ui_html(
        openapi_url=app.openapi_url,
        title=f"{settings.API_TITLE} - Swagger UI",
        oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
        swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
        swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
    )

# 注册路由
app.include_router(api_router, prefix="/api/v1")

# 根路径
@app.get("/")
async def root():
    return {
        "message": "GLM-OCR API服务",
        "version": settings.API_VERSION,
        "docs": "/docs",
        "health": "/api/v1/health"
    }

# 启动事件
@app.on_event("startup")
async def startup_event():
    import logging
    logger = logging.getLogger(__name__)
    logger.info("GLM-OCR API服务启动中...")
    
    # 初始化OCR服务
    try:
        from app.core.ocr_service import get_ocr_service
        ocr_service = get_ocr_service()
        logger.info("GLM-OCR服务初始化完成")
    except Exception as e:
        logger.error(f"GLM-OCR服务初始化失败: {e}")
        raise

@app.on_event("shutdown")
async def shutdown_event():
    import logging
    logger = logging.getLogger(__name__)
    logger.info("GLM-OCR API服务关闭中...")
    
    # 清理资源
    from app.core.ocr_service import _ocr_service
    if _ocr_service:
        _ocr_service.disconnect()
        logger.info("GLM-OCR服务连接已关闭")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "app.main:app",
        host=settings.HOST,
        port=settings.PORT,
        reload=settings.RELOAD,
        workers=settings.MAX_WORKERS
    )

5.4 创建启动脚本

打开/root/GLM-OCR-api/start_api.sh

#!/bin/bash

# GLM-OCR API启动脚本

# 设置环境变量
export PYTHONPATH=/root/GLM-OCR-api:$PYTHONPATH

# 激活conda环境
source /opt/miniconda3/bin/activate py310

# 检查GLM-OCR服务是否运行
GLM_OCR_PID=$(lsof -ti:7860)
if [ -z "$GLM_OCR_PID" ]; then
    echo "GLM-OCR服务未运行,正在启动..."
    cd /root/GLM-OCR
    nohup ./start_vllm.sh > /dev/null 2>&1 &
    echo "等待GLM-OCR服务启动..."
    sleep 30
else
    echo "GLM-OCR服务已在运行,PID: $GLM_OCR_PID"
fi

# 启动FastAPI服务
echo "启动GLM-OCR API服务..."
cd /root/GLM-OCR-api

# 使用uvicorn启动
uvicorn app.main:app \
    --host 0.0.0.0 \
    --port 8000 \
    --workers 4 \
    --log-level info \
    --access-log

给脚本添加执行权限:

chmod +x /root/GLM-OCR-api/start_api.sh

现在,我们的FastAPI服务已经完成了。接下来,我们创建Java客户端示例。

6. Java客户端调用示例

6.1 使用HttpClient调用(Java 11+)

如果你使用的是Java 11或更高版本,可以使用内置的HttpClient。这里是一个完整的示例:

import java.io.File;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse.BodyHandlers;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Map;
import java.util.HashMap;
import com.fasterxml.jackson.databind.ObjectMapper;

public class GLMOCRClient {
    
    private static final String API_BASE_URL = "http://localhost:8000/api/v1";
    private static final HttpClient httpClient = HttpClient.newBuilder()
            .connectTimeout(Duration.ofSeconds(30))
            .build();
    private static final ObjectMapper objectMapper = new ObjectMapper();
    
    /**
     * 文本识别
     */
    public static String recognizeText(File imageFile, String language) throws Exception {
        // 构建multipart/form-data请求
        Map<Object, Object> data = new HashMap<>();
        data.put("image", imageFile.toPath());
        data.put("language", language);
        data.put("confidence_threshold", "0.5");
        
        HttpRequest request = HttpRequest.newBuilder()
                .uri(URI.create(API_BASE_URL + "/ocr/text"))
                .header("Content-Type", "multipart/form-data")
                .POST(ofMimeMultipartData(data, "boundary"))
                .timeout(Duration.ofSeconds(60))
                .build();
        
        HttpResponse<String> response = httpClient.send(request, BodyHandlers.ofString());
        
        if (response.statusCode() == 200) {
            // 解析响应
            Map<String, Object> result = objectMapper.readValue(response.body(), Map.class);
            if (Boolean.TRUE.equals(result.get("success"))) {
                return (String) result.get("text");
            } else {
                throw new RuntimeException("识别失败: " + result.get("error_message"));
            }
        } else {
            throw new RuntimeException("HTTP错误: " + response.statusCode());
        }
    }
    
    /**
     * 表格识别
     */
    public static String recognizeTable(File imageFile, String outputFormat) throws Exception {
        Map<Object, Object> data = new HashMap<>();
        data.put("image", imageFile.toPath());
        data.put("output_format", outputFormat);
        
        HttpRequest request = HttpRequest.newBuilder()
                .uri(URI.create(API_BASE_URL + "/ocr/table"))
                .header("Content-Type", "multipart/form-data")
                .POST(ofMimeMultipartData(data, "boundary"))
                .timeout(Duration.ofSeconds(60))
                .build();
        
        HttpResponse<String> response = httpClient.send(request, BodyHandlers.ofString());
        
        if (response.statusCode() == 200) {
            Map<String, Object> result = objectMapper.readValue(response.body(), Map.class);
            if (Boolean.TRUE.equals(result.get("success"))) {
                if ("csv".equals(outputFormat)) {
                    return (String) result.get("csv_data");
                } else if ("html".equals(outputFormat)) {
                    return (String) result.get("html_table");
                } else {
                    return objectMapper.writeValueAsString(result.get("table_data"));
                }
            } else {
                throw new RuntimeException("识别失败: " + result.get("error_message"));
            }
        } else {
            throw new RuntimeException("HTTP错误: " + response.statusCode());
        }
    }
    
    /**
     * 公式识别
     */
    public static String recognizeFormula(File imageFile) throws Exception {
        Map<Object, Object> data = new HashMap<>();
        data.put("image", imageFile.toPath());
        
        HttpRequest request = HttpRequest.newBuilder()
                .uri(URI.create(API_BASE_URL + "/ocr/formula"))
                .header("Content-Type", "multipart/form-data")
                .POST(ofMimeMultipartData(data, "boundary"))
                .timeout(Duration.ofSeconds(60))
                .build();
        
        HttpResponse<String> response = httpClient.send(request, BodyHandlers.ofString());
        
        if (response.statusCode() == 200) {
            Map<String, Object> result = objectMapper.readValue(response.body(), Map.class);
            if (Boolean.TRUE.equals(result.get("success"))) {
                return (String) result.get("latex");
            } else {
                throw new RuntimeException("识别失败: " + result.get("error_message"));
            }
        } else {
            throw new RuntimeException("HTTP错误: " + response.statusCode());
        }
    }
    
    /**
     * 健康检查
     */
    public static boolean healthCheck() throws Exception {
        HttpRequest request = HttpRequest.newBuilder()
                .uri(URI.create(API_BASE_URL + "/health"))
                .GET()
                .timeout(Duration.ofSeconds(10))
                .build();
        
        HttpResponse<String> response = httpClient.send(request, BodyHandlers.ofString());
        
        if (response.statusCode() == 200) {
            Map<String, Object> result = objectMapper.readValue(response.body(), Map.class);
            return "healthy".equals(result.get("status"));
        }
        return false;
    }
    
    /**
     * 构建multipart/form-data请求体
     */
    private static HttpRequest.BodyPublisher ofMimeMultipartData(Map<Object, Object> data, String boundary) {
        // 简化实现,实际使用时建议使用专门的库如Apache HttpClient
        // 这里仅展示思路
        StringBuilder sb = new StringBuilder();
        
        for (Map.Entry<Object, Object> entry : data.entrySet()) {
            sb.append("--").append(boundary).append("\r\n");
            if (entry.getValue() instanceof Path) {
                Path path = (Path) entry.getValue();
                sb.append("Content-Disposition: form-data; name=\"")
                  .append(entry.getKey())
                  .append("\"; filename=\"")
                  .append(path.getFileName())
                  .append("\"\r\n")
                  .append("Content-Type: application/octet-stream\r\n\r\n");
                // 这里需要读取文件内容
            } else {
                sb.append("Content-Disposition: form-data; name=\"")
                  .append(entry.getKey())
                  .append("\"\r\n\r\n")
                  .append(entry.getValue())
                  .append("\r\n");
            }
        }
        sb.append("--").append(boundary).append("--\r\n");
        
        return BodyPublishers.ofString(sb.toString());
    }
    
    public static void main(String[] args) {
        try {
            // 健康检查
            if (!healthCheck()) {
                System.err.println("OCR服务不可用");
                return;
            }
            
            System.out.println("OCR服务状态正常");
            
            // 示例:文本识别
            File imageFile = new File("test.png");
            if (imageFile.exists()) {
                String text = recognizeText(imageFile, "zh");
                System.out.println("识别结果: " + text);
            }
            
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

6.2 使用Spring Boot调用

如果你使用Spring Boot,可以创建一个更完整的服务。首先添加依赖:

<!-- pom.xml -->
<dependencies>
    <!-- Spring Boot Web -->
    <dependency>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-web</artifactId>
    </dependency>
    
    <!-- HttpClient -->
    <dependency>
        <groupId>org.apache.httpcomponents</groupId>
        <artifactId>httpclient</artifactId>
        <version>4.5.13</version>
    </dependency>
    
    <!-- Jackson -->
    <dependency>
        <groupId>com.fasterxml.jackson.core</groupId>
        <artifactId>jackson-databind</artifactId>
    </dependency>
    
    <!-- Lombok -->
    <dependency>
        <groupId>org.projectlombok</groupId>
        <artifactId>lombok</artifactId>
        <optional>true</optional>
    </dependency>
</dependencies>

创建配置类:

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.client.config.RequestConfig;

@Configuration
public class HttpClientConfig {
    
    @Bean
    public CloseableHttpClient httpClient() {
        RequestConfig requestConfig = RequestConfig.custom()
                .setConnectTimeout(30000)
                .setSocketTimeout(60000)
                .build();
        
        return HttpClients.custom()
                .setDefaultRequestConfig(requestConfig)
                .build();
    }
}

创建OCR服务类:

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.mime.MultipartEntityBuilder;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.File;
import java.io.IOException;
import java.util.Map;

@Slf4j
@Service
public class GLMOCRService {
    
    @Value("${glm.ocr.api.base-url:http://localhost:8000/api/v1}")
    private String apiBaseUrl;
    
    @Autowired
    private CloseableHttpClient httpClient;
    
    @Autowired
    private ObjectMapper objectMapper;
    
    /**
     * 文本识别
     */
    public OCRResult recognizeText(File imageFile, String language) throws IOException {
        String url = apiBaseUrl + "/ocr/text";
        
        HttpPost httpPost = new HttpPost(url);
        
        // 构建multipart请求
        MultipartEntityBuilder builder = MultipartEntityBuilder.create();
        builder.addBinaryBody("image", imageFile, 
                ContentType.APPLICATION_OCTET_STREAM, 
                imageFile.getName());
        builder.addTextBody("language", language);
        builder.addTextBody("confidence_threshold", "0.5");
        
        httpPost.setEntity(builder.build());
        
        try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
            String responseBody = EntityUtils.toString(response.getEntity());
            
            if (response.getStatusLine().getStatusCode() == 200) {
                Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
                
                OCRResult ocrResult = new OCRResult();
                ocrResult.setSuccess((Boolean) result.get("success"));
                ocrResult.setText((String) result.get("text"));
                ocrResult.setConfidence((Double) result.get("confidence"));
                ocrResult.setProcessingTime((Double) result.get("processing_time"));
                ocrResult.setErrorMessage((String) result.get("error_message"));
                
                return ocrResult;
            } else {
                throw new IOException("HTTP错误: " + response.getStatusLine().getStatusCode());
            }
        }
    }
    
    /**
     * 表格识别
     */
    public TableOCRResult recognizeTable(File imageFile, String outputFormat) throws IOException {
        String url = apiBaseUrl + "/ocr/table";
        
        HttpPost httpPost = new HttpPost(url);
        
        MultipartEntityBuilder builder = MultipartEntityBuilder.create();
        builder.addBinaryBody("image", imageFile, 
                ContentType.APPLICATION_OCTET_STREAM, 
                imageFile.getName());
        builder.addTextBody("output_format", outputFormat);
        
        httpPost.setEntity(builder.build());
        
        try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
            String responseBody = EntityUtils.toString(response.getEntity());
            
            if (response.getStatusLine().getStatusCode() == 200) {
                Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
                
                TableOCRResult tableResult = new TableOCRResult();
                tableResult.setSuccess((Boolean) result.get("success"));
                tableResult.setTableData((java.util.List) result.get("table_data"));
                tableResult.setCsvData((String) result.get("csv_data"));
                tableResult.setHtmlTable((String) result.get("html_table"));
                tableResult.setProcessingTime((Double) result.get("processing_time"));
                tableResult.setErrorMessage((String) result.get("error_message"));
                
                return tableResult;
            } else {
                throw new IOException("HTTP错误: " + response.getStatusLine().getStatusCode());
            }
        }
    }
    
    /**
     * 公式识别
     */
    public FormulaOCRResult recognizeFormula(File imageFile) throws IOException {
        String url = apiBaseUrl + "/ocr/formula";
        
        HttpPost httpPost = new HttpPost(url);
        
        MultipartEntityBuilder builder = MultipartEntityBuilder.create();
        builder.addBinaryBody("image", imageFile, 
                ContentType.APPLICATION_OCTET_STREAM, 
                imageFile.getName());
        
        httpPost.setEntity(builder.build());
        
        try (CloseableHttpResponse response = httpClient.execute(httpPost)) {
            String responseBody = EntityUtils.toString(response.getEntity());
            
            if (response.getStatusLine().getStatusCode() == 200) {
                Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
                
                FormulaOCRResult formulaResult = new FormulaOCRResult();
                formulaResult.setSuccess((Boolean) result.get("success"));
                formulaResult.setLatex((String) result.get("latex"));
                formulaResult.setConfidence((Double) result.get("confidence"));
                formulaResult.setProcessingTime((Double) result.get("processing_time"));
                formulaResult.setErrorMessage((String) result.get("error_message"));
                
                return formulaResult;
            } else {
                throw new IOException("HTTP错误: " + response.getStatusLine().getStatusCode());
            }
        }
    }
    
    /**
     * 健康检查
     */
    public boolean healthCheck() throws IOException {
        String url = apiBaseUrl + "/health";
        
        HttpGet httpGet = new HttpGet(url);
        
        try (CloseableHttpResponse response = httpClient.execute(httpGet)) {
            if (response.getStatusLine().getStatusCode() == 200) {
                String responseBody = EntityUtils.toString(response.getEntity());
                Map<String, Object> result = objectMapper.readValue(responseBody, Map.class);
                return "healthy".equals(result.get("status"));
            }
            return false;
        }
    }
    
    @Data
    public static class OCRResult {
        private boolean success;
        private String text;
        private Double confidence;
        private Double processingTime;
        private String errorMessage;
    }
    
    @Data
    public static class TableOCRResult {
        private boolean success;
        private java.util.List<java.util.List<String>> tableData;
        private String csvData;
        private String htmlTable;
        private Double processingTime;
        private String errorMessage;
    }
    
    @Data
    public static class FormulaOCRResult {
        private boolean success;
        private String latex;
        private Double confidence;
        private Double processingTime;
        private String errorMessage;
    }
}

创建REST控制器:

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;

@RestController
@RequestMapping("/api/ocr")
public class OCRController {
    
    @Autowired
    private GLMOCRService ocrService;
    
    @PostMapping("/text")
    public GLMOCRService.OCRResult recognizeText(
            @RequestParam("file") MultipartFile file,
            @RequestParam(value = "language", defaultValue = "zh") String language) throws IOException {
        
        // 保存上传的文件到临时目录
        Path tempFile = Files.createTempFile("ocr_", ".tmp");
        Files.copy(file.getInputStream(), tempFile, StandardCopyOption.REPLACE_EXISTING);
        
        try {
            return ocrService.recognizeText(tempFile.toFile(), language);
        } finally {
            // 清理临时文件
            Files.deleteIfExists(tempFile);
        }
    }
    
    @PostMapping("/table")
    public GLMOCRService.TableOCRResult recognizeTable(
            @RequestParam("file") MultipartFile file,
            @RequestParam(value = "format", defaultValue = "json") String format) throws IOException {
        
        Path tempFile = Files.createTempFile("ocr_", ".tmp");
        Files.copy(file.getInputStream(), tempFile, StandardCopyOption.REPLACE_EXISTING);
        
        try {
            return ocrService.recognizeTable(tempFile.toFile(), format);
        } finally {
            Files.deleteIfExists(tempFile);
        }
    }
    
    @PostMapping("/formula")
    public GLMOCRService.FormulaOCRResult recognizeFormula(
            @RequestParam("file") MultipartFile file) throws IOException {
        
        Path tempFile = Files.createTempFile("ocr_", ".tmp");
        Files.copy(file.getInputStream(), tempFile, StandardCopyOption.REPLACE_EXISTING);
        
        try {
            return ocrService.recognizeFormula(tempFile.toFile());
        } finally {
            Files.deleteIfExists(tempFile);
        }
    }
    
    @GetMapping("/health")
    public String healthCheck() throws IOException {
        boolean healthy = ocrService.healthCheck();
        return healthy ? "OCR服务正常" : "OCR服务异常";
    }
}

配置application.properties:

# 应用配置
server.port=8080
spring.servlet.multipart.max-file-size=10MB
spring.servlet.multipart.max-request-size=10MB

# GLM-OCR API配置
glm.ocr.api.base-url=http://localhost:8000/api/v1

# HTTP客户端配置
http.client.connect-timeout=30000
http.client.socket-timeout=60000

这样,你就有了一个完整的Spring Boot OCR服务,可以供其他Java应用调用。

7. 部署与测试

7.1 启动GLM-OCR API服务

首先,确保GLM-OCR服务已经启动:

# 进入GLM-OCR目录
cd /root/GLM-OCR

# 启动GLM-OCR服务
./start_vllm.sh

等待1-2分钟,直到服务完全启动。然后启动我们的FastAPI服务:

# 进入API项目目录
cd /root/GLM-OCR-api

# 启动API服务
./start_api.sh

服务启动后,你可以通过以下方式访问:

  1. API文档:打开浏览器访问 http://localhost:8000/docs
  2. 健康检查http://localhost:8000/api/v1/health
  3. 根路径http://localhost:8000/

7.2 测试API接口

使用curl测试文本识别:
curl -X POST "http://localhost:8000/api/v1/ocr/text" \
  -F "image=@/path/to/your/image.png" \
  -F "language=zh" \
  -F "confidence_threshold=0.5"
使用Python测试:
import requests

# 文本识别
url = "http://localhost:8000/api/v1/ocr/text"
files = {"image": open("test.png", "rb")}
data = {"language": "zh", "confidence_threshold": 0.5}

response = requests.post(url, files=files, data=data)
print(response.json())

# 表格识别
url = "http://localhost:8000/api/v1/ocr/table"
files = {"image": open("table.png", "rb")}
data = {"output_format": "csv"}

response = requests.post(url, files=files, data=data)
print(response.json())

# 公式识别
url = "http://localhost:8000/api/v1/ocr/formula"
files = {"image": open("formula.png", "rb")}

response = requests.post(url, files=files)
print(response.json())
使用Postman测试:
  1. 新建POST请求:http://localhost:8000/api/v1/ocr/text
  2. 选择Body -> form-data
  3. 添加字段:
    • key: image, type: File, value: 选择图片文件
    • key: language, value: zh
    • key: confidence_threshold, value: 0.5
  4. 点击Send

7.3 性能优化建议

在实际生产环境中,你可能需要考虑以下优化:

  1. 连接池:为Java客户端配置HTTP连接池
  2. 超时设置:根据实际需求调整超时时间
  3. 重试机制:添加失败重试逻辑
  4. 异步处理:对于大文件或批量处理,使用异步API
  5. 缓存:对相同的图片或识别结果进行缓存
  6. 负载均衡:如果流量大,可以部署多个API实例
  7. 监控:添加Prometheus监控指标

7.4 错误处理与日志

我们的API已经包含了基本的错误处理,但在生产环境中,你可能需要:

  1. 更详细的日志:记录请求参数、处理时间、错误堆栈
  2. 告警机制:当服务异常时发送告警
  3. 限流:防止恶意请求
  4. 认证授权:添加API密钥验证

8. 总结

通过本文的实践,我们成功地将GLM-OCR封装成了一个标准的RESTful API服务,并提供了完整的Java调用示例。让我们回顾一下关键点:

8.1 实现的核心价值

  1. 解耦与标准化:Java系统不再需要直接调用Python脚本,通过HTTP接口即可完成OCR识别
  2. 易于集成:标准的RESTful API让任何支持HTTP的语言都能轻松集成
  3. 功能完整:支持文本、表格、公式识别,以及批量处理和URL识别
  4. 生产就绪:包含错误处理、日志记录、健康检查等生产环境必需的功能

8.2 主要技术要点

  1. FastAPI框架:利用其自动生成文档、类型验证等特性,快速构建高质量API
  2. 服务封装:将GLM-OCR的Gradio接口封装成易于使用的Python类
  3. 数据模型:使用Pydantic定义清晰的数据结构,确保接口的稳定性
  4. 文件处理:正确处理文件上传、临时文件清理等细节
  5. Java客户端:提供了从简单HttpClient到完整Spring Boot的多种集成方案

8.3 实际应用建议

在实际项目中应用时,你可以根据具体需求进行调整:

  1. 根据业务需求定制:如果只需要文本识别,可以简化API接口
  2. 添加业务逻辑:在OCR识别前后添加预处理和后处理逻辑
  3. 集成到现有系统:将OCR服务作为微服务集成到你的架构中
  4. 性能监控:添加监控指标,了解服务的使用情况和性能表现
  5. 安全加固:根据实际需要添加认证、授权、限流等安全措施

8.4 扩展可能性

这个基础框架还有很多可以扩展的方向:

  1. 支持更多OCR功能:如果GLM-OCR更新了新的功能,可以快速添加到API中
  2. 多模型支持:可以扩展支持其他OCR模型,根据需求动态选择
  3. 异步处理:对于耗时的识别任务,可以提供异步接口
  4. WebSocket支持:实现实时OCR识别
  5. 客户端SDK:为不同语言提供更易用的SDK

通过本文的实践,你现在应该能够将GLM-OCR的能力轻松集成到任何Java系统中了。无论是传统的Spring应用,还是新的微服务架构,都可以通过这个API服务获得强大的OCR识别能力。


获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐